mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-03 03:02:09 +08:00
add tqdm
This commit is contained in:
parent
8cc364086b
commit
4dfdd0a2f2
1
cog.yaml
1
cog.yaml
@ -19,6 +19,7 @@ build:
|
||||
- "SwissArmyTransformer==0.2.9"
|
||||
- "torch==1.9.0"
|
||||
- "torchvision==0.10.0"
|
||||
- "tqdm==4.64.0"
|
||||
|
||||
run:
|
||||
- "mkdir -p /sharefs/cogview-new; cd /sharefs/cogview-new; wget https://models.nmb.ai/cogvideo/cogvideo-stage1.tar.gz -O - | tar xz"
|
||||
|
||||
28
predict.py
28
predict.py
@ -12,6 +12,7 @@ import torch
|
||||
import time
|
||||
import logging,sys
|
||||
import stat
|
||||
from tqdm import trange
|
||||
from torchvision.utils import save_image
|
||||
from icetk import icetk as tokenizer
|
||||
import torch.distributed as dist
|
||||
@ -111,13 +112,13 @@ def my_save_multiple_images(imgs, path, subdir, debug=True):
|
||||
# imgs: list of tensor images
|
||||
if debug:
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
print("\nSave to: ", path, flush=True)
|
||||
#print("\nSave to: ", path, flush=True)
|
||||
save_image(imgs, path, normalize=True)
|
||||
else:
|
||||
print("\nSave to: ", path, flush=True)
|
||||
#print("\nSave to: ", path, flush=True)
|
||||
single_frame_path = os.path.join(path, subdir)
|
||||
os.makedirs(single_frame_path, exist_ok=True)
|
||||
for i in range(len(imgs)):
|
||||
for i in trange(len(imgs)):
|
||||
save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True)
|
||||
os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
|
||||
save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True)
|
||||
@ -532,6 +533,8 @@ class Predictor(BasePredictor):
|
||||
yield Path(f"{workdir}/output/stage2/0.gif")
|
||||
|
||||
logging.debug("complete, exiting")
|
||||
raise StopIteration()
|
||||
|
||||
|
||||
def process_stage1(self, model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1):
|
||||
process_start_time = time.time()
|
||||
@ -552,7 +555,7 @@ class Predictor(BasePredictor):
|
||||
|
||||
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
|
||||
output_list_1st = []
|
||||
for tim in range(max(batch_size // mbz, 1)):
|
||||
for tim in trange(max(batch_size // mbz, 1)):
|
||||
start_time = time.time()
|
||||
output_list_1st.append(
|
||||
my_filling_sequence(model, args,seq_1st.clone(),
|
||||
@ -597,7 +600,7 @@ class Predictor(BasePredictor):
|
||||
guider_seq = None
|
||||
video_log_text_attention_weights = 1.4
|
||||
|
||||
for tim in range(max(batch_size // mbz, 1)):
|
||||
for tim in trange(max(batch_size // mbz, 1)):
|
||||
start_time = time.time()
|
||||
input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
|
||||
guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
|
||||
@ -702,7 +705,7 @@ class Predictor(BasePredictor):
|
||||
assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
|
||||
output_list = []
|
||||
start_time = time.time()
|
||||
for tim in range(max(generate_batchsize_total // mbz, 1)):
|
||||
for tim in trange(max(generate_batchsize_total // mbz, 1)):
|
||||
input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
|
||||
guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
|
||||
output_list.append(
|
||||
@ -753,14 +756,14 @@ class Predictor(BasePredictor):
|
||||
sred_tokens = self.dsr(text_seq, parent_given_tokens_2d)
|
||||
decoded_sr_videos = []
|
||||
|
||||
for sample_i in range(sample_num):
|
||||
for sample_i in trange(sample_num):
|
||||
decoded_sr_imgs = []
|
||||
for frame_i in range(frame_num_per_sample):
|
||||
decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:])
|
||||
decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480)))
|
||||
decoded_sr_videos.append(decoded_sr_imgs)
|
||||
|
||||
for sample_i in range(sample_num):
|
||||
for sample_i in trange(sample_num):
|
||||
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
|
||||
|
||||
@ -769,20 +772,17 @@ class Predictor(BasePredictor):
|
||||
#imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
|
||||
#os.makedirs(outputdir, exist_ok=True)
|
||||
#my_save_multiple_images(imgs, outputdir,subdir="frames", debug=False)
|
||||
#os.system(f"gifmaker -i '{outputdir}'/frames/0*.jpg -o '{outputdir}/{str(float(duration))}_concat.gif' -d 0.2")
|
||||
|
||||
|
||||
output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:]
|
||||
#os.system(f"gifmaker -i '{outputdir}'/frames/0*.jpg -o '{outputdir}/{str(float(duration))}_concat.gif' -d 0.2")
|
||||
decoded_videos = []
|
||||
|
||||
for sample_i in range(sample_num):
|
||||
for sample_i in trange(sample_num):
|
||||
decoded_imgs = []
|
||||
for frame_i in range(frame_num_per_sample):
|
||||
decoded_img = tokenizer.decode(image_ids=parent_given_tokens_2d[frame_i+sample_i*frame_num_per_sample][-3600:])
|
||||
decoded_imgs.append(torch.nn.functional.interpolate(decoded_img, size=(480, 480)))
|
||||
decoded_videos.append(decoded_imgs)
|
||||
|
||||
for sample_i in range(sample_num):
|
||||
for sample_i in trange(sample_num):
|
||||
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user