mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 18:52:08 +08:00
Based image prompts
This commit is contained in:
parent
4f3eba5521
commit
c52d82f6b0
26
predict.py
26
predict.py
@ -4,7 +4,7 @@ import subprocess
|
|||||||
import tempfile
|
import tempfile
|
||||||
import random
|
import random
|
||||||
import typing
|
import typing
|
||||||
from typing_extensions import Self
|
from PIL import Image, UnidentifiedImageError
|
||||||
from deep_translator import GoogleTranslator
|
from deep_translator import GoogleTranslator
|
||||||
from cog import BasePredictor, Input, Path
|
from cog import BasePredictor, Input, Path
|
||||||
|
|
||||||
@ -444,18 +444,19 @@ class Predictor(BasePredictor):
|
|||||||
args.use_guidance_stage1 = True
|
args.use_guidance_stage1 = True
|
||||||
args.use_guidance_stage2 = False
|
args.use_guidance_stage2 = False
|
||||||
args.both_stages = True
|
args.both_stages = True
|
||||||
args.device = torch.device('cuda')
|
args.device = torch.device("cuda")
|
||||||
|
self.image_prompt = None
|
||||||
|
|
||||||
self.translator = GoogleTranslator(source="en", target="zh-CN")
|
self.translator = GoogleTranslator(source="en", target="zh-CN")
|
||||||
self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
|
self.model_stage1, args = InferenceModel_Sequential.from_pretrained(args, "cogvideo-stage1")
|
||||||
self.model_stage1.eval()
|
self.model_stage1.eval()
|
||||||
self.model_stage1 = self.model_stage1.cpu()
|
self.model_stage1 = self.model_stage1.cpu()
|
||||||
self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2')
|
self.model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, "cogvideo-stage2")
|
||||||
self.model_stage2.eval()
|
self.model_stage2.eval()
|
||||||
self.model_stage2 = self.model_stage2.cpu()
|
self.model_stage2 = self.model_stage2.cpu()
|
||||||
|
|
||||||
# enable dsr if model exists
|
# enable dsr if model exists
|
||||||
if os.path.exists('/sharefs/cogview-new/cogview2-dsr'):
|
if os.path.exists("/sharefs/cogview-new/cogview2-dsr"):
|
||||||
subprocess.check_output("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True)
|
subprocess.check_output("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True)
|
||||||
sys.path.append('./Image-Local-Attention')
|
sys.path.append('./Image-Local-Attention')
|
||||||
from sr_pipeline import DirectSuperResolution
|
from sr_pipeline import DirectSuperResolution
|
||||||
@ -485,6 +486,7 @@ class Predictor(BasePredictor):
|
|||||||
description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True
|
description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True
|
||||||
),
|
),
|
||||||
use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True),
|
use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True),
|
||||||
|
image_prompt: Path = Input(description="Starting image")
|
||||||
) -> typing.Iterator[Path]:
|
) -> typing.Iterator[Path]:
|
||||||
if translate:
|
if translate:
|
||||||
prompt = self.translator.translate(prompt.strip())
|
prompt = self.translator.translate(prompt.strip())
|
||||||
@ -495,6 +497,13 @@ class Predictor(BasePredictor):
|
|||||||
self.args.seed = seed
|
self.args.seed = seed
|
||||||
self.args.use_guidance_stage1 = use_guidance
|
self.args.use_guidance_stage1 = use_guidance
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
|
if os.path.exists(image_prompt):
|
||||||
|
try:
|
||||||
|
Image.open(image_prompt)
|
||||||
|
except (FileNotFoundError, UnidentifiedImageError):
|
||||||
|
logging.debug("Bad image prompt; ignoring") # Is there a better way to input images?
|
||||||
|
else:
|
||||||
|
self.image_prompt = Image.open(image_prompt)
|
||||||
self.args.both_stages = both_stages
|
self.args.both_stages = both_stages
|
||||||
|
|
||||||
for file in self.run():
|
for file in self.run():
|
||||||
@ -546,6 +555,7 @@ class Predictor(BasePredictor):
|
|||||||
text_len_1st = len(seq_1st) - frame_len*1 - 1
|
text_len_1st = len(seq_1st) - frame_len*1 - 1
|
||||||
|
|
||||||
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
|
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
|
||||||
|
if self.image_prompt is None:
|
||||||
output_list_1st = []
|
output_list_1st = []
|
||||||
for tim in range(max(batch_size // mbz, 1)):
|
for tim in range(max(batch_size // mbz, 1)):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -565,6 +575,8 @@ class Predictor(BasePredictor):
|
|||||||
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
|
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
|
||||||
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
||||||
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
|
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
|
||||||
|
else:
|
||||||
|
given_tokens = tokenizer.encode(image_path=self.image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
|
||||||
|
|
||||||
# generate subsequent frames:
|
# generate subsequent frames:
|
||||||
total_frames = self.generate_frame_num
|
total_frames = self.generate_frame_num
|
||||||
@ -766,7 +778,7 @@ class Predictor(BasePredictor):
|
|||||||
|
|
||||||
for sample_i in range(sample_num):
|
for sample_i in range(sample_num):
|
||||||
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||||
output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif'
|
output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif"
|
||||||
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
|
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
|
||||||
yield output_file
|
yield output_file
|
||||||
|
|
||||||
@ -790,6 +802,6 @@ class Predictor(BasePredictor):
|
|||||||
|
|
||||||
for sample_i in range(sample_num):
|
for sample_i in range(sample_num):
|
||||||
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||||
output_file = f'{outputdir}/{sample_i+sample_num*gpu_rank}.gif'
|
output_file = f"{outputdir}/{sample_i+sample_num*gpu_rank}.gif"
|
||||||
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
|
subprocess.check_output(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{output_file}' -d 0.125", shell=True)
|
||||||
yield output_file
|
yield output_file
|
||||||
Loading…
x
Reference in New Issue
Block a user