update readme for video caption

This commit is contained in:
huangshiyu 2024-09-19 13:48:05 +08:00
parent 6a2efb844b
commit 5bdc3f65f1
9 changed files with 293 additions and 13 deletions

View File

@ -27,7 +27,7 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
controllability. With this release, the CogVideoX series now supports three tasks: text-to-video, video extension, and controllability. With this release, the CogVideoX series now supports three tasks: text-to-video, video extension, and
image-to-video generation. Feel free to try it out [online](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space). image-to-video generation. Feel free to try it out [online](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space).
- 🔥🔥 **News**: ```2024/9/19```: The caption model used in the CogVideoX training process to convert video data into text - 🔥🔥 **News**: ```2024/9/19```: The caption model used in the CogVideoX training process to convert video data into text
descriptions, [cogvlm2-llama3-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), is now open-source. Feel descriptions, [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), is now open-source. Feel
free to download and use it. free to download and use it.
- 🔥 **News**: ```2024/9/16```: We have added an automated video generation tool! You can now use local open-source - 🔥 **News**: ```2024/9/16```: We have added an automated video generation tool! You can now use local open-source
models + FLUX + CogVideoX to automatically generate high-quality videos. Feel free models + FLUX + CogVideoX to automatically generate high-quality videos. Feel free

View File

@ -25,7 +25,7 @@
- 🔥🔥 **ニュース**: ```2024/9/19```: CogVideoXシリーズの画像生成ビデオモデル **CogVideoX-5B-I2V** - 🔥🔥 **ニュース**: ```2024/9/19```: CogVideoXシリーズの画像生成ビデオモデル **CogVideoX-5B-I2V**
をオープンソース化しました。このモデルでは、背景として画像を入力し、プロンプトと組み合わせてビデオを生成でき、より強力なコントロール性を提供します。これで、CogVideoXシリーズは、テキスト生成ビデオ、ビデオ拡張、画像生成ビデオの3つのタスクをサポートしています。ぜひ [オンラインでお試しください](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)。 をオープンソース化しました。このモデルでは、背景として画像を入力し、プロンプトと組み合わせてビデオを生成でき、より強力なコントロール性を提供します。これで、CogVideoXシリーズは、テキスト生成ビデオ、ビデオ拡張、画像生成ビデオの3つのタスクをサポートしています。ぜひ [オンラインでお試しください](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)。
- 🔥🔥 **ニュース**: ```2024/9/19```CogVideoX - 🔥🔥 **ニュース**: ```2024/9/19```CogVideoX
のトレーニングプロセスで、ビデオデータをテキストに変換するためのキャプションモデル [cogvlm2-llama3-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption) のトレーニングプロセスで、ビデオデータをテキストに変換するためのキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
がオープンソース化されました。ぜひダウンロードしてご利用ください。 がオープンソース化されました。ぜひダウンロードしてご利用ください。
- 🔥 **ニュース**: ```2024/9/16```: 自動動画生成ツールを追加しました!オープンソースのローカルモデル + FLUX + CogVideoX - 🔥 **ニュース**: ```2024/9/16```: 自動動画生成ツールを追加しました!オープンソースのローカルモデル + FLUX + CogVideoX
を使用して、高品質な動画を自動生成できます。ぜひ[お試しください](tools/llm_flux_cogvideox/llm_flux_cogvideox.py)。 を使用して、高品質な動画を自動生成できます。ぜひ[お試しください](tools/llm_flux_cogvideox/llm_flux_cogvideox.py)。

View File

