mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
237 lines
8.4 KiB
Python
237 lines
8.4 KiB
Python
import os
|
|
import math
|
|
import argparse
|
|
from typing import List, Union
|
|
from tqdm import tqdm
|
|
from omegaconf import ListConfig
|
|
import imageio
|
|
|
|
import torch
|
|
import numpy as np
|
|
from einops import rearrange
|
|
import torchvision.transforms as TT
|
|
|
|
from sat.model.base_model import get_model
|
|
from sat.training.model_io import load_checkpoint
|
|
from sat import mpu
|
|
|
|
from diffusion_video import SATVideoDiffusionEngine
|
|
from arguments import get_args
|
|
from torchvision.transforms.functional import center_crop, resize
|
|
from torchvision.transforms import InterpolationMode
|
|
|
|
|
|
def read_from_cli():
|
|
cnt = 0
|
|
try:
|
|
while True:
|
|
x = input("Please input English text (Ctrl-D quit): ")
|
|
yield x.strip(), cnt
|
|
cnt += 1
|
|
except EOFError as e:
|
|
pass
|
|
|
|
|
|
def read_from_file(p, rank=0, world_size=1):
|
|
with open(p, "r") as fin:
|
|
cnt = -1
|
|
for l in fin:
|
|
cnt += 1
|
|
if cnt % world_size != rank:
|
|
continue
|
|
yield l.strip(), cnt
|
|
|
|
|
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
|
return list(set([x.input_key for x in conditioner.embedders]))
|
|
|
|
|
|
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
|
|
batch = {}
|
|
batch_uc = {}
|
|
|
|
for key in keys:
|
|
if key == "txt":
|
|
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
|
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
|
else:
|
|
batch[key] = value_dict[key]
|
|
|
|
if T is not None:
|
|
batch["num_video_frames"] = T
|
|
|
|
for key in batch.keys():
|
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
|
batch_uc[key] = torch.clone(batch[key])
|
|
return batch, batch_uc
|
|
|
|
|
|
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
for i, vid in enumerate(video_batch):
|
|
gif_frames = []
|
|
for frame in vid:
|
|
frame = rearrange(frame, "c h w -> h w c")
|
|
frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
|
|
gif_frames.append(frame)
|
|
now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
|
|
with imageio.get_writer(now_save_path, fps=fps) as writer:
|
|
for frame in gif_frames:
|
|
writer.append_data(frame)
|
|
|
|
|
|
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
|
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
|
arr = resize(
|
|
arr,
|
|
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
)
|
|
else:
|
|
arr = resize(
|
|
arr,
|
|
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
)
|
|
|
|
h, w = arr.shape[2], arr.shape[3]
|
|
arr = arr.squeeze(0)
|
|
|
|
delta_h = h - image_size[0]
|
|
delta_w = w - image_size[1]
|
|
|
|
if reshape_mode == "random" or reshape_mode == "none":
|
|
top = np.random.randint(0, delta_h + 1)
|
|
left = np.random.randint(0, delta_w + 1)
|
|
elif reshape_mode == "center":
|
|
top, left = delta_h // 2, delta_w // 2
|
|
else:
|
|
raise NotImplementedError
|
|
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
|
return arr
|
|
|
|
|
|
def sampling_main(args, model_cls):
|
|
if isinstance(model_cls, type):
|
|
model = get_model(args, model_cls)
|
|
else:
|
|
model = model_cls
|
|
|
|
load_checkpoint(model, args)
|
|
model.eval()
|
|
|
|
if args.input_type == "cli":
|
|
data_iter = read_from_cli()
|
|
elif args.input_type == "txt":
|
|
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
|
|
print("rank and world_size", rank, world_size)
|
|
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
image_size = [480, 720]
|
|
|
|
sample_func = model.sample
|
|
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
|
|
num_samples = [1]
|
|
force_uc_zero_embeddings = ["txt"]
|
|
device = model.device
|
|
with torch.no_grad():
|
|
for text, cnt in tqdm(data_iter):
|
|
# reload model on GPU
|
|
model.to(device)
|
|
print("rank:", rank, "start to process", text, cnt)
|
|
# TODO: broadcast image2video
|
|
value_dict = {
|
|
"prompt": text,
|
|
"negative_prompt": "",
|
|
"num_frames": torch.tensor(T).unsqueeze(0),
|
|
}
|
|
|
|
batch, batch_uc = get_batch(
|
|
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
|
|
)
|
|
for key in batch:
|
|
if isinstance(batch[key], torch.Tensor):
|
|
print(key, batch[key].shape)
|
|
elif isinstance(batch[key], list):
|
|
print(key, [len(l) for l in batch[key]])
|
|
else:
|
|
print(key, batch[key])
|
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
|
batch,
|
|
batch_uc=batch_uc,
|
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
|
)
|
|
|
|
for k in c:
|
|
if not k == "crossattn":
|
|
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
|
|
for index in range(args.batch_size):
|
|
# reload model on GPU
|
|
model.to(device)
|
|
samples_z = sample_func(
|
|
c,
|
|
uc=uc,
|
|
batch_size=1,
|
|
shape=(T, C, H // F, W // F),
|
|
)
|
|
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
|
|
|
# Unload the model from GPU to save GPU memory
|
|
model.to("cpu")
|
|
torch.cuda.empty_cache()
|
|
first_stage_model = model.first_stage_model
|
|
first_stage_model = first_stage_model.to(device)
|
|
|
|
latent = 1.0 / model.scale_factor * samples_z
|
|
|
|
# Decode latent serial to save GPU memory
|
|
recons = []
|
|
loop_num = (T - 1) // 2
|
|
for i in range(loop_num):
|
|
if i == 0:
|
|
start_frame, end_frame = 0, 3
|
|
else:
|
|
start_frame, end_frame = i * 2 + 1, i * 2 + 3
|
|
if i == loop_num - 1:
|
|
clear_fake_cp_cache = True
|
|
else:
|
|
clear_fake_cp_cache = False
|
|
with torch.no_grad():
|
|
recon = first_stage_model.decode(
|
|
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
|
|
)
|
|
|
|
recons.append(recon)
|
|
|
|
recon = torch.cat(recons, dim=2).to(torch.float32)
|
|
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
|
|
|
save_path = os.path.join(
|
|
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
|
)
|
|
if mpu.get_model_parallel_rank() == 0:
|
|
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
|
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
|
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
|
|
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
|
|
py_parser = argparse.ArgumentParser(add_help=False)
|
|
known, args_list = py_parser.parse_known_args()
|
|
|
|
args = get_args(args_list)
|
|
args = argparse.Namespace(**vars(args), **vars(known))
|
|
del args.deepspeed_config
|
|
args.model_config.first_stage_config.params.cp_size = 1
|
|
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
|
|
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
|
|
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
|
|
|
|
sampling_main(args, model_cls=SATVideoDiffusionEngine)
|