mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 10:32:09 +08:00
Based image prompts
This commit is contained in:
parent
4f3eba5521
commit
c52d82f6b0
64
predict.py
64
predict.py
@ -4,7 +4,7 @@ import subprocess
|
||||
import tempfile
|
||||
import random
|
||||
import typing
|
||||
from typing_extensions import Self
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from deep_translator import GoogleTranslator
|
||||
from cog import BasePredictor, Input, Path
|
||||
|
||||
@ -444,18 +444,19 @@ class Predictor(BasePredictor):
|
||||
args.use_guidance_stage1 = True
|
||||
args.use_guidance_stage2 = False
|
||||
args.both_stages = True
|
||||
args.device = torch.device('cuda')
|
||||
args.device = torch.device("cuda")
|
||||
self.image_prompt = None
|
||||
|
||||
self.translator = GoogleTranslator(source="en", target="zh-CN")
|
||||
self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
|
||||
self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, "cogvideo-stage1")
|
||||
self.model_stage1.eval()
|
||||
self.model_stage1 = self.model_stage1.cpu()
|
||||
self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2')
|
||||
self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, "cogvideo-stage2")
|
||||
self.model_stage2.eval()
|
||||
self.model_stage2 = self.model_stage2.cpu()
|
||||
|
||||
# enable dsr if model exists
|
||||
if os.path.exists('/sharefs/cogview-new/cogview2-dsr'):
|
||||
if os.path.exists("/sharefs/cogview-new/cogview2-dsr"):
|
||||
subprocess.check_output("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True)
|
||||
sys.path.append('./Image-Local-Attention')
|
||||
from sr_pipeline import DirectSuperResolution
|
||||
@ -485,6 +486,7 @@ class Predictor(BasePredictor):
|
||||
description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True
|
||||
),
|
||||
use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True),
|
||||
image_prompt: Path = Input(description="Starting image")
|
||||
) -> typing.Iterator[Path]:
|
||||
if translate:
|
||||
prompt = self.translator.translate(prompt.strip())
|
||||
@ -495,6 +497,13 @@ class Predictor(BasePredictor):
|
||||
self.args.seed = seed
|
||||
self.args.use_guidance_stage1 = use_guidance
|
||||
self.prompt = prompt
|
||||
if os.path.exists(image_prompt):
|
||||
try:
|
||||
Image.open(image_prompt)
|
||||
except (FileNotFoundError, UnidentifiedImageError):
|
||||
logging.debug("Bad image prompt; ignoring") # Is there a better way to input images?
|
||||
else:
|
||||
self.image_prompt = Image.open(image_prompt)
|
||||
self.args.both_stages = both_stages
|
||||
|
||||
for file in self.run():
|
||||
@ -546,25 +555,28 @@ class Predictor(BasePredictor):
|
||||
text_len_1st = len(seq_1st) - frame_len*1 - 1
|
||||
|
||||
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
|
||||
output_list_1st = []
|
||||
for tim in range(max(batch_size // mbz, 1)):
|
||||
start_time = time.time()
|
||||
output_list_1st.append(
|
||||
my_filling_sequence(self.model_stage1, args,seq_1st.clone(),
|
||||
batch_size=min(batch_size, mbz),
|
||||
get_masks_and_position_ids=get_masks_and_position_ids_stage1,
|
||||
text_len=text_len_1st,
|
||||
frame_len=frame_len,
|
||||
strategy=strategy_cogview2,
|
||||
strategy2=strategy_cogvideo,
|
||||
log_text_attention_weights=1.4,
|
||||
enforce_no_swin=True,
|
||||
mode_stage1=True,
|
||||
)[0]
|
||||
)
|
||||
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
|
||||
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
||||
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
|
||||
if self.image_prompt is None:
|
||||
output_list_1st = []
|
||||
for tim in range(max(batch_size // mbz, 1)):
|
||||
start_time = time.time()
|
||||
output_list_1st.append(
|
||||
my_filling_sequence(self.model_stage1, args,seq_1st.clone(),
|
||||
batch_size=min(batch_size, mbz),
|
||||
get_masks_and_position_ids=get_masks_and_position_ids_stage1,
|
||||
text_len=text_len_1st,
|
||||
frame_len=frame_len,
|
||||
strategy=strategy_cogview2,
|
||||
strategy2=strategy_cogvideo,
|
||||
log_text_attention_weights=1.4,
|
||||
enforce_no_swin=True,
|
||||
mode_stage1=True,
|
||||
)[0]
|
||||
)
|
||||
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
|
||||
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
||||
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
|
||||
else:
|
||||
given_tokens = tokenizer.encode(image_path=self.image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
|
||||
|
||||
# generate subsequent frames:
|
||||
total_frames = self.generate_frame_num
|
||||
@ -766,7 +778,7 @@ class Predictor(BasePredictor):
|
||||
|
||||
for sample_i in range(sample_num):
|
||||
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||
output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif'
|
||||
output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif"
|
||||
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
|
||||
yield output_file
|
||||
|
||||
@ -790,6 +802,6 @@ class Predictor(BasePredictor):
|
||||
|
||||
for sample_i in range(sample_num):
|
||||
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||
output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif'
|
||||
output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif"
|
||||
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
|
||||
yield output_file
|
||||
Loading…
x
Reference in New Issue
Block a user