This commit is contained in:
Stephan Auerhahn 2022-07-25 11:41:53 +00:00
parent 8cc364086b
commit 4dfdd0a2f2
2 changed files with 15 additions and 14 deletions

View File

@ -19,6 +19,7 @@ build:
- "SwissArmyTransformer==0.2.9"
- "torch==1.9.0"
- "torchvision==0.10.0"
- "tqdm==4.64.0"
run:
- "mkdir -p /sharefs/cogview-new; cd /sharefs/cogview-new; wget https://models.nmb.ai/cogvideo/cogvideo-stage1.tar.gz -O - | tar xz"

View File

@ -12,6 +12,7 @@ import torch
import time
import logging,sys
import stat
from tqdm import trange
from torchvision.utils import save_image
from icetk import icetk as tokenizer
import torch.distributed as dist
@ -111,13 +112,13 @@ def my_save_multiple_images(imgs, path, subdir, debug=True):
# imgs: list of tensor images
if debug:
imgs = torch.cat(imgs, dim=0)
print("\nSave to: ", path, flush=True)
#print("\nSave to: ", path, flush=True)
save_image(imgs, path, normalize=True)
else:
print("\nSave to: ", path, flush=True)
#print("\nSave to: ", path, flush=True)
single_frame_path = os.path.join(path, subdir)
os.makedirs(single_frame_path, exist_ok=True)
for i in range(len(imgs)):
for i in trange(len(imgs)):
save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True)
os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True)
@ -532,6 +533,8 @@ class Predictor(BasePredictor):
yield Path(f"{workdir}/output/stage2/0.gif")
logging.debug("complete, exiting")
raise StopIteration()
def process_stage1(self, model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1):
process_start_time = time.time()
@ -552,7 +555,7 @@ class Predictor(BasePredictor):
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
output_list_1st = []
for tim in range(max(batch_size // mbz, 1)):
for tim in trange(max(batch_size // mbz, 1)):
start_time = time.time()
output_list_1st.append(
my_filling_sequence(model, args,seq_1st.clone(),
@ -597,7 +600,7 @@ class Predictor(BasePredictor):
guider_seq = None
video_log_text_attention_weights = 1.4
for tim in range(max(batch_size // mbz, 1)):
for tim in trange(max(batch_size // mbz, 1)):
start_time = time.time()
input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
@ -702,7 +705,7 @@ class Predictor(BasePredictor):
assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
output_list = []
start_time = time.time()
for tim in range(max(generate_batchsize_total // mbz, 1)):
for tim in trange(max(generate_batchsize_total // mbz, 1)):
input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
output_list.append(
@ -753,14 +756,14 @@ class Predictor(BasePredictor):
sred_tokens = self.dsr(text_seq, parent_given_tokens_2d)
decoded_sr_videos = []
for sample_i in range(sample_num):
for sample_i in trange(sample_num):
decoded_sr_imgs = []
for frame_i in range(frame_num_per_sample):
decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:])
decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480)))
decoded_sr_videos.append(decoded_sr_imgs)
for sample_i in range(sample_num):
for sample_i in trange(sample_num):
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
@ -769,20 +772,17 @@ class Predictor(BasePredictor):
#imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
#os.makedirs(outputdir, exist_ok=True)
#my_save_multiple_images(imgs, outputdir,subdir="frames", debug=False)
#os.system(f"gifmaker -i '{outputdir}'/frames/0*.jpg -o '{outputdir}/{str(float(duration))}_concat.gif' -d 0.2")
output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:]
#os.system(f"gifmaker -i '{outputdir}'/frames/0*.jpg -o '{outputdir}/{str(float(duration))}_concat.gif' -d 0.2")
decoded_videos = []
for sample_i in range(sample_num):
for sample_i in trange(sample_num):
decoded_imgs = []
for frame_i in range(frame_num_per_sample):
decoded_img = tokenizer.decode(image_ids=parent_given_tokens_2d[frame_i+sample_i*frame_num_per_sample][-3600:])
decoded_imgs.append(torch.nn.functional.interpolate(decoded_img, size=(480, 480)))
decoded_videos.append(decoded_imgs)
for sample_i in range(sample_num):
for sample_i in trange(sample_num):
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")