@ -27,7 +27,7 @@
。该模型可以将一张图像作为背景输入,结合提示词一起生成视频,具有更强的可控性。 。该模型可以将一张图像作为背景输入,结合提示词一起生成视频,具有更强的可控性。
至此CogVideoX系列模型已经支持文本生成视频视频续写图片生成视频三种任务。欢迎前往在线[体验](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)。 至此CogVideoX系列模型已经支持文本生成视频视频续写图片生成视频三种任务。欢迎前往在线[体验](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)。
- 🔥🔥 **News**: ```2024/9/19```: CogVideoX 训练过程中用于将视频数据转换为文本描述的 Caption - 🔥🔥 **News**: ```2024/9/19```: CogVideoX 训练过程中用于将视频数据转换为文本描述的 Caption
模型 [cogvlm2-llama3-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption) 模型 [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
已经开源。欢迎前往下载并使用。 已经开源。欢迎前往下载并使用。
- 🔥 **News**: ```2024/9/16```: 我们添加自动化生成视频工具,你可以使用本地开源模型 + FLUX + CogVideoX - 🔥 **News**: ```2024/9/16```: 我们添加自动化生成视频工具,你可以使用本地开源模型 + FLUX + CogVideoX
实现自动生成优质视频,欢迎[体验](tools/llm_flux_cogvideox/llm_flux_cogvideox.py) 实现自动生成优质视频,欢迎[体验](tools/llm_flux_cogvideox/llm_flux_cogvideox.py)

View File

@ -3,11 +3,37 @@
Typically, most video data does not come with corresponding descriptive text, so it is necessary to convert the video Typically, most video data does not come with corresponding descriptive text, so it is necessary to convert the video
data into textual descriptions to provide the essential training data for text-to-video models. data into textual descriptions to provide the essential training data for text-to-video models.
## Update and News
- 🔥🔥 **News**: ```2024/9/19```: The caption model used in the CogVideoX training process to convert video data into text
descriptions, [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), is now open-source. Feel
free to download and use it.
## Video Caption via CogVLM2-Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Caption is a video captioning model used to generate training data for the CogVideoX model.
### Install
```shell
pip install -r requirements.txt
```
### Usage
```shell
python video_caption.py
```
Example:
<div align="center">
<img width="600px" height="auto" src="./assests/CogVLM2-Caption-example.png">
</div>
## Video Caption via CogVLM2-Video ## Video Caption via CogVLM2-Video
<p align="center"> [Code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
🤗 <a href="https://huggingface.co/THUDM/cogvlm2-video-llama3-chat">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://cogvlm2-video.github.io/">Blog</a> &nbsp&nbsp <a href="http://cogvlm2-online.cogviewai.cn:7868/">💬 Online Demo</a>&nbsp&nbsp
</p>
CogVLM2-Video is a versatile video understanding model equipped with timestamp-based question answering capabilities. CogVLM2-Video is a versatile video understanding model equipped with timestamp-based question answering capabilities.
Users can input prompts such as `Please describe this video in detail.` to the model to obtain a detailed video caption: Users can input prompts such as `Please describe this video in detail.` to the model to obtain a detailed video caption:
@ -16,3 +42,26 @@ Users can input prompts such as `Please describe this video in detail.` to the m
</div> </div>
Users can use the provided [code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) to load the model or configure a RESTful API to generate video captions. Users can use the provided [code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) to load the model or configure a RESTful API to generate video captions.
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```

View File

@ -2,11 +2,37 @@
通常、ほとんどのビデオデータには対応する説明文が付いていないため、ビデオデータをテキストの説明に変換して、テキストからビデオへのモデルに必要なトレーニングデータを提供する必要があります。 通常、ほとんどのビデオデータには対応する説明文が付いていないため、ビデオデータをテキストの説明に変換して、テキストからビデオへのモデルに必要なトレーニングデータを提供する必要があります。
## 更新とニュース
- 🔥🔥 **ニュース**: ```2024/9/19```CogVideoX
のトレーニングプロセスで、ビデオデータをテキストに変換するためのキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
がオープンソース化されました。ぜひダウンロードしてご利用ください。
## CogVLM2-Captionによるビデオキャプション
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Captionは、CogVideoXモデルのトレーニングデータを生成するために使用されるビデオキャプションモデルです。
### インストール
```shell
pip install -r requirements.txt
```
### 使用方法
```shell
python video_caption.py
```
例:
<div align="center">
<img width="600px" height="auto" src="./assests/CogVLM2-Caption-example.png">
</div>
## CogVLM2-Video を使用したビデオキャプション ## CogVLM2-Video を使用したビデオキャプション
<p align="center"> [Code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
🤗 <a href="https://huggingface.co/THUDM/cogvlm2-video-llama3-chat">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://cogvlm2-video.github.io/">ブログ</a> &nbsp&nbsp <a href="http://cogvlm2-online.cogviewai.cn:7868/">💬 オンラインデモ</a>&nbsp&nbsp
</p>
CogVLM2-Video は、タイムスタンプベースの質問応答機能を備えた多機能なビデオ理解モデルです。ユーザーは `このビデオを詳細に説明してください。` などのプロンプトをモデルに入力して、詳細なビデオキャプションを取得できます: CogVLM2-Video は、タイムスタンプベースの質問応答機能を備えた多機能なビデオ理解モデルです。ユーザーは `このビデオを詳細に説明してください。` などのプロンプトをモデルに入力して、詳細なビデオキャプションを取得できます:
<div align="center"> <div align="center">
@ -14,3 +40,26 @@ CogVLM2-Video は、タイムスタンプベースの質問応答機能を備え
</div> </div>
ユーザーは提供された[コード](https://github.com/THUDM/CogVLM2/tree/main/video_demo)を使用してモデルをロードするか、RESTful API を構成してビデオキャプションを生成できます。 ユーザーは提供された[コード](https://github.com/THUDM/CogVLM2/tree/main/video_demo)を使用してモデルをロードするか、RESTful API を構成してビデオキャプションを生成できます。
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```

View File

@ -2,11 +2,38 @@
通常,大多数视频数据不带有相应的描述性文本,因此需要将视频数据转换为文本描述,以提供必要的训练数据用于文本到视频模型。 通常,大多数视频数据不带有相应的描述性文本,因此需要将视频数据转换为文本描述,以提供必要的训练数据用于文本到视频模型。
## 项目更新
- 🔥🔥 **News**: ```2024/9/19```: CogVideoX 训练过程中用于将视频数据转换为文本描述的 Caption
模型 [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
已经开源。欢迎前往下载并使用。
## 通过 CogVLM2-Caption 模型生成视频Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Caption是用于生成CogVideoX模型训练数据的视频caption模型。
### 安装依赖
```shell
pip install -r requirements.txt
```
### 运行caption模型
```shell
python video_caption.py
```
示例:
<div align="center">
<img width="600px" height="auto" src="./assests/CogVLM2-Caption-example.png">
</div>
## 通过 CogVLM2-Video 模型生成视频Caption ## 通过 CogVLM2-Video 模型生成视频Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/) [Code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `请详细描述这个视频` 的提示语给模型以获得详细的视频Caption CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `Describe this video in detail.` 的提示语给模型以获得详细的视频Caption
<div align="center"> <div align="center">
@ -14,3 +41,27 @@ CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的
</div> </div>
用户可以使用提供的[代码](https://github.com/THUDM/CogVLM2/tree/main/video_demo)加载模型或配置 RESTful API 来生成视频Caption。 用户可以使用提供的[代码](https://github.com/THUDM/CogVLM2/tree/main/video_demo)加载模型或配置 RESTful API 来生成视频Caption。
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

View File

@ -0,0 +1,23 @@
decord>=0.6.0
#根据https://download.pytorch.org/whl/torch/python版本为[3.8,3.11]
torch==2.1.0
torchvision== 0.16.0
pytorchvideo==0.1.5
xformers
transformers==4.42.4
#git+https://github.com/huggingface/transformers.git
huggingface-hub>=0.23.0
pillow
chainlit>=1.0
pydantic>=2.7.1
timm>=0.9.16
openai>=1.30.1
loguru>=0.7.2
pydantic>=2.7.1
einops
sse-starlette>=2.1.0
flask
gunicorn
gevent
requests
gradio

View File

@ -0,0 +1,108 @@
import io
import argparse
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
0] >= 8 else torch.float16
parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
args = parser.parse_args([])
def load_video(video_data, strategy='chat'):
bridge.set_bridge('torch')
mp4_stream = video_data
num_frames = 24
decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
frame_id_list = None
total_frames = len(decord_vr)
if strategy == 'base':
clip_end_sec = 60
clip_start_sec = 0
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
end_frame = min(total_frames,
int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
elif strategy == 'chat':
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
timestamps = [i[0] for i in timestamps]
max_second = round(max(timestamps)) + 1
frame_id_list = []
for second in range(max_second):
closest_num = min(timestamps, key=lambda x: abs(x - second))
index = timestamps.index(closest_num)
frame_id_list.append(index)
if len(frame_id_list) >= num_frames:
break
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True
).eval().to(DEVICE)
def predict(prompt, video_data, temperature):
strategy = 'chat'
video = load_video(video_data, strategy=strategy)
history = []
query = prompt
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=query,
images=[video],
history=history,
template_version=strategy
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
'images': [[inputs['images'][0].to('cuda').to(TORCH_TYPE)]],
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def test():
prompt = "Please describe this video in detail."
temperature = 0.1
video_data = open('test.mp4', 'rb').read()
response = predict(prompt, video_data, temperature)
print(response)
if __name__ == '__main__':
test()