mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
update cli_demo
This commit is contained in:
parent
edcddcd99c
commit
0360745dc8
@ -21,8 +21,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler
|
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler
|
||||||
|
|
||||||
|
|
||||||
def export_to_video_imageio(
|
def export_to_video_imageio(
|
||||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
|
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
|
||||||
@ -38,13 +39,13 @@ def export_to_video_imageio(
|
|||||||
|
|
||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
output_path: str = "./output.mp4",
|
output_path: str = "./output.mp4",
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
guidance_scale: float = 6.0,
|
guidance_scale: float = 6.0,
|
||||||
num_videos_per_prompt: int = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.bloat16,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a video based on the given prompt and saves it to the specified path.
|
Generates a video based on the given prompt and saves it to the specified path.
|
||||||
@ -56,11 +57,11 @@ def generate_video(
|
|||||||
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
||||||
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
||||||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||||||
- dtype (torch.dtype): The data type for computation (default is torch.float16).
|
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (float16).
|
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
|
||||||
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
|
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
|
||||||
# function to use Multi GPUs.
|
# function to use Multi GPUs.
|
||||||
|
|
||||||
@ -79,8 +80,7 @@ def generate_video(
|
|||||||
torch.cuda.reset_accumulated_memory_stats()
|
torch.cuda.reset_accumulated_memory_stats()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
# Using with diffusers branch `main` to enable tiling. This will cost ONLY 12GB GPU memory.
|
pipe.vae.enable_tiling()
|
||||||
# pipe.vae.enable_tiling()
|
|
||||||
|
|
||||||
# 4. Generate the video frames based on the prompt.
|
# 4. Generate the video frames based on the prompt.
|
||||||
# `num_frames` is the Number of frames to generate.
|
# `num_frames` is the Number of frames to generate.
|
||||||
@ -90,7 +90,7 @@ def generate_video(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
||||||
num_inference_steps=num_inference_steps, # Number of inference steps
|
num_inference_steps=num_inference_steps, # Number of inference steps
|
||||||
num_frames=48, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after.
|
num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after.
|
||||||
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
|
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
|
||||||
generator=torch.Generator().manual_seed(42), # Set the seed for reproducibility
|
generator=torch.Generator().manual_seed(42), # Set the seed for reproducibility
|
||||||
).frames[0]
|
).frames[0]
|
||||||
@ -103,7 +103,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path", type=str, default="THUDM/CogVideoX-2b", help="The path of the pre-trained model to be used"
|
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
||||||
@ -114,13 +114,13 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Convert dtype argument to torch.dtype, NOT suggest BF16.
|
# Convert dtype argument to torch.dtype, NOT suggest BF16.
|
||||||
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
||||||
|
|
||||||
# main function to generate video.
|
# main function to generate video.
|
||||||
generate_video(
|
generate_video(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
args:
|
args:
|
||||||
latent_channels: 16
|
latent_channels: 16
|
||||||
mode: inference
|
mode: inference
|
||||||
# load: "{your_CogVideoX-2b-sat_path}/transformer" # This is for Full model without lora adapter
|
load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter
|
||||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||||
|
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user