Based image prompts

This commit is contained in:
neverix 2022-07-27 00:10:31 +03:00
parent 4f3eba5521
commit c52d82f6b0

View File

@ -4,7 +4,7 @@ import subprocess
import tempfile import tempfile
import random import random
import typing import typing
from typing_extensions import Self from PIL import Image, UnidentifiedImageError
from deep_translator import GoogleTranslator from deep_translator import GoogleTranslator
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path
@ -444,18 +444,19 @@ class Predictor(BasePredictor):
args.use_guidance_stage1 = True args.use_guidance_stage1 = True
args.use_guidance_stage2 = False args.use_guidance_stage2 = False
args.both_stages = True args.both_stages = True
args.device = torch.device('cuda') args.device = torch.device("cuda")
self.image_prompt = None
self.translator = GoogleTranslator(source="en", target="zh-CN") self.translator = GoogleTranslator(source="en", target="zh-CN")
self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1') self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, "cogvideo-stage1")
self.model_stage1.eval() self.model_stage1.eval()
self.model_stage1 = self.model_stage1.cpu() self.model_stage1 = self.model_stage1.cpu()
self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2') self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, "cogvideo-stage2")
self.model_stage2.eval() self.model_stage2.eval()
self.model_stage2 = self.model_stage2.cpu() self.model_stage2 = self.model_stage2.cpu()
# enable dsr if model exists # enable dsr if model exists
if os.path.exists('/sharefs/cogview-new/cogview2-dsr'): if os.path.exists("/sharefs/cogview-new/cogview2-dsr"):
subprocess.check_output("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True) subprocess.check_output("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True)
sys.path.append('./Image-Local-Attention') sys.path.append('./Image-Local-Attention')
from sr_pipeline import DirectSuperResolution from sr_pipeline import DirectSuperResolution
@ -485,6 +486,7 @@ class Predictor(BasePredictor):
description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True
), ),
use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True), use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True),
image_prompt: Path = Input(description="Starting image")
) -> typing.Iterator[Path]: ) -> typing.Iterator[Path]:
if translate: if translate:
prompt = self.translator.translate(prompt.strip()) prompt = self.translator.translate(prompt.strip())
@ -495,6 +497,13 @@ class Predictor(BasePredictor):
self.args.seed = seed self.args.seed = seed
self.args.use_guidance_stage1 = use_guidance self.args.use_guidance_stage1 = use_guidance
self.prompt = prompt self.prompt = prompt
if os.path.exists(image_prompt):
try:
Image.open(image_prompt)
except (FileNotFoundError, UnidentifiedImageError):
logging.debug("Bad image prompt; ignoring") # Is there a better way to input images?
else:
self.image_prompt = Image.open(image_prompt)
self.args.both_stages = both_stages self.args.both_stages = both_stages
for file in self.run(): for file in self.run():
@ -546,25 +555,28 @@ class Predictor(BasePredictor):
text_len_1st = len(seq_1st) - frame_len*1 - 1 text_len_1st = len(seq_1st) - frame_len*1 - 1
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0) seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
output_list_1st = [] if self.image_prompt is None:
for tim in range(max(batch_size // mbz, 1)): output_list_1st = []
start_time = time.time() for tim in range(max(batch_size // mbz, 1)):
output_list_1st.append( start_time = time.time()
my_filling_sequence(self.model_stage1, args,seq_1st.clone(), output_list_1st.append(
batch_size=min(batch_size, mbz), my_filling_sequence(self.model_stage1, args,seq_1st.clone(),
get_masks_and_position_ids=get_masks_and_position_ids_stage1, batch_size=min(batch_size, mbz),
text_len=text_len_1st, get_masks_and_position_ids=get_masks_and_position_ids_stage1,
frame_len=frame_len, text_len=text_len_1st,
strategy=strategy_cogview2, frame_len=frame_len,
strategy2=strategy_cogvideo, strategy=strategy_cogview2,
log_text_attention_weights=1.4, strategy2=strategy_cogvideo,
enforce_no_swin=True, log_text_attention_weights=1.4,
mode_stage1=True, enforce_no_swin=True,
)[0] mode_stage1=True,
) )[0]
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)) )
output_tokens_1st = torch.cat(output_list_1st, dim=0) logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400] output_tokens_1st = torch.cat(output_list_1st, dim=0)
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
else:
given_tokens = tokenizer.encode(image_path=self.image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
# generate subsequent frames: # generate subsequent frames:
total_frames = self.generate_frame_num total_frames = self.generate_frame_num
@ -766,7 +778,7 @@ class Predictor(BasePredictor):
for sample_i in range(sample_num): for sample_i in range(sample_num):
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif' output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif"
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True) subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
yield output_file yield output_file
@ -790,6 +802,6 @@ class Predictor(BasePredictor):
for sample_i in range(sample_num): for sample_i in range(sample_num):
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif' output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif"
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True) subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
yield output_file yield output_file