update with diffusers

This commit is contained in:
zR 2024-11-11 22:41:28 +08:00
parent 68d93ce8fc
commit 2151a3bdfb

View File

@ -36,7 +36,7 @@ def generate_video(
model_path: str,
lora_path: str = None,
lora_rank: int = 128,
num_frames=81,
num_frames: int = 81,
output_path: str = "./output.mp4",
image_or_video_path: str = "",
num_inference_steps: int = 50,
@ -57,7 +57,7 @@ def generate_video(
- lora_rank (int): The rank of the LoRA weights.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- num_frames (int): Number of frames to generate.
- num_frames (int): Number of frames to generate. CogVideoX1.0 generates 49 frames for 6 seconds at 8 fps, while CogVideoX1.5 produces either 81 or 161 frames, corresponding to 5 seconds or 10 seconds at 16 fps.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
@ -99,9 +99,9 @@ def generate_video(
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
# and enable to("cuda")
# pipe.to("cuda")
pipe.to("cuda")
pipe.enable_sequential_cpu_offload()
# pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
@ -154,7 +154,7 @@ if __name__ == "__main__":
help="The path of the image to be used as the background of the video",
)
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model use"
"--model_path", type=str, default="THUDM/CogVideoX1.5-5B", help="Path of the pre-trained model use"
)
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")