From c52d82f6b06c7aa2901585f141f0d60fd99ffd78 Mon Sep 17 00:00:00 2001 From: neverix Date: Wed, 27 Jul 2022 00:10:31 +0300 Subject: [PATCH 1/5] Based image prompts --- predict.py | 64 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/predict.py b/predict.py index 9da9bba..d59eb36 100644 --- a/predict.py +++ b/predict.py @@ -4,7 +4,7 @@ import subprocess import tempfile import random import typing -from typing_extensions import Self +from PIL import Image, UnidentifiedImageError from deep_translator import GoogleTranslator from cog import BasePredictor, Input, Path @@ -444,18 +444,19 @@ class Predictor(BasePredictor): args.use_guidance_stage1 = True args.use_guidance_stage2 = False args.both_stages = True - args.device = torch.device('cuda') + args.device = torch.device("cuda") + self.image_prompt = None self.translator = GoogleTranslator(source="en", target="zh-CN") - self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1') + self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, "cogvideo-stage1") self.model_stage1.eval() self.model_stage1 = self.model_stage1.cpu() - self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2') + self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, "cogvideo-stage2") self.model_stage2.eval() self.model_stage2 = self.model_stage2.cpu() # enable dsr if model exists - if os.path.exists('/sharefs/cogview-new/cogview2-dsr'): + if os.path.exists("/sharefs/cogview-new/cogview2-dsr"): subprocess.check_output("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True) sys.path.append('./Image-Local-Attention') from sr_pipeline import DirectSuperResolution @@ -485,6 +486,7 @@ class Predictor(BasePredictor): description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True ), use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True), + image_prompt: Path = Input(description="Starting image") ) -> typing.Iterator[Path]: if translate: prompt = self.translator.translate(prompt.strip()) @@ -495,6 +497,13 @@ class Predictor(BasePredictor): self.args.seed = seed self.args.use_guidance_stage1 = use_guidance self.prompt = prompt + if os.path.exists(image_prompt): + try: + Image.open(image_prompt) + except (FileNotFoundError, UnidentifiedImageError): + logging.debug("Bad image prompt; ignoring") # Is there a better way to input images? + else: + self.image_prompt = Image.open(image_prompt) self.args.both_stages = both_stages for file in self.run(): @@ -546,25 +555,28 @@ class Predictor(BasePredictor): text_len_1st = len(seq_1st) - frame_len*1 - 1 seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0) - output_list_1st = [] - for tim in range(max(batch_size // mbz, 1)): - start_time = time.time() - output_list_1st.append( - my_filling_sequence(self.model_stage1, args,seq_1st.clone(), - batch_size=min(batch_size, mbz), - get_masks_and_position_ids=get_masks_and_position_ids_stage1, - text_len=text_len_1st, - frame_len=frame_len, - strategy=strategy_cogview2, - strategy2=strategy_cogvideo, - log_text_attention_weights=1.4, - enforce_no_swin=True, - mode_stage1=True, - )[0] - ) - logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)) - output_tokens_1st = torch.cat(output_list_1st, dim=0) - given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400] + if self.image_prompt is None: + output_list_1st = [] + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + output_list_1st.append( + my_filling_sequence(self.model_stage1, args,seq_1st.clone(), + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len_1st, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=1.4, + enforce_no_swin=True, + mode_stage1=True, + )[0] + ) + logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)) + output_tokens_1st = torch.cat(output_list_1st, dim=0) + given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400] + else: + given_tokens = tokenizer.encode(image_path=self.image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1) # generate subsequent frames: total_frames = self.generate_frame_num @@ -766,7 +778,7 @@ class Predictor(BasePredictor): for sample_i in range(sample_num): my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) - output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif' + output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif" subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True) yield output_file @@ -790,6 +802,6 @@ class Predictor(BasePredictor): for sample_i in range(sample_num): my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) - output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif' + output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif" subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True) yield output_file \ No newline at end of file From 7e5cd59fd0b23b96c3aa4471490857acf3121233 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 26 Jul 2022 22:58:12 -0700 Subject: [PATCH 2/5] Update predict.py --- predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predict.py b/predict.py index d59eb36..21a1524 100644 --- a/predict.py +++ b/predict.py @@ -499,7 +499,7 @@ class Predictor(BasePredictor): self.prompt = prompt if os.path.exists(image_prompt): try: - Image.open(image_prompt) + Image.open(str(image_prompt)) except (FileNotFoundError, UnidentifiedImageError): logging.debug("Bad image prompt; ignoring") # Is there a better way to input images? else: From e0067c5132611e09340895a88e4d5bc737a3a570 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 26 Jul 2022 22:58:19 -0700 Subject: [PATCH 3/5] Update predict.py --- predict.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/predict.py b/predict.py index 21a1524..18cfa32 100644 --- a/predict.py +++ b/predict.py @@ -497,7 +497,8 @@ class Predictor(BasePredictor): self.args.seed = seed self.args.use_guidance_stage1 = use_guidance self.prompt = prompt - if os.path.exists(image_prompt): + self.image_prompt = None + if os.path.exists(str(image_prompt)): try: Image.open(str(image_prompt)) except (FileNotFoundError, UnidentifiedImageError): From 634a513ba7b047b433b39c136b23428b5c19e699 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 26 Jul 2022 22:58:23 -0700 Subject: [PATCH 4/5] Update predict.py --- predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predict.py b/predict.py index 18cfa32..8e38b35 100644 --- a/predict.py +++ b/predict.py @@ -486,7 +486,7 @@ class Predictor(BasePredictor): description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True ), use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True), - image_prompt: Path = Input(description="Starting image") + image_prompt: Path = Input(description="Starting image", default=None) ) -> typing.Iterator[Path]: if translate: prompt = self.translator.translate(prompt.strip()) From a054801c18b13e88eddfdf6ccaa3b9ad9eaeba07 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 26 Jul 2022 22:58:29 -0700 Subject: [PATCH 5/5] Update predict.py --- predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predict.py b/predict.py index 8e38b35..93e26e9 100644 --- a/predict.py +++ b/predict.py @@ -504,7 +504,7 @@ class Predictor(BasePredictor): except (FileNotFoundError, UnidentifiedImageError): logging.debug("Bad image prompt; ignoring") # Is there a better way to input images? else: - self.image_prompt = Image.open(image_prompt) + self.image_prompt = str(image_prompt) self.args.both_stages = both_stages for file in self.run():