update readme for video caption

This commit is contained in:
huangshiyu 2024-09-19 13:48:05 +08:00
parent 6a2efb844b
commit 5bdc3f65f1
9 changed files with 293 additions and 13 deletions

View File

@ -27,7 +27,7 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
controllability. With this release, the CogVideoX series now supports three tasks: text-to-video, video extension, and
image-to-video generation. Feel free to try it out [online](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space).
- 🔥🔥 **News**: ```2024/9/19```: The caption model used in the CogVideoX training process to convert video data into text
descriptions, [cogvlm2-llama3-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), is now open-source. Feel
descriptions, [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), is now open-source. Feel
free to download and use it.
- 🔥 **News**: ```2024/9/16```: We have added an automated video generation tool! You can now use local open-source
models + FLUX + CogVideoX to automatically generate high-quality videos. Feel free

View File

@ -25,7 +25,7 @@
- 🔥🔥 **ニュース**: ```2024/9/19```: CogVideoXシリーズの画像生成ビデオモデル **CogVideoX-5B-I2V**
をオープンソース化しました。このモデルでは、背景として画像を入力し、プロンプトと組み合わせてビデオを生成でき、より強力なコントロール性を提供します。これで、CogVideoXシリーズは、テキスト生成ビデオ、ビデオ拡張、画像生成ビデオの3つのタスクをサポートしています。ぜひ [オンラインでお試しください](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)。
- 🔥🔥 **ニュース**: ```2024/9/19```CogVideoX
のトレーニングプロセスで、ビデオデータをテキストに変換するためのキャプションモデル [cogvlm2-llama3-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
のトレーニングプロセスで、ビデオデータをテキストに変換するためのキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
がオープンソース化されました。ぜひダウンロードしてご利用ください。
- 🔥 **ニュース**: ```2024/9/16```: 自動動画生成ツールを追加しました!オープンソースのローカルモデル + FLUX + CogVideoX
を使用して、高品質な動画を自動生成できます。ぜひ[お試しください](tools/llm_flux_cogvideox/llm_flux_cogvideox.py)。

View File

@ -27,7 +27,7 @@
。该模型可以将一张图像作为背景输入,结合提示词一起生成视频,具有更强的可控性。
至此CogVideoX系列模型已经支持文本生成视频视频续写图片生成视频三种任务。欢迎前往在线[体验](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)。
- 🔥🔥 **News**: ```2024/9/19```: CogVideoX 训练过程中用于将视频数据转换为文本描述的 Caption
模型 [cogvlm2-llama3-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
模型 [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
已经开源。欢迎前往下载并使用。
- 🔥 **News**: ```2024/9/16```: 我们添加自动化生成视频工具,你可以使用本地开源模型 + FLUX + CogVideoX
实现自动生成优质视频,欢迎[体验](tools/llm_flux_cogvideox/llm_flux_cogvideox.py)

View File

@ -3,11 +3,37 @@
Typically, most video data does not come with corresponding descriptive text, so it is necessary to convert the video
data into textual descriptions to provide the essential training data for text-to-video models.
## Update and News
- 🔥🔥 **News**: ```2024/9/19```: The caption model used in the CogVideoX training process to convert video data into text
descriptions, [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), is now open-source. Feel
free to download and use it.
## Video Caption via CogVLM2-Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Caption is a video captioning model used to generate training data for the CogVideoX model.
### Install
```shell
pip install -r requirements.txt
```
### Usage
```shell
python video_caption.py
```
Example:
<div align="center">
<img width="600px" height="auto" src="./assests/CogVLM2-Caption-example.png">
</div>
## Video Caption via CogVLM2-Video
<p align="center">
🤗 <a href="https://huggingface.co/THUDM/cogvlm2-video-llama3-chat">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://cogvlm2-video.github.io/">Blog</a> &nbsp&nbsp <a href="http://cogvlm2-online.cogviewai.cn:7868/">💬 Online Demo</a>&nbsp&nbsp
</p>
[Code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
CogVLM2-Video is a versatile video understanding model equipped with timestamp-based question answering capabilities.
Users can input prompts such as `Please describe this video in detail.` to the model to obtain a detailed video caption:
@ -16,3 +42,26 @@ Users can input prompts such as `Please describe this video in detail.` to the m
</div>
Users can use the provided [code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) to load the model or configure a RESTful API to generate video captions.
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```

View File

@ -2,11 +2,37 @@
通常、ほとんどのビデオデータには対応する説明文が付いていないため、ビデオデータをテキストの説明に変換して、テキストからビデオへのモデルに必要なトレーニングデータを提供する必要があります。
## 更新とニュース
- 🔥🔥 **ニュース**: ```2024/9/19```CogVideoX
のトレーニングプロセスで、ビデオデータをテキストに変換するためのキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
がオープンソース化されました。ぜひダウンロードしてご利用ください。
## CogVLM2-Captionによるビデオキャプション
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Captionは、CogVideoXモデルのトレーニングデータを生成するために使用されるビデオキャプションモデルです。
### インストール
```shell
pip install -r requirements.txt
```
### 使用方法
```shell
python video_caption.py
```
例:
<div align="center">
<img width="600px" height="auto" src="./assests/CogVLM2-Caption-example.png">
</div>
## CogVLM2-Video を使用したビデオキャプション
<p align="center">
🤗 <a href="https://huggingface.co/THUDM/cogvlm2-video-llama3-chat">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://cogvlm2-video.github.io/">ブログ</a> &nbsp&nbsp <a href="http://cogvlm2-online.cogviewai.cn:7868/">💬 オンラインデモ</a>&nbsp&nbsp
</p>
[Code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
CogVLM2-Video は、タイムスタンプベースの質問応答機能を備えた多機能なビデオ理解モデルです。ユーザーは `このビデオを詳細に説明してください。` などのプロンプトをモデルに入力して、詳細なビデオキャプションを取得できます:
<div align="center">
@ -14,3 +40,26 @@ CogVLM2-Video は、タイムスタンプベースの質問応答機能を備え
</div>
ユーザーは提供された[コード](https://github.com/THUDM/CogVLM2/tree/main/video_demo)を使用してモデルをロードするか、RESTful API を構成してビデオキャプションを生成できます。
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```

View File

@ -2,11 +2,38 @@
通常,大多数视频数据不带有相应的描述性文本,因此需要将视频数据转换为文本描述,以提供必要的训练数据用于文本到视频模型。
## 项目更新
- 🔥🔥 **News**: ```2024/9/19```: CogVideoX 训练过程中用于将视频数据转换为文本描述的 Caption
模型 [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
已经开源。欢迎前往下载并使用。
## 通过 CogVLM2-Caption 模型生成视频Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Caption是用于生成CogVideoX模型训练数据的视频caption模型。
### 安装依赖
```shell
pip install -r requirements.txt
```
### 运行caption模型
```shell
python video_caption.py
```
示例:
<div align="center">
<img width="600px" height="auto" src="./assests/CogVLM2-Caption-example.png">
</div>
## 通过 CogVLM2-Video 模型生成视频Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
[Code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `请详细描述这个视频` 的提示语给模型以获得详细的视频Caption
CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `Describe this video in detail.` 的提示语给模型以获得详细的视频Caption
<div align="center">
@ -14,3 +41,27 @@ CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的
</div>
用户可以使用提供的[代码](https://github.com/THUDM/CogVLM2/tree/main/video_demo)加载模型或配置 RESTful API 来生成视频Caption。
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

View File

@ -0,0 +1,23 @@
decord>=0.6.0
#根据https://download.pytorch.org/whl/torch/python版本为[3.8,3.11]
torch==2.1.0
torchvision== 0.16.0
pytorchvideo==0.1.5
xformers
transformers==4.42.4
#git+https://github.com/huggingface/transformers.git
huggingface-hub>=0.23.0
pillow
chainlit>=1.0
pydantic>=2.7.1
timm>=0.9.16
openai>=1.30.1
loguru>=0.7.2
pydantic>=2.7.1
einops
sse-starlette>=2.1.0
flask
gunicorn
gevent
requests
gradio

View File

@ -0,0 +1,108 @@
import io
import argparse
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
0] >= 8 else torch.float16
parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
args = parser.parse_args([])
def load_video(video_data, strategy='chat'):
bridge.set_bridge('torch')
mp4_stream = video_data
num_frames = 24
decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
frame_id_list = None
total_frames = len(decord_vr)
if strategy == 'base':
clip_end_sec = 60
clip_start_sec = 0
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
end_frame = min(total_frames,
int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
elif strategy == 'chat':
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
timestamps = [i[0] for i in timestamps]
max_second = round(max(timestamps)) + 1
frame_id_list = []
for second in range(max_second):
closest_num = min(timestamps, key=lambda x: abs(x - second))
index = timestamps.index(closest_num)
frame_id_list.append(index)
if len(frame_id_list) >= num_frames:
break
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True
).eval().to(DEVICE)
def predict(prompt, video_data, temperature):
strategy = 'chat'
video = load_video(video_data, strategy=strategy)
history = []
query = prompt
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=query,
images=[video],
history=history,
template_version=strategy
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
'images': [[inputs['images'][0].to('cuda').to(TORCH_TYPE)]],
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def test():
prompt = "Please describe this video in detail."
temperature = 0.1
video_data = open('test.mp4', 'rb').read()
response = predict(prompt, video_data, temperature)
print(response)
if __name__ == '__main__':
test()