mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 18:52:08 +08:00
remove raise
This commit is contained in:
parent
e8f23104b9
commit
abfc7af8f9
20
predict.py
20
predict.py
@ -183,7 +183,7 @@ def my_filling_sequence(
|
||||
guider_tokens = guider_tokens[..., :context_length-guider_index_delta]
|
||||
guider_input_tokens = guider_tokens.clone()
|
||||
|
||||
for fid in range(current_frame_num):
|
||||
for fid in trange(current_frame_num):
|
||||
input_tokens[:, text_len+400*fid] = tokenizer['<start_of_image>']
|
||||
if guider_seq is not None:
|
||||
guider_input_tokens[:, guider_text_len+400*fid] = tokenizer['<start_of_image>']
|
||||
@ -532,9 +532,7 @@ class Predictor(BasePredictor):
|
||||
gpu_rank=0, gpu_parallel_size=1)
|
||||
yield Path(f"{workdir}/output/stage2/0.gif")
|
||||
|
||||
logging.debug("complete, exiting")
|
||||
raise StopIteration()
|
||||
|
||||
logging.debug("complete, exiting")
|
||||
|
||||
def process_stage1(self, model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1):
|
||||
process_start_time = time.time()
|
||||
@ -555,7 +553,7 @@ class Predictor(BasePredictor):
|
||||
|
||||
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
|
||||
output_list_1st = []
|
||||
for tim in trange(max(batch_size // mbz, 1)):
|
||||
for tim in range(max(batch_size // mbz, 1)):
|
||||
start_time = time.time()
|
||||
output_list_1st.append(
|
||||
my_filling_sequence(model, args,seq_1st.clone(),
|
||||
@ -600,7 +598,7 @@ class Predictor(BasePredictor):
|
||||
guider_seq = None
|
||||
video_log_text_attention_weights = 1.4
|
||||
|
||||
for tim in trange(max(batch_size // mbz, 1)):
|
||||
for tim in range(max(batch_size // mbz, 1)):
|
||||
start_time = time.time()
|
||||
input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
|
||||
guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
|
||||
@ -705,7 +703,7 @@ class Predictor(BasePredictor):
|
||||
assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
|
||||
output_list = []
|
||||
start_time = time.time()
|
||||
for tim in trange(max(generate_batchsize_total // mbz, 1)):
|
||||
for tim in range(max(generate_batchsize_total // mbz, 1)):
|
||||
input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
|
||||
guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
|
||||
output_list.append(
|
||||
@ -756,14 +754,14 @@ class Predictor(BasePredictor):
|
||||
sred_tokens = self.dsr(text_seq, parent_given_tokens_2d)
|
||||
decoded_sr_videos = []
|
||||
|
||||
for sample_i in trange(sample_num):
|
||||
for sample_i in range(sample_num):
|
||||
decoded_sr_imgs = []
|
||||
for frame_i in range(frame_num_per_sample):
|
||||
decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:])
|
||||
decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480)))
|
||||
decoded_sr_videos.append(decoded_sr_imgs)
|
||||
|
||||
for sample_i in trange(sample_num):
|
||||
for sample_i in range(sample_num):
|
||||
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
|
||||
|
||||
@ -775,14 +773,14 @@ class Predictor(BasePredictor):
|
||||
#os.system(f"gifmaker -i '{outputdir}'/frames/0*.jpg -o '{outputdir}/{str(float(duration))}_concat.gif' -d 0.2")
|
||||
decoded_videos = []
|
||||
|
||||
for sample_i in trange(sample_num):
|
||||
for sample_i in range(sample_num):
|
||||
decoded_imgs = []
|
||||
for frame_i in range(frame_num_per_sample):
|
||||
decoded_img = tokenizer.decode(image_ids=parent_given_tokens_2d[frame_i+sample_i*frame_num_per_sample][-3600:])
|
||||
decoded_imgs.append(torch.nn.functional.interpolate(decoded_img, size=(480, 480)))
|
||||
decoded_videos.append(decoded_imgs)
|
||||
|
||||
for sample_i in trange(sample_num):
|
||||
for sample_i in range(sample_num):
|
||||
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user