diff --git a/README_zh.md b/README_zh.md
index cae5b40..2e26810 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -116,13 +116,13 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
+ [diffusers_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。
+ [diffusers_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码,目前需要71GB显存,将来会优化。
+ [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。因为CogVideoX是在长文本上训练的,所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4,也可以替换为GPT、Gemini等任意大语言模型。
-+ [gradio_demo](gradio_demo.py): 一个简单的gradio网页应用,展示如何使用 CogVideoX-2B 模型生成视频。
++ [gradio_web_demo](inference/gradio_web_demo.py): 一个简单的gradio网页应用,展示如何使用 CogVideoX-2B 模型生成视频。
-+ [web_demo](inference/web_demo.py): 一个简单的streamlit网页应用,展示如何使用 CogVideoX-2B 模型生成视频。
++ [streamlit_web_demo](inference/streamlit_web_demo.py): 一个简单的streamlit网页应用,展示如何使用 CogVideoX-2B 模型生成视频。

diff --git a/inference/cli_demo.py b/inference/cli_demo.py
index c480d43..a1bb764 100644
--- a/inference/cli_demo.py
+++ b/inference/cli_demo.py
@@ -47,20 +47,6 @@ def generate_video(
device: str = "cuda",
dtype: torch.dtype = torch.float16,
):
- """
- Generates a video based on the given prompt and saves it to the specified path.
-
- Parameters:
- - prompt (str): The description of the video to be generated.
- - model_path (str): The path of the pre-trained model to be used.
- - output_path (str): The path where the generated video will be saved.
- - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- - num_videos_per_prompt (int): Number of videos to generate per prompt.
- - device (str): The device to use for computation (e.g., "cuda" or "cpu").
- - dtype (torch.dtype): The data type for computation (default is torch.float16).
- """
-
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
@@ -74,7 +60,8 @@ def generate_video(
device=device, # Device to use for computation
dtype=dtype, # Data type for computation
)
-
+ # Must enable model CPU offload to avoid OOM issue on GPU with 24GB memory
+ pipe.enable_model_cpu_offload()
# Generate the video frames using the pipeline
video = pipe(
num_inference_steps=num_inference_steps, # Number of inference steps
@@ -82,11 +69,9 @@ def generate_video(
prompt_embeds=prompt_embeds, # Encoded prompt embeddings
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
).frames[0]
-
# Export the generated frames to a video file. fps must be 8
export_to_video_imageio(video, output_path, fps=8)
-
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
diff --git a/gradio_demo.py b/inference/gradio_web_demo.py
similarity index 99%
rename from gradio_demo.py
rename to inference/gradio_web_demo.py
index ea0b020..9f36254 100644
--- a/gradio_demo.py
+++ b/inference/gradio_web_demo.py
@@ -104,7 +104,7 @@ def infer(
device=device,
dtype=dtype,
)
-
+ pipe.enable_model_cpu_offload()
video = pipe(
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
diff --git a/inference/web_demo.py b/inference/streamlit_web_demo.py
similarity index 99%
rename from inference/web_demo.py
rename to inference/streamlit_web_demo.py
index 8695975..6df62db 100644
--- a/inference/web_demo.py
+++ b/inference/streamlit_web_demo.py
@@ -76,7 +76,7 @@ def generate_video(
device=device,
dtype=dtype,
)
-
+ pipe.enable_model_cpu_offload()
# Generate video
video = pipe(
num_inference_steps=num_inference_steps,