Based image prompts

This commit is contained in:
neverix 2022-07-27 00:10:31 +03:00
parent 4f3eba5521
commit c52d82f6b0

View File

@ -4,7 +4,7 @@ import subprocess
import tempfile
import random
import typing
from typing_extensions import Self
from PIL import Image, UnidentifiedImageError
from deep_translator import GoogleTranslator
from cog import BasePredictor, Input, Path
@ -444,18 +444,19 @@ class Predictor(BasePredictor):
args.use_guidance_stage1 = True
args.use_guidance_stage2 = False
args.both_stages = True
args.device = torch.device('cuda')
args.device = torch.device("cuda")
self.image_prompt = None
self.translator = GoogleTranslator(source="en", target="zh-CN")
self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, "cogvideo-stage1")
self.model_stage1.eval()
self.model_stage1 = self.model_stage1.cpu()
self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2')
self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, "cogvideo-stage2")
self.model_stage2.eval()
self.model_stage2 = self.model_stage2.cpu()
# enable dsr if model exists
if os.path.exists('/sharefs/cogview-new/cogview2-dsr'):
if os.path.exists("/sharefs/cogview-new/cogview2-dsr"):
subprocess.check_output("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True)
sys.path.append('./Image-Local-Attention')
from sr_pipeline import DirectSuperResolution
@ -485,6 +486,7 @@ class Predictor(BasePredictor):
description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True
),
use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True),
image_prompt: Path = Input(description="Starting image")
) -> typing.Iterator[Path]:
if translate:
prompt = self.translator.translate(prompt.strip())
@ -495,6 +497,13 @@ class Predictor(BasePredictor):
self.args.seed = seed
self.args.use_guidance_stage1 = use_guidance
self.prompt = prompt
if os.path.exists(image_prompt):
try:
Image.open(image_prompt)
except (FileNotFoundError, UnidentifiedImageError):
logging.debug("Bad image prompt; ignoring") # Is there a better way to input images?
else:
self.image_prompt = Image.open(image_prompt)
self.args.both_stages = both_stages
for file in self.run():
@ -546,6 +555,7 @@ class Predictor(BasePredictor):
text_len_1st = len(seq_1st) - frame_len*1 - 1
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
if self.image_prompt is None:
output_list_1st = []
for tim in range(max(batch_size // mbz, 1)):
start_time = time.time()
@ -565,6 +575,8 @@ class Predictor(BasePredictor):
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
output_tokens_1st = torch.cat(output_list_1st, dim=0)
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
else:
given_tokens = tokenizer.encode(image_path=self.image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
# generate subsequent frames:
total_frames = self.generate_frame_num
@ -766,7 +778,7 @@ class Predictor(BasePredictor):
for sample_i in range(sample_num):
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif'
output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif"
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
yield output_file
@ -790,6 +802,6 @@ class Predictor(BasePredictor):
for sample_i in range(sample_num):
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif'
output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif"
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
yield output_file