CogVideo/sat/sample_video.py
Yuxuan Zhang 39c6562dc8 format
2025-03-22 15:14:06 +08:00

302 lines
11 KiB
Python

import os
import math
import argparse
from typing import List, Union
from tqdm import tqdm
from omegaconf import ListConfig
from PIL import Image
import imageio
import torch
import numpy as np
from einops import rearrange, repeat
import torchvision.transforms as TT
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from sat import mpu
from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args
def read_from_cli():
cnt = 0
try:
while True:
x = input("Please input English text (Ctrl-D quit): ")
yield x.strip(), cnt
cnt += 1
except EOFError as e:
pass
def read_from_file(p, rank=0, world_size=1):
with open(p, "r") as fin:
cnt = -1
for l in fin:
cnt += 1
if cnt % world_size != rank:
continue
yield l.strip(), cnt
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "fps":
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
elif key == "fps_id":
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
)
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
device, dtype=torch.half
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def save_video_as_grid_and_mp4(
video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
):
os.makedirs(save_path, exist_ok=True)
for i, vid in enumerate(video_batch):
gif_frames = []
for frame in vid:
frame = rearrange(frame, "c h w -> h w c")
frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
gif_frames.append(frame)
now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
with imageio.get_writer(now_save_path, fps=fps) as writer:
for frame in gif_frames:
writer.append_data(frame)
def sampling_main(args, model_cls):
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
load_checkpoint(model, args)
model.eval()
if args.input_type == "cli":
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
sample_func = model.sample
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
T, C = args.sampling_num_frames, args.latent_channels
with torch.no_grad():
for text, cnt in tqdm(data_iter):
if args.image2video:
# use with input image shape
text, image_path = text.split("@@")
assert os.path.exists(image_path), image_path
image = Image.open(image_path).convert("RGB")
(img_W, img_H) = image.size
def nearest_multiple_of_16(n):
lower_multiple = (n // 16) * 16
upper_multiple = (n // 16 + 1) * 16
if abs(n - lower_multiple) < abs(n - upper_multiple):
return lower_multiple
else:
return upper_multiple
if img_H < img_W:
H = 96
W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
else:
W = 96
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
chained_trainsforms = []
chained_trainsforms.append(
TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)
)
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to("cuda")
image = image * 2.0 - 1.0
image = image.unsqueeze(2).to(torch.bfloat16)
image = model.encode_first_stage(image, None)
image = image / model.scale_factor
image = image.permute(0, 2, 1, 3, 4).contiguous()
pad_shape = (image.shape[0], T - 1, C, H, W)
image = torch.concat(
[image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1
)
else:
image_size = args.sampling_image_size
H, W = image_size[0], image_size[1]
F = 8 # 8x downsampled
image = None
text_cast = [text]
mp_size = mpu.get_model_parallel_world_size()
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast_object_list(
text_cast, src=src, group=mpu.get_model_parallel_group()
)
text = text_cast[0]
value_dict = {
"prompt": text,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
if args.image2video:
c["concat"] = image
uc["concat"] = image
for index in range(args.batch_size):
if args.image2video:
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H, W),
ofs=torch.tensor([2.0]).to("cuda"),
)
else:
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
).to("cuda")
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
if args.only_save_latents:
samples_z = 1.0 / model.scale_factor * samples_z
save_path = os.path.join(
args.output_dir,
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
str(index),
)
os.makedirs(save_path, exist_ok=True)
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
with open(os.path.join(save_path, "text.txt"), "w") as f:
f.write(text)
else:
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = os.path.join(
args.output_dir,
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
str(index),
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
if __name__ == "__main__":
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
del args.deepspeed_config
args.model_config.first_stage_config.params.cp_size = 1
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
sampling_main(args, model_cls=SATVideoDiffusionEngine)