import os import math import argparse from typing import List, Union from tqdm import tqdm from omegaconf import ListConfig from PIL import Image import imageio import torch import numpy as np from einops import rearrange, repeat import torchvision.transforms as TT from sat.model.base_model import get_model from sat.training.model_io import load_checkpoint from sat import mpu from diffusion_video import SATVideoDiffusionEngine from arguments import get_args def read_from_cli(): cnt = 0 try: while True: x = input("Please input English text (Ctrl-D quit): ") yield x.strip(), cnt cnt += 1 except EOFError as e: pass def read_from_file(p, rank=0, world_size=1): with open(p, "r") as fin: cnt = -1 for l in fin: cnt += 1 if cnt % world_size != rank: continue yield l.strip(), cnt def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"): batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = ( np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() ) batch_uc["txt"] = ( np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) .to(device) .repeat(*N, 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]) .to(device) .repeat(*N, 1) ) elif key == "aesthetic_score": batch["aesthetic_score"] = ( torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( torch.tensor([value_dict["target_height"], value_dict["target_width"]]) .to(device) .repeat(*N, 1) ) elif key == "fps": batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) elif key == "fps_id": batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N)) ) elif key == "pool_image": batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to( device, dtype=torch.half ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to("cuda"), "1 -> b", b=math.prod(N), ) elif key == "cond_frames": batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) elif key == "cond_frames_without_noise": batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def save_video_as_grid_and_mp4( video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None ): os.makedirs(save_path, exist_ok=True) for i, vid in enumerate(video_batch): gif_frames = [] for frame in vid: frame = rearrange(frame, "c h w -> h w c") frame = (255.0 * frame).cpu().numpy().astype(np.uint8) gif_frames.append(frame) now_save_path = os.path.join(save_path, f"{i:06d}.mp4") with imageio.get_writer(now_save_path, fps=fps) as writer: for frame in gif_frames: writer.append_data(frame) def sampling_main(args, model_cls): if isinstance(model_cls, type): model = get_model(args, model_cls) else: model = model_cls load_checkpoint(model, args) model.eval() if args.input_type == "cli": data_iter = read_from_cli() elif args.input_type == "txt": rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) else: raise NotImplementedError sample_func = model.sample num_samples = [1] force_uc_zero_embeddings = ["txt"] T, C = args.sampling_num_frames, args.latent_channels with torch.no_grad(): for text, cnt in tqdm(data_iter): if args.image2video: # use with input image shape text, image_path = text.split("@@") assert os.path.exists(image_path), image_path image = Image.open(image_path).convert("RGB") (img_W, img_H) = image.size def nearest_multiple_of_16(n): lower_multiple = (n // 16) * 16 upper_multiple = (n // 16 + 1) * 16 if abs(n - lower_multiple) < abs(n - upper_multiple): return lower_multiple else: return upper_multiple if img_H < img_W: H = 96 W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8 else: W = 96 H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8 chained_trainsforms = [] chained_trainsforms.append( TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1) ) chained_trainsforms.append(TT.ToTensor()) transform = TT.Compose(chained_trainsforms) image = transform(image).unsqueeze(0).to("cuda") image = image * 2.0 - 1.0 image = image.unsqueeze(2).to(torch.bfloat16) image = model.encode_first_stage(image, None) image = image / model.scale_factor image = image.permute(0, 2, 1, 3, 4).contiguous() pad_shape = (image.shape[0], T - 1, C, H, W) image = torch.concat( [image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1 ) else: image_size = args.sampling_image_size H, W = image_size[0], image_size[1] F = 8 # 8x downsampled image = None text_cast = [text] mp_size = mpu.get_model_parallel_world_size() global_rank = torch.distributed.get_rank() // mp_size src = global_rank * mp_size torch.distributed.broadcast_object_list( text_cast, src=src, group=mpu.get_model_parallel_group() ) text = text_cast[0] value_dict = { "prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0), } batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, ) for key in batch: if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape) elif isinstance(batch[key], list): print(key, [len(l) for l in batch[key]]) else: print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: if not k == "crossattn": c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) if args.image2video: c["concat"] = image uc["concat"] = image for index in range(args.batch_size): if args.image2video: samples_z = sample_func( c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda"), ) else: samples_z = sample_func( c, uc=uc, batch_size=1, shape=(T, C, H // F, W // F), ).to("cuda") samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() if args.only_save_latents: samples_z = 1.0 / model.scale_factor * samples_z save_path = os.path.join( args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index), ) os.makedirs(save_path, exist_ok=True) torch.save(samples_z, os.path.join(save_path, "latent.pt")) with open(os.path.join(save_path, "text.txt"), "w") as f: f.write(text) else: samples_x = model.decode_first_stage(samples_z).to(torch.float32) samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous() samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() save_path = os.path.join( args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index), ) if mpu.get_model_parallel_rank() == 0: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) if __name__ == "__main__": if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] py_parser = argparse.ArgumentParser(add_help=False) known, args_list = py_parser.parse_known_args() args = get_args(args_list) args = argparse.Namespace(**vars(args), **vars(known)) del args.deepspeed_config args.model_config.first_stage_config.params.cp_size = 1 args.model_config.network_config.params.transformer_args.model_parallel_size = 1 args.model_config.network_config.params.transformer_args.checkpoint_activations = False args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False sampling_main(args, model_cls=SATVideoDiffusionEngine)