From 5a69462c8b3cd282fe4ce5c52bc4f7d9d1993601 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Wed, 7 Aug 2024 16:49:11 +0800 Subject: [PATCH] update gradio webdemo --- README.md | 4 ++-- README_zh.md | 4 ++-- inference/cli_demo.py | 19 ++----------------- .../gradio_web_demo.py | 2 +- .../{web_demo.py => streamlit_web_demo.py} | 2 +- 5 files changed, 8 insertions(+), 23 deletions(-) rename gradio_demo.py => inference/gradio_web_demo.py (99%) rename inference/{web_demo.py => streamlit_web_demo.py} (99%) diff --git a/README.md b/README.md index c292386..cc96605 100644 --- a/README.md +++ b/README.md @@ -132,14 +132,14 @@ of the **CogVideoX** open-source model. CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as GPT, Gemini, etc. -+ [gradio_demo](gradio_demo.py): A simple gradio web UI demonstrating how to use the CogVideoX-2B model to generate ++ [gradio_web_demo](inference/gradio_web_demo.py): A simple gradio web UI demonstrating how to use the CogVideoX-2B model to generate videos.
-+ [web_demo](inference/web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model ++ [streamlit_web_demo](inference/streamlit_web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model to generate videos.
diff --git a/README_zh.md b/README_zh.md index cae5b40..2e26810 100644 --- a/README_zh.md +++ b/README_zh.md @@ -116,13 +116,13 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源 + [diffusers_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。 + [diffusers_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码,目前需要71GB显存,将来会优化。 + [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。因为CogVideoX是在长文本上训练的,所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4,也可以替换为GPT、Gemini等任意大语言模型。 -+ [gradio_demo](gradio_demo.py): 一个简单的gradio网页应用,展示如何使用 CogVideoX-2B 模型生成视频。 ++ [gradio_web_demo](inference/gradio_web_demo.py): 一个简单的gradio网页应用,展示如何使用 CogVideoX-2B 模型生成视频。
-+ [web_demo](inference/web_demo.py): 一个简单的streamlit网页应用,展示如何使用 CogVideoX-2B 模型生成视频。 ++ [streamlit_web_demo](inference/streamlit_web_demo.py): 一个简单的streamlit网页应用,展示如何使用 CogVideoX-2B 模型生成视频。
diff --git a/inference/cli_demo.py b/inference/cli_demo.py index c480d43..a1bb764 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -47,20 +47,6 @@ def generate_video( device: str = "cuda", dtype: torch.dtype = torch.float16, ): - """ - Generates a video based on the given prompt and saves it to the specified path. - - Parameters: - - prompt (str): The description of the video to be generated. - - model_path (str): The path of the pre-trained model to be used. - - output_path (str): The path where the generated video will be saved. - - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality. - - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt. - - num_videos_per_prompt (int): Number of videos to generate per prompt. - - device (str): The device to use for computation (e.g., "cuda" or "cpu"). - - dtype (torch.dtype): The data type for computation (default is torch.float16). - """ - # Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) @@ -74,7 +60,8 @@ def generate_video( device=device, # Device to use for computation dtype=dtype, # Data type for computation ) - + # Must enable model CPU offload to avoid OOM issue on GPU with 24GB memory + pipe.enable_model_cpu_offload() # Generate the video frames using the pipeline video = pipe( num_inference_steps=num_inference_steps, # Number of inference steps @@ -82,11 +69,9 @@ def generate_video( prompt_embeds=prompt_embeds, # Encoded prompt embeddings negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt ).frames[0] - # Export the generated frames to a video file. fps must be 8 export_to_video_imageio(video, output_path, fps=8) - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") diff --git a/gradio_demo.py b/inference/gradio_web_demo.py similarity index 99% rename from gradio_demo.py rename to inference/gradio_web_demo.py index ea0b020..9f36254 100644 --- a/gradio_demo.py +++ b/inference/gradio_web_demo.py @@ -104,7 +104,7 @@ def infer( device=device, dtype=dtype, ) - + pipe.enable_model_cpu_offload() video = pipe( num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, diff --git a/inference/web_demo.py b/inference/streamlit_web_demo.py similarity index 99% rename from inference/web_demo.py rename to inference/streamlit_web_demo.py index 8695975..6df62db 100644 --- a/inference/web_demo.py +++ b/inference/streamlit_web_demo.py @@ -76,7 +76,7 @@ def generate_video( device=device, dtype=dtype, ) - + pipe.enable_model_cpu_offload() # Generate video video = pipe( num_inference_steps=num_inference_steps,