diff --git a/cog.yaml b/cog.yaml index 54e243f..989f5a3 100644 --- a/cog.yaml +++ b/cog.yaml @@ -19,6 +19,7 @@ build: - "SwissArmyTransformer==0.2.9" - "torch==1.9.0" - "torchvision==0.10.0" + - "tqdm==4.64.0" run: - "mkdir -p /sharefs/cogview-new; cd /sharefs/cogview-new; wget https://models.nmb.ai/cogvideo/cogvideo-stage1.tar.gz -O - | tar xz" diff --git a/predict.py b/predict.py index e3ab3af..6404ae8 100644 --- a/predict.py +++ b/predict.py @@ -12,6 +12,7 @@ import torch import time import logging,sys import stat +from tqdm import trange from torchvision.utils import save_image from icetk import icetk as tokenizer import torch.distributed as dist @@ -111,13 +112,13 @@ def my_save_multiple_images(imgs, path, subdir, debug=True): # imgs: list of tensor images if debug: imgs = torch.cat(imgs, dim=0) - print("\nSave to: ", path, flush=True) + #print("\nSave to: ", path, flush=True) save_image(imgs, path, normalize=True) else: - print("\nSave to: ", path, flush=True) + #print("\nSave to: ", path, flush=True) single_frame_path = os.path.join(path, subdir) os.makedirs(single_frame_path, exist_ok=True) - for i in range(len(imgs)): + for i in trange(len(imgs)): save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True) os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True) @@ -532,6 +533,8 @@ class Predictor(BasePredictor): yield Path(f"{workdir}/output/stage2/0.gif") logging.debug("complete, exiting") + raise StopIteration() + def process_stage1(self, model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1): process_start_time = time.time() @@ -552,7 +555,7 @@ class Predictor(BasePredictor): seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0) output_list_1st = [] - for tim in range(max(batch_size // mbz, 1)): + for tim in trange(max(batch_size // mbz, 1)): start_time = time.time() output_list_1st.append( my_filling_sequence(model, args,seq_1st.clone(), @@ -597,7 +600,7 @@ class Predictor(BasePredictor): guider_seq = None video_log_text_attention_weights = 1.4 - for tim in range(max(batch_size // mbz, 1)): + for tim in trange(max(batch_size // mbz, 1)): start_time = time.time() input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone() guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None @@ -702,7 +705,7 @@ class Predictor(BasePredictor): assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0 output_list = [] start_time = time.time() - for tim in range(max(generate_batchsize_total // mbz, 1)): + for tim in trange(max(generate_batchsize_total // mbz, 1)): input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone() guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None output_list.append( @@ -753,14 +756,14 @@ class Predictor(BasePredictor): sred_tokens = self.dsr(text_seq, parent_given_tokens_2d) decoded_sr_videos = [] - for sample_i in range(sample_num): + for sample_i in trange(sample_num): decoded_sr_imgs = [] for frame_i in range(frame_num_per_sample): decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:]) decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480))) decoded_sr_videos.append(decoded_sr_imgs) - for sample_i in range(sample_num): + for sample_i in trange(sample_num): my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125") @@ -769,20 +772,17 @@ class Predictor(BasePredictor): #imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge] #os.makedirs(outputdir, exist_ok=True) #my_save_multiple_images(imgs, outputdir,subdir="frames", debug=False) - #os.system(f"gifmaker -i '{outputdir}'/frames/0*.jpg -o '{outputdir}/{str(float(duration))}_concat.gif' -d 0.2") - - - output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:] + #os.system(f"gifmaker -i '{outputdir}'/frames/0*.jpg -o '{outputdir}/{str(float(duration))}_concat.gif' -d 0.2") decoded_videos = [] - for sample_i in range(sample_num): + for sample_i in trange(sample_num): decoded_imgs = [] for frame_i in range(frame_num_per_sample): decoded_img = tokenizer.decode(image_ids=parent_given_tokens_2d[frame_i+sample_i*frame_num_per_sample][-3600:]) decoded_imgs.append(torch.nn.functional.interpolate(decoded_img, size=(480, 480))) decoded_videos.append(decoded_imgs) - for sample_i in range(sample_num): + for sample_i in trange(sample_num): my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")