update gradio webdemo

This commit is contained in:
zR 2024-08-07 16:49:11 +08:00
parent a73101c958
commit 5a69462c8b
5 changed files with 8 additions and 23 deletions

View File

@ -132,14 +132,14 @@ of the **CogVideoX** open-source model.
CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training
distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as
GPT, Gemini, etc.
+ [gradio_demo](gradio_demo.py): A simple gradio web UI demonstrating how to use the CogVideoX-2B model to generate
+ [gradio_web_demo](inference/gradio_web_demo.py): A simple gradio web UI demonstrating how to use the CogVideoX-2B model to generate
videos.
<div style="text-align: center;">
<img src="resources/gradio_demo.png" style="width: 100%; height: auto;" />
</div>
+ [web_demo](inference/web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model
+ [streamlit_web_demo](inference/streamlit_web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model
to generate videos.
<div style="text-align: center;">

View File

@ -116,13 +116,13 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
+ [diffusers_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。
+ [diffusers_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码目前需要71GB显存将来会优化。
+ [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。因为CogVideoX是在长文本上训练的所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4也可以替换为GPT、Gemini等任意大语言模型。
+ [gradio_demo](gradio_demo.py): 一个简单的gradio网页应用展示如何使用 CogVideoX-2B 模型生成视频。
+ [gradio_web_demo](inference/gradio_web_demo.py): 一个简单的gradio网页应用展示如何使用 CogVideoX-2B 模型生成视频。
<div style="text-align: center;">
<img src="resources/gradio_demo.png" style="width: 100%; height: auto;" />
</div>
+ [web_demo](inference/web_demo.py): 一个简单的streamlit网页应用展示如何使用 CogVideoX-2B 模型生成视频。
+ [streamlit_web_demo](inference/streamlit_web_demo.py): 一个简单的streamlit网页应用展示如何使用 CogVideoX-2B 模型生成视频。
<div style="text-align: center;">
<img src="resources/web_demo.png" style="width: 100%; height: auto;" />

View File

@ -47,20 +47,6 @@ def generate_video(
device: str = "cuda",
dtype: torch.dtype = torch.float16,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
Parameters:
- prompt (str): The description of the video to be generated.
- model_path (str): The path of the pre-trained model to be used.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
- dtype (torch.dtype): The data type for computation (default is torch.float16).
"""
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
@ -74,7 +60,8 @@ def generate_video(
device=device, # Device to use for computation
dtype=dtype, # Data type for computation
)
# Must enable model CPU offload to avoid OOM issue on GPU with 24GB memory
pipe.enable_model_cpu_offload()
# Generate the video frames using the pipeline
video = pipe(
num_inference_steps=num_inference_steps, # Number of inference steps
@ -82,11 +69,9 @@ def generate_video(
prompt_embeds=prompt_embeds, # Encoded prompt embeddings
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
).frames[0]
# Export the generated frames to a video file. fps must be 8
export_to_video_imageio(video, output_path, fps=8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")

View File

@ -104,7 +104,7 @@ def infer(
device=device,
dtype=dtype,
)
pipe.enable_model_cpu_offload()
video = pipe(
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,

View File

@ -76,7 +76,7 @@ def generate_video(
device=device,
dtype=dtype,
)
pipe.enable_model_cpu_offload()
# Generate video
video = pipe(
num_inference_steps=num_inference_steps,