From 0360745dc8afd395f117345d1f1dea65a72da940 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Fri, 23 Aug 2024 15:04:55 +0800 Subject: [PATCH] update cli_demo --- inference/cli_demo.py | 32 ++++++++++++++++---------------- sat/configs/inference.yaml | 2 +- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/inference/cli_demo.py b/inference/cli_demo.py index 5e69ac1..e63ce05 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -21,8 +21,9 @@ import numpy as np import torch from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler + def export_to_video_imageio( - video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 ) -> str: """ Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX) @@ -38,13 +39,13 @@ def export_to_video_imageio( def generate_video( - prompt: str, - model_path: str, - output_path: str = "./output.mp4", - num_inference_steps: int = 50, - guidance_scale: float = 6.0, - num_videos_per_prompt: int = 1, - dtype: torch.dtype = torch.float16, + prompt: str, + model_path: str, + output_path: str = "./output.mp4", + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int = 1, + dtype: torch.dtype = torch.bloat16, ): """ Generates a video based on the given prompt and saves it to the specified path. @@ -56,11 +57,11 @@ def generate_video( - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality. - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt. - num_videos_per_prompt (int): Number of videos to generate per prompt. - - dtype (torch.dtype): The data type for computation (default is torch.float16). + - dtype (torch.dtype): The data type for computation (default is torch.bfloat16). """ - # 1. Load the pre-trained CogVideoX pipeline with the specified precision (float16). + # 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16). # add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload() # function to use Multi GPUs. @@ -79,8 +80,7 @@ def generate_video( torch.cuda.reset_accumulated_memory_stats() torch.cuda.reset_peak_memory_stats() - # Using with diffusers branch `main` to enable tiling. This will cost ONLY 12GB GPU memory. - # pipe.vae.enable_tiling() + pipe.vae.enable_tiling() # 4. Generate the video frames based on the prompt. # `num_frames` is the Number of frames to generate. @@ -90,7 +90,7 @@ def generate_video( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt num_inference_steps=num_inference_steps, # Number of inference steps - num_frames=48, # Number of frames to generateļ¼Œchanged to 49 for diffusers version `0.31.0` and after. + num_frames=49, # Number of frames to generateļ¼Œchanged to 49 for diffusers version `0.31.0` and after. guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance generator=torch.Generator().manual_seed(42), # Set the seed for reproducibility ).frames[0] @@ -103,7 +103,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") parser.add_argument( - "--model_path", type=str, default="THUDM/CogVideoX-2b", help="The path of the pre-trained model to be used" + "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used" ) parser.add_argument( "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" @@ -114,13 +114,13 @@ if __name__ == "__main__": parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") parser.add_argument( - "--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')" + "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" ) args = parser.parse_args() # Convert dtype argument to torch.dtype, NOT suggest BF16. - dtype = torch.float16 if args.dtype == "float16" else torch.float32 + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 # main function to generate video. generate_video( diff --git a/sat/configs/inference.yaml b/sat/configs/inference.yaml index bf90d34..a745639 100644 --- a/sat/configs/inference.yaml +++ b/sat/configs/inference.yaml @@ -1,7 +1,7 @@ args: latent_channels: 16 mode: inference - # load: "{your_CogVideoX-2b-sat_path}/transformer" # This is for Full model without lora adapter + load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter batch_size: 1