This commit is contained in:
zR 2024-08-07 19:27:53 +08:00
parent 5a69462c8b
commit 125432d403
5 changed files with 78 additions and 48 deletions

View File

@ -20,6 +20,8 @@
## Update and News ## Update and News
- 🔥 **News**: `2024/8/7`: CogVideoX has been integrated into `diffusers` version 0.30.0. Inference can now be performed
on a single 3090 GPU. For more details, please refer to the [code](inference/cli_demo.py).
- 🔥 **News**: ``2024/8/6``: We have also open-sourced **3D Causal VAE** used in **CogVideoX-2B**, which can reconstruct - 🔥 **News**: ``2024/8/6``: We have also open-sourced **3D Causal VAE** used in **CogVideoX-2B**, which can reconstruct
the video almost losslessly. the video almost losslessly.
- 🔥 **News**: ``2024/8/6``: We have open-sourced **CogVideoX-2B**the first model in the CogVideoX series of video - 🔥 **News**: ``2024/8/6``: We have open-sourced **CogVideoX-2B**the first model in the CogVideoX series of video
@ -106,14 +108,14 @@ along with related basic information:
| Model Name | CogVideoX-2B | | Model Name | CogVideoX-2B |
|-------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| |-------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Prompt Language | English | | Prompt Language | English |
| GPU Memory Required for Inference (FP16) | 18GB if using [SAT](https://github.com/THUDM/SwissArmyTransformer); 36GB if using diffusers (will be optimized before the PR is merged) | | Single GPU Inference (FP16) | 18GB using [SAT](https://github.com/THUDM/SwissArmyTransformer) <br> 23.9GB using diffusers |
| Multi GPUs Inference (FP16) | 20GB minimum per GPU using diffusers |
| GPU Memory Required for Fine-tuning(bs=1) | 40GB | | GPU Memory Required for Fine-tuning(bs=1) | 40GB |
| Prompt Max Length | 226 Tokens | | Prompt Max Length | 226 Tokens |
| Video Length | 6 seconds | | Video Length | 6 seconds |
| Frames Per Second | 8 frames | | Frames Per Second | 8 frames |
| Resolution | 720 * 480 | | Resolution | 720 * 480 |
| Quantized Inference | Not Supported | | Quantized Inference | Not Supported |
| Multi-card Inference | Not Supported |
| Download Link (HF diffusers Model) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) [💫 WiseModel](https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b) | | Download Link (HF diffusers Model) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) [💫 WiseModel](https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b) |
| Download Link (SAT Model) | [SAT](./sat/README.md) | | Download Link (SAT Model) | [SAT](./sat/README.md) |
@ -132,14 +134,16 @@ of the **CogVideoX** open-source model.
CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training
distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as
GPT, Gemini, etc. GPT, Gemini, etc.
+ [gradio_web_demo](inference/gradio_web_demo.py): A simple gradio web UI demonstrating how to use the CogVideoX-2B model to generate + [gradio_web_demo](inference/gradio_web_demo.py): A simple gradio web UI demonstrating how to use the CogVideoX-2B
model to generate
videos. videos.
<div style="text-align: center;"> <div style="text-align: center;">
<img src="resources/gradio_demo.png" style="width: 100%; height: auto;" /> <img src="resources/gradio_demo.png" style="width: 100%; height: auto;" />
</div> </div>
+ [streamlit_web_demo](inference/streamlit_web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model + [streamlit_web_demo](inference/streamlit_web_demo.py): A simple streamlit web application demonstrating how to use the
CogVideoX-2B model
to generate videos. to generate videos.
<div style="text-align: center;"> <div style="text-align: center;">

View File

@ -21,24 +21,26 @@
## 项目更新 ## 项目更新
- 🔥 **News**: ``2024/8/7``: CogVideoX 已经合并入 `diffusers` 0.30.0版本单张3090可以推理详情请见[代码](inference/cli_demo.py)。
- 🔥 **News**: ``2024/8/6``: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。 - 🔥 **News**: ``2024/8/6``: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。
- 🔥 **News**: ``2024/8/6``: 我们开源 CogVideoX 系列视频生成模型的第一个模型, **CogVideoX-2B** - 🔥 **News**: ``2024/8/6``: 我们开源 CogVideoX 系列视频生成模型的第一个模型, **CogVideoX-2B**
- 🌱 **Source**: ```2022/5/19```: 我们开源了 CogVideo 视频生成模型(现在你可以在 `CogVideo` 分支中看到),这是首个开源的基于 Transformer 的大型文本生成视频模型,您可以访问 [ICLR'23 论文](https://arxiv.org/abs/2205.15868) 查看技术细节。 - 🌱 **Source**: ```2022/5/19```: 我们开源了 CogVideo 视频生成模型(现在你可以在 `CogVideo` 分支中看到),这是首个开源的基于
**性能更强,参数量更大的模型正在到来的路上~,欢迎关注** Transformer 的大型文本生成视频模型,您可以访问 [ICLR'23 论文](https://arxiv.org/abs/2205.15868) 查看技术细节。
**性能更强,参数量更大的模型正在到来的路上~,欢迎关注**
## 目录 ## 目录
跳转到指定部分: 跳转到指定部分:
- [快速开始](#快速开始) - [快速开始](#快速开始)
- [SAT](#sat) - [SAT](#sat)
- [Diffusers](#Diffusers) - [Diffusers](#Diffusers)
- [CogVideoX-2B 视频作品](#cogvideox-2b-视频作品) - [CogVideoX-2B 视频作品](#cogvideox-2b-视频作品)
- [CogVideoX模型介绍](#模型介绍) - [CogVideoX模型介绍](#模型介绍)
- [完整项目代码结构](#完整项目代码结构) - [完整项目代码结构](#完整项目代码结构)
- [Inference](#inference) - [Inference](#inference)
- [SAT](#sat) - [SAT](#sat)
- [Tools](#tools) - [Tools](#tools)
- [开源项目规划](#开源项目规划) - [开源项目规划](#开源项目规划)
- [模型协议](#模型协议) - [模型协议](#模型协议)
- [CogVideo(ICLR'23)模型介绍](#cogvideoiclr23) - [CogVideo(ICLR'23)模型介绍](#cogvideoiclr23)
@ -53,8 +55,9 @@
### SAT ### SAT
查看sat文件夹下的[sat_demo](sat/README.md):包含了 SAT 权重的推理代码和微调代码,推荐基于此代码进行 CogVideoX 模型结构的改进,研究者使用该代码可以更好的进行快速的迭代和开发。 查看sat文件夹下的[sat_demo](sat/README.md):包含了 SAT 权重的推理代码和微调代码,推荐基于此代码进行 CogVideoX
(18 GB 推理, 40GB lora微调) 模型结构的改进,研究者使用该代码可以更好的进行快速的迭代和开发。
(18 GB 推理, 40GB lora微调)
### Diffusers ### Diffusers
@ -64,7 +67,6 @@ pip install -r requirements.txt
查看[diffusers_demo](inference/cli_demo.py)包含对推理代码更详细的解释包括各种关键的参数。36GB 推理,显存优化以及微调代码正在开发) 查看[diffusers_demo](inference/cli_demo.py)包含对推理代码更详细的解释包括各种关键的参数。36GB 推理,显存优化以及微调代码正在开发)
## CogVideoX-2B 视频作品 ## CogVideoX-2B 视频作品
<div align="center"> <div align="center">
@ -93,19 +95,19 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
下表战展示目前我们提供的视频生成模型列表,以及相关基础信息: 下表战展示目前我们提供的视频生成模型列表,以及相关基础信息:
| 模型名字 | CogVideoX-2B | | 模型名 | CogVideoX-2B |
|---------------------|--------------------------------------------------------------------------------------------------------------------------------------| |---------------------|-------------------------------------------------------------------------------------------------------------------------------|
| 提示词语言 | English | | 提示词语言 | English |
| 推理显存消耗 (FP-16) | 36GB using diffusers (will be optimized before the PR is merged) and 18GB using [SAT](https://github.com/THUDM/SwissArmyTransformer) | | 单GPU推理 (FP-16) 显存消耗 | 18GB using [SAT](https://github.com/THUDM/SwissArmyTransformer) <br> 23.9GB using diffusers |
| 微调显存消耗 (bs=1) | 42GB | | 多GPU推理 (FP-16) 显存消耗 | 20GB minimum per GPU using diffusers |
| 提示词长度上限 | 226 Tokens | | 微调显存消耗 (bs=1) | 42GB |
| 视频长度 | 6 seconds | | 提示词长度上限 | 226 Tokens |
| 帧率(每秒) | 8 frames | | 视频长度 | 6 seconds |
| 视频分辨率 | 720 * 480 | | 帧率(每秒) | 8 frames |
| 量化推理 | 不支持 | | 视频分辨率 | 720 * 480 |
| 多卡推理 | 不支持 | | 量化推理 | 不支持 |
| 下载地址 (Diffusers 模型) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) | | 下载地址 (Diffusers 模型) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) |
| 下载地址 (SAT 模型) | [SAT](./sat/README_zh.md) | | 下载地址 (SAT 模型) | [SAT](./sat/README_zh.md) |
## 完整项目代码结构 ## 完整项目代码结构
@ -115,7 +117,8 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
+ [diffusers_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。 + [diffusers_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。
+ [diffusers_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码目前需要71GB显存将来会优化。 + [diffusers_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码目前需要71GB显存将来会优化。
+ [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。因为CogVideoX是在长文本上训练的所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4也可以替换为GPT、Gemini等任意大语言模型。 + [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合
CogVideoX的长输入。因为CogVideoX是在长文本上训练的所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4也可以替换为GPT、Gemini等任意大语言模型。
+ [gradio_web_demo](inference/gradio_web_demo.py): 一个简单的gradio网页应用展示如何使用 CogVideoX-2B 模型生成视频。 + [gradio_web_demo](inference/gradio_web_demo.py): 一个简单的gradio网页应用展示如何使用 CogVideoX-2B 模型生成视频。
<div style="text-align: center;"> <div style="text-align: center;">
@ -140,9 +143,10 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
+ [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): 将 SAT 模型权重转换为 Huggingface 模型权重。 + [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): 将 SAT 模型权重转换为 Huggingface 模型权重。
+ [caption_demo](tools/caption/README_zh.md): Caption 工具,对视频理解并用文字输出的模型。 + [caption_demo](tools/caption/README_zh.md): Caption 工具,对视频理解并用文字输出的模型。
## CogVideo(ICLR'23)
## CogVideo(ICLR'23) [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868)
[CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) 的官方repo位于[CogVideo branch](https://github.com/THUDM/CogVideo/tree/CogVideo)。 的官方repo位于[CogVideo branch](https://github.com/THUDM/CogVideo/tree/CogVideo)。
**CogVideo可以生成高帧率视频下面展示了一个32帧的4秒视频。** **CogVideo可以生成高帧率视频下面展示了一个32帧的4秒视频。**
@ -155,11 +159,12 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
<video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video> <video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
</div> </div>
CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/)。您可以在这里体验文本到视频生成。*原始输入为中文。* CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/)。您可以在这里体验文本到视频生成。
*原始输入为中文。*
## 引用 ## 引用
🌟 如果您发现我们的工作有所帮助欢迎引用我们的文章留下宝贵的stars 🌟 如果您发现我们的工作有所帮助欢迎引用我们的文章留下宝贵的stars
``` ```
@article{yang2024cogvideox, @article{yang2024cogvideox,

View File

@ -22,7 +22,7 @@ from diffusers import CogVideoXPipeline
def export_to_video_imageio( def export_to_video_imageio(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
) -> str: ) -> str:
""" """
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX) Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
@ -38,17 +38,34 @@ def export_to_video_imageio(
def generate_video( def generate_video(
prompt: str, prompt: str,
model_path: str, model_path: str,
output_path: str = "./output.mp4", output_path: str = "./output.mp4",
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 6.0, guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1, num_videos_per_prompt: int = 1,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
): ):
"""
Generates a video based on the given prompt and saves it to the specified path.
Parameters:
- prompt (str): The description of the video to be generated.
- model_path (str): The path of the pre-trained model to be used.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
- dtype (torch.dtype): The data type for computation (default is torch.float16).
"""
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device # Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) # add device_map="balanced" in the from_pretrained function and remove
# `pipe.enable_model_cpu_offload()` to enable Multi GPUs (2 or more and each one must have more than 20GB memory) inference.
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
# Encode the prompt to get the prompt embeddings # Encode the prompt to get the prompt embeddings
prompt_embeds, _ = pipe.encode_prompt( prompt_embeds, _ = pipe.encode_prompt(
@ -60,18 +77,19 @@ def generate_video(
device=device, # Device to use for computation device=device, # Device to use for computation
dtype=dtype, # Data type for computation dtype=dtype, # Data type for computation
) )
# Must enable model CPU offload to avoid OOM issue on GPU with 24GB memory
pipe.enable_model_cpu_offload()
# Generate the video frames using the pipeline # Generate the video frames using the pipeline
video = pipe( video = pipe(
num_inference_steps=num_inference_steps, # Number of inference steps num_inference_steps=5, # Number of inference steps
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
prompt_embeds=prompt_embeds, # Encoded prompt embeddings prompt_embeds=prompt_embeds, # Encoded prompt embeddings
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
).frames[0] ).frames[0]
# Export the generated frames to a video file. fps must be 8 # Export the generated frames to a video file. fps must be 8
export_to_video_imageio(video, output_path, fps=8) export_to_video_imageio(video, output_path, fps=8)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")

View File

@ -16,7 +16,8 @@ import PIL
dtype = torch.bfloat16 dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device) pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype)
pipe.enable_model_cpu_offload()
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets. sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
@ -104,7 +105,7 @@ def infer(
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
pipe.enable_model_cpu_offload()
video = pipe( video = pipe(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,

View File

@ -39,7 +39,9 @@ def load_model(model_path: str, dtype: torch.dtype, device: str) -> CogVideoXPip
Returns: Returns:
- CogVideoXPipeline: Loaded model pipeline. - CogVideoXPipeline: Loaded model pipeline.
""" """
return CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
return pipe
# Define a function to generate video based on the provided prompt and model path # Define a function to generate video based on the provided prompt and model path