diff --git a/README.md b/README.md index 3341042..25b7d33 100644 --- a/README.md +++ b/README.md @@ -60,8 +60,8 @@ along with related basic information: | Model Name | CogVideoX-2B | |-------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------| | Prompt Language | English | -| GPU Memory Required for Inference (FP16) | 36GB using diffusers (will be optimized before the PR is merged) and 18GB using [SAT](https://github.com/THUDM/SwissArmyTransformer) | -| GPU Memory Required for Fine-tuning(bs=1) | 42GB | +| GPU Memory Required for Inference (FP16) | 18GB if using [SAT](https://github.com/THUDM/SwissArmyTransformer); 36GB if using diffusers (will be optimized before the PR is merged) | +| GPU Memory Required for Fine-tuning(bs=1) | 40GB | | Prompt Max Length | 226 Tokens | | Video Length | 6 seconds | | Frames Per Second | 8 frames | @@ -112,7 +112,7 @@ This folder contains some tools for model conversion / caption generation, etc. - [x] CogVideoX model fine-tuning example (SAT) - [ ] CogVideoX model fine-tuning example (Huggingface / SAT) - [ ] Open source CogVideoX-Pro (adapted for CogVideoX-2B suite) - - [ ] Release CogVideoX technical report + - [x] Release CogVideoX technical report We welcome your contributions. You can click [here](resources/contribute.md) for more information. @@ -126,4 +126,4 @@ The model weights and implementation code are released under the [CogVideoX LICE 🌟 If you find our work helpful, please leave us a star. 🌟 -The paper is still being written and will be released soon. Stay tuned! \ No newline at end of file +The paper on Arxiv is coming soon! \ No newline at end of file diff --git a/sat/configs/cogvideox_2b_infer.yaml b/sat/configs/cogvideox_2b_infer.yaml index a9cc17c..adf9de2 100644 --- a/sat/configs/cogvideox_2b_infer.yaml +++ b/sat/configs/cogvideox_2b_infer.yaml @@ -5,7 +5,7 @@ args: batch_size: 1 input_type: txt input_file: test.txt - sampling_num_frames: 13 # Must be 11,13 or 19 + sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_fps: 8 fp16: True output_dir: outputs/ diff --git a/sat/configs/cogvideox_2b_sft.yaml b/sat/configs/cogvideox_2b_sft.yaml index 21cd2d7..1cac09b 100644 --- a/sat/configs/cogvideox_2b_sft.yaml +++ b/sat/configs/cogvideox_2b_sft.yaml @@ -122,7 +122,7 @@ model: lora_config: ## Using Lora target: sat.model.finetune.lora2.LoraMixin params: - r: 256 + r: 128 patch_embed_config: target: dit_video_concat.ImagePatchEmbeddingMixin diff --git a/sat/configs/test.txt b/sat/configs/test.txt index 02c383d..b732bbd 100644 --- a/sat/configs/test.txt +++ b/sat/configs/test.txt @@ -1,2 +1,3 @@ -A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance. -A cinematic view of Earth rotating in space, showcasing the planet's vibrant blue oceans and swirling white clouds, high quality, realistic. The scene transitions from day to night, highlighting the twinkling city lights and the soft glow of the moon reflecting on the surface. Stars and distant galaxies form a breathtaking backdrop, adding to the grandeur and beauty of the Earth seen from space. \ No newline at end of file +In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict. +The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds. +A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting. \ No newline at end of file diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 07dcb85..951e93e 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -166,7 +166,7 @@ class SATVideoDiffusionEngine(nn.Module): else: kwargs = {} use_cp = False - out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs) + out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs) all_out.append(out) out = torch.cat(all_out, dim=0) return out @@ -186,7 +186,7 @@ class SATVideoDiffusionEngine(nn.Module): all_out = [] with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): - out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples], use_cp=use_cp) + out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples]) all_out.append(out) z = torch.cat(all_out, dim=0) z = self.scale_factor * z diff --git a/sat/finetune.sh b/sat/finetune.sh index 1d353ac..da31247 100644 --- a/sat/finetune.sh +++ b/sat/finetune.sh @@ -1,6 +1,5 @@ #! /bin/bash -module load cuda echo "RUN on `hostname`, CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" diff --git a/sat/sample_video.py b/sat/sample_video.py index 30d7794..8ca4b5a 100644 --- a/sat/sample_video.py +++ b/sat/sample_video.py @@ -136,8 +136,11 @@ def sampling_main(args, model_cls): T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8 num_samples = [1] force_uc_zero_embeddings = ["txt"] + device = model.device with torch.no_grad(): for text, cnt in tqdm(data_iter): + # reload model on GPU + model.to(device) print("rank:", rank, "start to process", text, cnt) # TODO: broadcast image2video value_dict = { @@ -166,6 +169,8 @@ def sampling_main(args, model_cls): if not k == "crossattn": c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) for index in range(args.batch_size): + # reload model on GPU + model.to(device) samples_z = sample_func( c, uc=uc, @@ -173,11 +178,18 @@ def sampling_main(args, model_cls): shape=(T, C, H // F, W // F), ) samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() + + # Unload the model from GPU to save GPU memory + model.to('cpu') + torch.cuda.empty_cache() + first_stage_model = model.first_stage_model + first_stage_model = first_stage_model.to(device) latent = 1.0 / model.scale_factor * samples_z - + + # Decode latent serial to save GPU memory recons = [] - loop_num = (T - 1) // 2 + loop_num = (T-1)//2 for i in range(loop_num): if i == 0: start_frame, end_frame = 0, 3 @@ -188,11 +200,10 @@ def sampling_main(args, model_cls): else: clear_fake_cp_cache = False with torch.no_grad(): - recon = model.first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache) + recon = first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache) recons.append(recon) - recon = torch.cat(recons, dim=2).to(torch.float32) samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()