mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-25 16:55:58 +08:00
llm-flux-cogvideox-i2v-tools
This commit is contained in:
parent
775b0e1ba3
commit
b410841bcf
@ -53,7 +53,7 @@ pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
|
|||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
|
pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||||
"THUDM/CogVideoX-5b",
|
"THUDM/CogVideoX-5b-I2V",
|
||||||
transformer=CogVideoXTransformer3DModel.from_pretrained(
|
transformer=CogVideoXTransformer3DModel.from_pretrained(
|
||||||
"THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
|
"THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||||
),
|
),
|
||||||
@ -65,10 +65,10 @@ pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
|
|||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
# pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
||||||
pipe_image.transformer.to(memory_format=torch.channels_last)
|
# pipe_image.transformer.to(memory_format=torch.channels_last)
|
||||||
pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True)
|
# pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
os.makedirs("./output", exist_ok=True)
|
os.makedirs("./output", exist_ok=True)
|
||||||
os.makedirs("./gradio_tmp", exist_ok=True)
|
os.makedirs("./gradio_tmp", exist_ok=True)
|
||||||
@ -294,7 +294,8 @@ def delete_old_files():
|
|||||||
|
|
||||||
|
|
||||||
threading.Thread(target=delete_old_files, daemon=True).start()
|
threading.Thread(target=delete_old_files, daemon=True).start()
|
||||||
examples = [["horse.mp4"], ["kitten.mp4"], ["train_running.mp4"]]
|
examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]]
|
||||||
|
examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]]
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
@ -302,7 +303,8 @@ with gr.Blocks() as demo:
|
|||||||
CogVideoX-5B Huggingface Space🤗
|
CogVideoX-5B Huggingface Space🤗
|
||||||
</div>
|
</div>
|
||||||
<div style="text-align: center;">
|
<div style="text-align: center;">
|
||||||
<a href="https://huggingface.co/THUDM/CogVideoX-5B">🤗 5B Model Hub</a> |
|
<a href="https://huggingface.co/THUDM/CogVideoX-5B">🤗 5B(T2V) Model Hub</a> |
|
||||||
|
<a href="https://huggingface.co/THUDM/CogVideoX-5B-I2V">🤗 5B(I2V) Model Hub</a> |
|
||||||
<a href="https://github.com/THUDM/CogVideo">🌐 Github</a> |
|
<a href="https://github.com/THUDM/CogVideo">🌐 Github</a> |
|
||||||
<a href="https://arxiv.org/pdf/2408.06072">📜 arxiv </a>
|
<a href="https://arxiv.org/pdf/2408.06072">📜 arxiv </a>
|
||||||
</div>
|
</div>
|
||||||
@ -320,10 +322,11 @@ with gr.Blocks() as demo:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
||||||
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
||||||
|
examples_component_images = gr.Examples(examples_images, inputs=[examples_images], cache_examples=False)
|
||||||
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
|
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
|
||||||
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
|
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
|
||||||
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
|
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
|
||||||
examples_component = gr.Examples(examples, inputs=[video_input], cache_examples=False)
|
examples_component_videos = gr.Examples(examples_videos, inputs=[examples_videos], cache_examples=False)
|
||||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -338,7 +341,7 @@ with gr.Blocks() as demo:
|
|||||||
label="Inference Seed (Enter a positive number, -1 for random)", value=-1
|
label="Inference Seed (Enter a positive number, -1 for random)", value=-1
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 1440 × 960)", value=False)
|
enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False)
|
||||||
enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
|
enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br> The entire process is based on open-source solutions."
|
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br> The entire process is based on open-source solutions."
|
||||||
@ -356,7 +359,7 @@ with gr.Blocks() as demo:
|
|||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
||||||
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
||||||
🎥 Video Gallery
|
🎥 Video Gallery(For 5B)
|
||||||
</div>
|
</div>
|
||||||
<tr>
|
<tr>
|
||||||
<td style="width: 25%; vertical-align: top; font-size: 0.9em;">
|
<td style="width: 25%; vertical-align: top; font-size: 0.9em;">
|
||||||
|
BIN
inference/gradio_composite_demo/example_images/beach.png
Normal file
BIN
inference/gradio_composite_demo/example_images/beach.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 376 KiB |
BIN
inference/gradio_composite_demo/example_images/camping.png
Normal file
BIN
inference/gradio_composite_demo/example_images/camping.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 473 KiB |
BIN
inference/gradio_composite_demo/example_images/street.png
Normal file
BIN
inference/gradio_composite_demo/example_images/street.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 467 KiB |
3
inference/gradio_composite_demo/example_videos/horse.mp4
Normal file
3
inference/gradio_composite_demo/example_videos/horse.mp4
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:3c857bbc0d197c0751db9d6da9b5c85eafd163511ff9b0e10be65adf8ef9e352
|
||||||
|
size 453387
|
BIN
inference/gradio_composite_demo/example_videos/kitten.mp4
Normal file
BIN
inference/gradio_composite_demo/example_videos/kitten.mp4
Normal file
Binary file not shown.
BIN
inference/gradio_composite_demo/example_videos/train_running.mp4
Normal file
BIN
inference/gradio_composite_demo/example_videos/train_running.mp4
Normal file
Binary file not shown.
@ -3,7 +3,7 @@ safetensors>=0.4.5
|
|||||||
spandrel>=0.4.0
|
spandrel>=0.4.0
|
||||||
tqdm>=4.66.5
|
tqdm>=4.66.5
|
||||||
scikit-video>=1.1.11
|
scikit-video>=1.1.11
|
||||||
diffusers>=0.30.1
|
git+https://github.com/huggingface/diffusers.git@main
|
||||||
transformers>=4.44.0
|
transformers>=4.44.0
|
||||||
accelerate>=0.34.2
|
accelerate>=0.34.2
|
||||||
opencv-python>=4.10.0.84
|
opencv-python>=4.10.0.84
|
||||||
|
@ -18,7 +18,10 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
### 2. 下载模型权重
|
### 2. 下载模型权重
|
||||||
|
|
||||||
首先,前往 SAT 镜像下载模型权重。对于 CogVideoX-2B 模型,请按照如下方式下载:
|
首先,前往 SAT 镜像下载模型权重。
|
||||||
|
|
||||||
|
对于 CogVideoX-2B 模型,请按照如下方式下载:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
mkdir CogVideoX-2b-sat
|
mkdir CogVideoX-2b-sat
|
||||||
cd CogVideoX-2b-sat
|
cd CogVideoX-2b-sat
|
||||||
@ -29,28 +32,27 @@ wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
|
|||||||
mv 'index.html?dl=1' transformer.zip
|
mv 'index.html?dl=1' transformer.zip
|
||||||
unzip transformer.zip
|
unzip transformer.zip
|
||||||
```
|
```
|
||||||
对于 CogVideoX-5B 模型,请按照如下方式下载(VAE文件相同):
|
|
||||||
```shell
|
请按如下链接方式下载 CogVideoX-5B 模型的 `transformers` 文件(VAE 文件与 2B 相同):
|
||||||
mkdir CogVideoX-5b-sat
|
|
||||||
cd CogVideoX-5b-sat
|
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
|
||||||
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
|
||||||
mv 'index.html?dl=1' vae.zip
|
|
||||||
unzip vae.zip
|
接着,你需要将模型文件排版成如下格式:
|
||||||
```
|
|
||||||
然后,您需要前往[清华云盘](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)下载我们的模型,并进行解压。
|
|
||||||
整理之后, 两个模型的完整模型结构应该如下:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
.
|
.
|
||||||
├── transformer
|
├── transformer
|
||||||
│ ├── 1000 (or 1)
|
│ ├── 1000 (or 1)
|
||||||
│ │ └── mp_rank_00_model_states.pt
|
│ │ └── mp_rank_00_model_states.pt
|
||||||
│ └── latest
|
│ └── latest
|
||||||
└── vae
|
└── vae
|
||||||
└── 3d-vae.pt
|
└── 3d-vae.pt
|
||||||
```
|
```
|
||||||
|
|
||||||
由于模型的权重档案较大,建议使用`git lfs`。`git lfs`安装参见[这里](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)
|
由于模型的权重档案较大,建议使用`git lfs`。`git lfs`
|
||||||
|
安装参见[这里](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
git lfs install
|
git lfs install
|
||||||
```
|
```
|
||||||
@ -410,15 +412,16 @@ python ../tools/convert_weight_sat2hf.py
|
|||||||
在经过上面这些步骤训练之后,我们得到了一个sat带lora的权重,在{args.save}/1000/1000/mp_rank_00_model_states.pt你可以看到这个文件
|
在经过上面这些步骤训练之后,我们得到了一个sat带lora的权重,在{args.save}/1000/1000/mp_rank_00_model_states.pt你可以看到这个文件
|
||||||
|
|
||||||
导出的lora权重脚本在CogVideoX仓库 tools/export_sat_lora_weight.py ,导出后使用 load_cogvideox_lora.py 推理
|
导出的lora权重脚本在CogVideoX仓库 tools/export_sat_lora_weight.py ,导出后使用 load_cogvideox_lora.py 推理
|
||||||
- 导出命令
|
|
||||||
|
导出命令:
|
||||||
|
|
||||||
```
|
```
|
||||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||||
···
|
```
|
||||||
|
|
||||||
这次训练主要修改了下面几个模型结构,下面列出了 转换为HF格式的lora结构对应关系,可以看到lora将模型注意力结构上增加一个低秩权重,
|
这次训练主要修改了下面几个模型结构,下面列出了 转换为HF格式的lora结构对应关系,可以看到lora将模型注意力结构上增加一个低秩权重,
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||||
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
||||||
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
||||||
|
24
tools/llm_flux_cogvideox/generate.sh
Normal file
24
tools/llm_flux_cogvideox/generate.sh
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NUM_VIDEOS=100
|
||||||
|
INFERENCE_STEPS=50
|
||||||
|
GUIDANCE_SCALE=7.0
|
||||||
|
OUTPUT_DIR_PREFIX="outputs/gpu_"
|
||||||
|
LOG_DIR_PREFIX="logs/gpu_"
|
||||||
|
|
||||||
|
CUDA_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"}
|
||||||
|
|
||||||
|
IFS=',' read -r -a GPU_ARRAY <<< "$CUDA_DEVICES"
|
||||||
|
|
||||||
|
for i in "${!GPU_ARRAY[@]}"
|
||||||
|
do
|
||||||
|
GPU=${GPU_ARRAY[$i]}
|
||||||
|
echo "Starting task on GPU $GPU..."
|
||||||
|
CUDA_VISIBLE_DEVICES=$GPU nohup python3 llm_flux_cogvideox.py \
|
||||||
|
--num_videos $NUM_VIDEOS \
|
||||||
|
--image_generator_num_inference_steps $INFERENCE_STEPS \
|
||||||
|
--guidance_scale $GUIDANCE_SCALE \
|
||||||
|
--use_dynamic_cfg \
|
||||||
|
--output_dir ${OUTPUT_DIR_PREFIX}${GPU} \
|
||||||
|
> ${LOG_DIR_PREFIX}${GPU}.log 2>&1 &
|
||||||
|
done
|
256
tools/llm_flux_cogvideox/llm_flux_cogvideox.py
Normal file
256
tools/llm_flux_cogvideox/llm_flux_cogvideox.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
"""
|
||||||
|
The original experimental code for this project can be found at:
|
||||||
|
|
||||||
|
https://gist.github.com/a-r-r-o-w/d070cce059ab4ceab3a9f289ff83c69c
|
||||||
|
|
||||||
|
By using this code, description prompts will be generated through a local large language model, and images will be
|
||||||
|
generated using the black-forest-labs/FLUX.1-dev model, followed by video generation via CogVideoX.
|
||||||
|
The entire process utilizes open-source solutions, without the need for any API keys.
|
||||||
|
|
||||||
|
You can use the generate.sh file in the same folder to automate running this code
|
||||||
|
for batch generation of videos and images.
|
||||||
|
|
||||||
|
bash generate.sh
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
os.environ["TORCH_LOGS"] = "+dynamo,recompiles,graph_breaks"
|
||||||
|
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline
|
||||||
|
from diffusers.utils.logging import get_logger
|
||||||
|
from diffusers.utils import export_to_video
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """
|
||||||
|
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe.
|
||||||
|
|
||||||
|
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model.
|
||||||
|
|
||||||
|
There are a few rules to follow:
|
||||||
|
- You will only ever output a single video description per request.
|
||||||
|
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit.
|
||||||
|
|
||||||
|
You responses should just be the video generation prompt. Here are examples:
|
||||||
|
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
|
||||||
|
- "A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall"
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
USER_PROMPT = """
|
||||||
|
Could you generate a prompt for a video generation model?
|
||||||
|
Please limit the prompt to [{0}] words.
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_videos",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of unique videos you would like to generate."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_path",
|
||||||
|
type=str,
|
||||||
|
default="THUDM/CogVideoX-5B",
|
||||||
|
help="The path of Image2Video CogVideoX-5B",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--caption_generator_model_id",
|
||||||
|
type=str,
|
||||||
|
default="THUDM/glm-4-9b-chat",
|
||||||
|
help="Caption generation model. default GLM-4-9B",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--caption_generator_cache_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Cache directory for caption generation model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_generator_model_id",
|
||||||
|
type=str,
|
||||||
|
default="black-forest-labs/FLUX.1-dev",
|
||||||
|
help="Image generation model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_generator_cache_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Cache directory for image generation model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_generator_num_inference_steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Caption generation model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--guidance_scale",
|
||||||
|
type=float,
|
||||||
|
default=7,
|
||||||
|
help="Guidance scale to be use for generation."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_dynamic_cfg",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to use cosine dynamic guidance for generation [Recommended].",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
default="outputs/",
|
||||||
|
help="Location where generated images and videos should be stored.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--compile",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to compile the transformer of image and video generators."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_vae_tiling",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to use VAE tiling when encoding/decoding."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="Seed for reproducibility."
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def reset_memory():
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.reset_accumulated_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main(args: Dict[str, Any]) -> None:
|
||||||
|
output_dir = pathlib.Path(args.output_dir)
|
||||||
|
os.makedirs(output_dir.as_posix(), exist_ok=True)
|
||||||
|
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
reset_memory()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True)
|
||||||
|
caption_generator = transformers.pipeline(
|
||||||
|
"text-generation",
|
||||||
|
model=args.caption_generator_model_id,
|
||||||
|
device_map="auto",
|
||||||
|
model_kwargs={
|
||||||
|
"local_files_only": True,
|
||||||
|
"cache_dir": args.caption_generator_cache_dir,
|
||||||
|
"torch_dtype": torch.bfloat16,
|
||||||
|
},
|
||||||
|
trust_remote_code=True,
|
||||||
|
tokenizer=tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
captions = []
|
||||||
|
for i in range(args.num_videos):
|
||||||
|
num_words = random.choice([100, 150, 200])
|
||||||
|
user_prompt = USER_PROMPT.format(num_words)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = caption_generator(messages, max_new_tokens=226)
|
||||||
|
caption = outputs[0]["generated_text"][-1]["content"]
|
||||||
|
if caption.startswith("\"") and caption.endswith("\""):
|
||||||
|
caption = caption[1:-1]
|
||||||
|
captions.append(caption)
|
||||||
|
logger.info(f"Generated caption: {caption}")
|
||||||
|
|
||||||
|
with open(output_dir / "captions.json", "w") as file:
|
||||||
|
json.dump(captions, file)
|
||||||
|
|
||||||
|
del caption_generator
|
||||||
|
reset_memory()
|
||||||
|
|
||||||
|
image_generator = DiffusionPipeline.from_pretrained(
|
||||||
|
args.image_generator_model_id,
|
||||||
|
cache_dir=args.image_generator_cache_dir,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
image_generator.to("cuda")
|
||||||
|
|
||||||
|
if args.compile:
|
||||||
|
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
|
if args.enable_vae_tiling:
|
||||||
|
image_generator.vae.enable_tiling()
|
||||||
|
|
||||||
|
images = []
|
||||||
|
for index, caption in enumerate(captions):
|
||||||
|
image = image_generator(
|
||||||
|
prompt=caption,
|
||||||
|
height=480,
|
||||||
|
width=720,
|
||||||
|
num_inference_steps=args.image_generator_num_inference_steps,
|
||||||
|
guidance_scale=3.5,
|
||||||
|
).images[0]
|
||||||
|
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
||||||
|
image.save(output_dir / f"{index}_{filename}.png")
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
del image_generator
|
||||||
|
reset_memory()
|
||||||
|
|
||||||
|
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||||
|
args.model_path, torch_dtype=torch.bfloat16).to("cuda")
|
||||||
|
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
|
||||||
|
video_generator.scheduler.config,
|
||||||
|
timestep_spacing="trailing")
|
||||||
|
|
||||||
|
if args.compile:
|
||||||
|
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
|
if args.enable_vae_tiling:
|
||||||
|
video_generator.vae.enable_tiling()
|
||||||
|
|
||||||
|
generator = torch.Generator().manual_seed(args.seed)
|
||||||
|
for index, (caption, image) in enumerate(zip(captions, images)):
|
||||||
|
video = video_generator(
|
||||||
|
image=image,
|
||||||
|
prompt=caption,
|
||||||
|
height=480,
|
||||||
|
width=720,
|
||||||
|
num_frames=49,
|
||||||
|
num_inference_steps=50,
|
||||||
|
guidance_scale=args.guidance_scale,
|
||||||
|
use_dynamic_cfg=args.use_dynamic_cfg,
|
||||||
|
generator=generator,
|
||||||
|
).frames[0]
|
||||||
|
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
||||||
|
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
main(args)
|
Loading…
x
Reference in New Issue
Block a user