update gradio webdemo

This commit is contained in:
zR 2024-08-07 16:49:11 +08:00
parent a73101c958
commit 5a69462c8b
5 changed files with 8 additions and 23 deletions

View File

@ -132,14 +132,14 @@ of the **CogVideoX** open-source model.
CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training
distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as
GPT, Gemini, etc. GPT, Gemini, etc.
+ [gradio_demo](gradio_demo.py): A simple gradio web UI demonstrating how to use the CogVideoX-2B model to generate + [gradio_web_demo](inference/gradio_web_demo.py): A simple gradio web UI demonstrating how to use the CogVideoX-2B model to generate
videos. videos.
<div style="text-align: center;"> <div style="text-align: center;">
<img src="resources/gradio_demo.png" style="width: 100%; height: auto;" /> <img src="resources/gradio_demo.png" style="width: 100%; height: auto;" />
</div> </div>
+ [web_demo](inference/web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model + [streamlit_web_demo](inference/streamlit_web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model
to generate videos. to generate videos.
<div style="text-align: center;"> <div style="text-align: center;">

View File

@ -116,13 +116,13 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
+ [diffusers_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。 + [diffusers_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。
+ [diffusers_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码目前需要71GB显存将来会优化。 + [diffusers_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码目前需要71GB显存将来会优化。
+ [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。因为CogVideoX是在长文本上训练的所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4也可以替换为GPT、Gemini等任意大语言模型。 + [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。因为CogVideoX是在长文本上训练的所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4也可以替换为GPT、Gemini等任意大语言模型。
+ [gradio_demo](gradio_demo.py): 一个简单的gradio网页应用展示如何使用 CogVideoX-2B 模型生成视频。 + [gradio_web_demo](inference/gradio_web_demo.py): 一个简单的gradio网页应用展示如何使用 CogVideoX-2B 模型生成视频。
<div style="text-align: center;"> <div style="text-align: center;">
<img src="resources/gradio_demo.png" style="width: 100%; height: auto;" /> <img src="resources/gradio_demo.png" style="width: 100%; height: auto;" />
</div> </div>
+ [web_demo](inference/web_demo.py): 一个简单的streamlit网页应用展示如何使用 CogVideoX-2B 模型生成视频。 + [streamlit_web_demo](inference/streamlit_web_demo.py): 一个简单的streamlit网页应用展示如何使用 CogVideoX-2B 模型生成视频。
<div style="text-align: center;"> <div style="text-align: center;">
<img src="resources/web_demo.png" style="width: 100%; height: auto;" /> <img src="resources/web_demo.png" style="width: 100%; height: auto;" />

View File

@ -47,20 +47,6 @@ def generate_video(
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
): ):
"""
Generates a video based on the given prompt and saves it to the specified path.
Parameters:
- prompt (str): The description of the video to be generated.
- model_path (str): The path of the pre-trained model to be used.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
- dtype (torch.dtype): The data type for computation (default is torch.float16).
"""
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device # Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
@ -74,7 +60,8 @@ def generate_video(
device=device, # Device to use for computation device=device, # Device to use for computation
dtype=dtype, # Data type for computation dtype=dtype, # Data type for computation
) )
# Must enable model CPU offload to avoid OOM issue on GPU with 24GB memory
pipe.enable_model_cpu_offload()
# Generate the video frames using the pipeline # Generate the video frames using the pipeline
video = pipe( video = pipe(
num_inference_steps=num_inference_steps, # Number of inference steps num_inference_steps=num_inference_steps, # Number of inference steps
@ -82,11 +69,9 @@ def generate_video(
prompt_embeds=prompt_embeds, # Encoded prompt embeddings prompt_embeds=prompt_embeds, # Encoded prompt embeddings
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
).frames[0] ).frames[0]
# Export the generated frames to a video file. fps must be 8 # Export the generated frames to a video file. fps must be 8
export_to_video_imageio(video, output_path, fps=8) export_to_video_imageio(video, output_path, fps=8)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")

View File

@ -104,7 +104,7 @@ def infer(
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
pipe.enable_model_cpu_offload()
video = pipe( video = pipe(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,

View File

@ -76,7 +76,7 @@ def generate_video(
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
pipe.enable_model_cpu_offload()
# Generate video # Generate video
video = pipe( video = pipe(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,