mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
302 lines
11 KiB
Python
302 lines
11 KiB
Python
import os
|
|
import math
|
|
import argparse
|
|
from typing import List, Union
|
|
from tqdm import tqdm
|
|
from omegaconf import ListConfig
|
|
from PIL import Image
|
|
import imageio
|
|
|
|
import torch
|
|
import numpy as np
|
|
from einops import rearrange, repeat
|
|
import torchvision.transforms as TT
|
|
|
|
from sat.model.base_model import get_model
|
|
from sat.training.model_io import load_checkpoint
|
|
from sat import mpu
|
|
|
|
from diffusion_video import SATVideoDiffusionEngine
|
|
from arguments import get_args
|
|
|
|
|
|
def read_from_cli():
|
|
cnt = 0
|
|
try:
|
|
while True:
|
|
x = input("Please input English text (Ctrl-D quit): ")
|
|
yield x.strip(), cnt
|
|
cnt += 1
|
|
except EOFError as e:
|
|
pass
|
|
|
|
|
|
def read_from_file(p, rank=0, world_size=1):
|
|
with open(p, "r") as fin:
|
|
cnt = -1
|
|
for l in fin:
|
|
cnt += 1
|
|
if cnt % world_size != rank:
|
|
continue
|
|
yield l.strip(), cnt
|
|
|
|
|
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
|
return list(set([x.input_key for x in conditioner.embedders]))
|
|
|
|
|
|
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
|
|
batch = {}
|
|
batch_uc = {}
|
|
|
|
for key in keys:
|
|
if key == "txt":
|
|
batch["txt"] = (
|
|
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
|
)
|
|
batch_uc["txt"] = (
|
|
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
|
)
|
|
elif key == "original_size_as_tuple":
|
|
batch["original_size_as_tuple"] = (
|
|
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
|
.to(device)
|
|
.repeat(*N, 1)
|
|
)
|
|
elif key == "crop_coords_top_left":
|
|
batch["crop_coords_top_left"] = (
|
|
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
|
|
.to(device)
|
|
.repeat(*N, 1)
|
|
)
|
|
elif key == "aesthetic_score":
|
|
batch["aesthetic_score"] = (
|
|
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
|
)
|
|
batch_uc["aesthetic_score"] = (
|
|
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
|
)
|
|
|
|
elif key == "target_size_as_tuple":
|
|
batch["target_size_as_tuple"] = (
|
|
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
|
.to(device)
|
|
.repeat(*N, 1)
|
|
)
|
|
elif key == "fps":
|
|
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
|
elif key == "fps_id":
|
|
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
|
elif key == "motion_bucket_id":
|
|
batch[key] = (
|
|
torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
|
|
)
|
|
elif key == "pool_image":
|
|
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
|
device, dtype=torch.half
|
|
)
|
|
elif key == "cond_aug":
|
|
batch[key] = repeat(
|
|
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
|
"1 -> b",
|
|
b=math.prod(N),
|
|
)
|
|
elif key == "cond_frames":
|
|
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
|
elif key == "cond_frames_without_noise":
|
|
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
|
|
else:
|
|
batch[key] = value_dict[key]
|
|
|
|
if T is not None:
|
|
batch["num_video_frames"] = T
|
|
|
|
for key in batch.keys():
|
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
|
batch_uc[key] = torch.clone(batch[key])
|
|
return batch, batch_uc
|
|
|
|
|
|
def save_video_as_grid_and_mp4(
|
|
video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
|
|
):
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
for i, vid in enumerate(video_batch):
|
|
gif_frames = []
|
|
for frame in vid:
|
|
frame = rearrange(frame, "c h w -> h w c")
|
|
frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
|
|
gif_frames.append(frame)
|
|
now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
|
|
with imageio.get_writer(now_save_path, fps=fps) as writer:
|
|
for frame in gif_frames:
|
|
writer.append_data(frame)
|
|
|
|
|
|
def sampling_main(args, model_cls):
|
|
if isinstance(model_cls, type):
|
|
model = get_model(args, model_cls)
|
|
else:
|
|
model = model_cls
|
|
|
|
load_checkpoint(model, args)
|
|
model.eval()
|
|
|
|
if args.input_type == "cli":
|
|
data_iter = read_from_cli()
|
|
elif args.input_type == "txt":
|
|
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
|
|
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
sample_func = model.sample
|
|
num_samples = [1]
|
|
force_uc_zero_embeddings = ["txt"]
|
|
T, C = args.sampling_num_frames, args.latent_channels
|
|
with torch.no_grad():
|
|
for text, cnt in tqdm(data_iter):
|
|
if args.image2video:
|
|
# use with input image shape
|
|
text, image_path = text.split("@@")
|
|
assert os.path.exists(image_path), image_path
|
|
image = Image.open(image_path).convert("RGB")
|
|
(img_W, img_H) = image.size
|
|
|
|
def nearest_multiple_of_16(n):
|
|
lower_multiple = (n // 16) * 16
|
|
upper_multiple = (n // 16 + 1) * 16
|
|
if abs(n - lower_multiple) < abs(n - upper_multiple):
|
|
return lower_multiple
|
|
else:
|
|
return upper_multiple
|
|
|
|
if img_H < img_W:
|
|
H = 96
|
|
W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
|
|
else:
|
|
W = 96
|
|
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
|
|
chained_trainsforms = []
|
|
chained_trainsforms.append(
|
|
TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)
|
|
)
|
|
chained_trainsforms.append(TT.ToTensor())
|
|
transform = TT.Compose(chained_trainsforms)
|
|
image = transform(image).unsqueeze(0).to("cuda")
|
|
image = image * 2.0 - 1.0
|
|
image = image.unsqueeze(2).to(torch.bfloat16)
|
|
image = model.encode_first_stage(image, None)
|
|
image = image / model.scale_factor
|
|
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
|
pad_shape = (image.shape[0], T - 1, C, H, W)
|
|
image = torch.concat(
|
|
[image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1
|
|
)
|
|
else:
|
|
image_size = args.sampling_image_size
|
|
H, W = image_size[0], image_size[1]
|
|
F = 8 # 8x downsampled
|
|
image = None
|
|
|
|
text_cast = [text]
|
|
mp_size = mpu.get_model_parallel_world_size()
|
|
global_rank = torch.distributed.get_rank() // mp_size
|
|
src = global_rank * mp_size
|
|
torch.distributed.broadcast_object_list(
|
|
text_cast, src=src, group=mpu.get_model_parallel_group()
|
|
)
|
|
text = text_cast[0]
|
|
value_dict = {
|
|
"prompt": text,
|
|
"negative_prompt": "",
|
|
"num_frames": torch.tensor(T).unsqueeze(0),
|
|
}
|
|
|
|
batch, batch_uc = get_batch(
|
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
|
value_dict,
|
|
num_samples,
|
|
)
|
|
for key in batch:
|
|
if isinstance(batch[key], torch.Tensor):
|
|
print(key, batch[key].shape)
|
|
elif isinstance(batch[key], list):
|
|
print(key, [len(l) for l in batch[key]])
|
|
else:
|
|
print(key, batch[key])
|
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
|
batch,
|
|
batch_uc=batch_uc,
|
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
|
)
|
|
|
|
for k in c:
|
|
if not k == "crossattn":
|
|
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
|
|
|
|
if args.image2video:
|
|
c["concat"] = image
|
|
uc["concat"] = image
|
|
|
|
for index in range(args.batch_size):
|
|
if args.image2video:
|
|
samples_z = sample_func(
|
|
c,
|
|
uc=uc,
|
|
batch_size=1,
|
|
shape=(T, C, H, W),
|
|
ofs=torch.tensor([2.0]).to("cuda"),
|
|
)
|
|
else:
|
|
samples_z = sample_func(
|
|
c,
|
|
uc=uc,
|
|
batch_size=1,
|
|
shape=(T, C, H // F, W // F),
|
|
).to("cuda")
|
|
|
|
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
|
if args.only_save_latents:
|
|
samples_z = 1.0 / model.scale_factor * samples_z
|
|
save_path = os.path.join(
|
|
args.output_dir,
|
|
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
|
|
str(index),
|
|
)
|
|
os.makedirs(save_path, exist_ok=True)
|
|
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
|
|
with open(os.path.join(save_path, "text.txt"), "w") as f:
|
|
f.write(text)
|
|
else:
|
|
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
|
|
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
|
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
|
save_path = os.path.join(
|
|
args.output_dir,
|
|
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
|
|
str(index),
|
|
)
|
|
if mpu.get_model_parallel_rank() == 0:
|
|
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
|
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
|
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
|
|
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
|
|
py_parser = argparse.ArgumentParser(add_help=False)
|
|
known, args_list = py_parser.parse_known_args()
|
|
|
|
args = get_args(args_list)
|
|
args = argparse.Namespace(**vars(args), **vars(known))
|
|
del args.deepspeed_config
|
|
args.model_config.first_stage_config.params.cp_size = 1
|
|
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
|
|
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
|
|
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
|
|
|
|
sampling_main(args, model_cls=SATVideoDiffusionEngine)
|