update llm_cogvideox_flux demo test

This commit is contained in:
zR 2024-09-17 23:15:19 +08:00
parent b410841bcf
commit db309f3242
14 changed files with 95 additions and 75 deletions

View File

@ -23,6 +23,11 @@ the videos in the `videos` directory.
└── videos.txt └── videos.txt
``` ```
You can download [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
from here.
The video fine-tuning dataset is used as a test for fine-tuning.
### Configuration Files and Execution ### Configuration Files and Execution
`accelerate` configuration files are as follows: `accelerate` configuration files are as follows:

View File

@ -21,6 +21,10 @@
└── videos.txt └── videos.txt
``` ```
你可以从这里下载 [迪士尼汽船威利号](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
视频微调数据集作为测试微调。
### 配置文件和运行 ### 配置文件和运行
`accelerate` 配置文件如下: `accelerate` 配置文件如下:

View File

@ -2,9 +2,10 @@
export MODEL_PATH="THUDM/CogVideoX-2b" export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache" export CACHE_PATH="~/.cache"
export DATASET_PATH="disney" export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-multi-gpu" export OUTPUT_PATH="cogvideox-lora-multi-gpu"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
train_cogvideox_lora.py \ train_cogvideox_lora.py \

View File

@ -2,9 +2,10 @@
export MODEL_PATH="THUDM/CogVideoX-2b" export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache" export CACHE_PATH="~/.cache"
export DATASET_PATH="disney" export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-single-gpu" export OUTPUT_PATH="cogvideox-lora-single-gpu"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
train_cogvideox_lora.py \ train_cogvideox_lora.py \

View File

@ -1073,7 +1073,6 @@ def main(args):
target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=["to_k", "to_q", "to_v", "to_out.0"],
) )
transformer.add_adapter(transformer_lora_config) transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder: if args.train_text_encoder:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=args.rank, r=args.rank,

View File

@ -22,6 +22,7 @@ pip install -r requirements.txt
### 2. Download model weights ### 2. Download model weights
First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows: First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows:
```shell ```shell
mkdir CogVideoX-2b-sat mkdir CogVideoX-2b-sat
cd CogVideoX-2b-sat cd CogVideoX-2b-sat
@ -32,16 +33,14 @@ wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
mv 'index.html?dl=1' transformer.zip mv 'index.html?dl=1' transformer.zip
unzip transformer.zip unzip transformer.zip
``` ```
For the CogVideoX-5B model, please download as follows (VAE files are the same):
```shell For the CogVideoX-5B model, please download the `transformers` file as follows link:
mkdir CogVideoX-5b-sat (VAE files are the same as 2B)
cd CogVideoX-5b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 + [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
mv 'index.html?dl=1' vae.zip + [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
unzip vae.zip
``` Next, you need to format the model files as follows:
Then, you need to go to [Tsinghua Cloud Disk](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) to download our model and unzip it.
After sorting, the complete model structure of the two models should be as follows:
``` ```
. .
@ -53,7 +52,8 @@ After sorting, the complete model structure of the two models should be as follo
└── 3d-vae.pt └── 3d-vae.pt
``` ```
Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be
found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)
Next, clone the T5 model, which is not used for training and fine-tuning, but must be used. Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well. > T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well.
@ -160,14 +160,14 @@ model:
ucg_rate: 0.1 ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder target: sgm.modules.encoders.modules.FrozenT5Embedder
params: params:
model_dir: "{absolute_path/to/your/t5-v1_1-xxl}/t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder model_dir: "t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder
max_length: 226 max_length: 226
first_stage_config: first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params: params:
cp_size: 1 cp_size: 1
ckpt_path: "{absolute_path/to/your/t5-v1_1-xxl}/CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder
ignore_keys: [ 'loss' ] ignore_keys: [ 'loss' ]
loss_config: loss_config:
@ -254,13 +254,14 @@ args:
sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_num_frames: 13 # Must be 13, 11 or 9
sampling_fps: 8 sampling_fps: 8
fp16: True # For CogVideoX-2B fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B # bf16: True # For CogVideoX-5B
output_dir: outputs/ output_dir: outputs/
force_inference: True force_inference: True
``` ```
+ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt. + Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt.
+ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the OPENAI_API_KEY as your environmental variable. + For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the
OPENAI_API_KEY as your environmental variable.
+ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput. + Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput.
```yaml ```yaml
@ -408,27 +409,31 @@ python ../tools/convert_weight_sat2hf.py
### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints ### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints
After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file at `{args.save}/1000/1000/mp_rank_00_model_states.pt`. After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file
at `{args.save}/1000/1000/mp_rank_00_model_states.pt`.
The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`. After exporting, you can use `load_cogvideox_lora.py` for inference. The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`.
After exporting, you can use `load_cogvideox_lora.py` for inference.
Export command:
#### Export command:
```bash ```bash
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
``` ```
This training mainly modified the following model structures. The table below lists the corresponding structure mappings for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the model's attention structure. This training mainly modified the following model structures. The table below lists the corresponding structure mappings
for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the
model's attention structure.
``` ```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', 'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', 'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', 'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight', 'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', 'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', 'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
``` ```
Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format. Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format.

View File

@ -32,18 +32,12 @@ mv 'index.html?dl=1' transformer.zip
unzip transformer.zip unzip transformer.zip
``` ```
CogVideoX-5B モデルの場合は、次のようにダウンロードしてください (VAE ファイルは同じです)。 CogVideoX-5B モデルの `transformers` ファイルを以下のリンクからダウンロードしてください VAE ファイルは 2B と同じです):
```shell + [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
mkdir CogVideoX-5b-sat + [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
cd CogVideoX-5b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
mv 'index.html?dl=1' vae.zip
unzip vae.zip
```
次に、[Tsinghua Cloud Disk](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) に移動してモデルをダウンロードし、解凍する必要があります。 次に、モデルファイルを以下の形式にフォーマットする必要があります:
整理すると、2 つのモデルの完全なモデル構造は次のようになります。 モデル構造は次のようになります:
``` ```
. .
@ -55,8 +49,9 @@ unzip vae.zip
└── 3d-vae.pt └── 3d-vae.pt
``` ```
モデルの重みファイルが大きいため、`git lfs`を使用することをお勧めいたします。`git lfs`
のインストールについては、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)をご参照ください。
モデルの重みファイルが大きいため、`git lfs`を使用することをお勧めいたします。`git lfs`のインストールについては、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)をご参照ください。
```shell ```shell
git lfs install git lfs install
``` ```
@ -166,14 +161,14 @@ model:
ucg_rate: 0.1 ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder target: sgm.modules.encoders.modules.FrozenT5Embedder
params: params:
model_dir: "{absolute_path/to/your/t5-v1_1-xxl}/t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶対パス model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶対パス
max_length: 226 max_length: 226
first_stage_config: first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params: params:
cp_size: 1 cp_size: 1
ckpt_path: "{absolute_path/to/your/t5-v1_1-xxl}/CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶対パス ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶対パス
ignore_keys: [ 'loss' ] ignore_keys: [ 'loss' ]
loss_config: loss_config:
@ -244,6 +239,7 @@ model:
exp: 5 exp: 5
num_steps: 50 num_steps: 50
``` ```
### 4. `configs/inference.yaml` ファイルを変更します。 ### 4. `configs/inference.yaml` ファイルを変更します。
```yaml ```yaml
@ -259,7 +255,7 @@ args:
sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_num_frames: 13 # Must be 13, 11 or 9
sampling_fps: 8 sampling_fps: 8
fp16: True # For CogVideoX-2B fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B # bf16: True # For CogVideoX-5B
output_dir: outputs/ output_dir: outputs/
force_inference: True force_inference: True
``` ```
@ -417,25 +413,23 @@ python ../tools/convert_weight_sat2hf.py
LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。 LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。
#### エクスポートコマンド: エクスポートコマンド:
```bash ```bash
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
``` ```
このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。 このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。
``` ```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', 'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', 'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', 'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight', 'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', 'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', 'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
``` ```
export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。 export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。

View File

@ -162,14 +162,14 @@ model:
ucg_rate: 0.1 ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder target: sgm.modules.encoders.modules.FrozenT5Embedder
params: params:
model_dir: "{absolute_path/to/your/t5-v1_1-xxl}/t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl权重文件夹的绝对路径 model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl 权重文件夹的绝对路径
max_length: 226 max_length: 226
first_stage_config: first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params: params:
cp_size: 1 cp_size: 1
ckpt_path: "{absolute_path/to/your/t5-v1_1-xxl}/CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.pt文件夹的绝对路径 ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.pt文件夹的绝对路径
ignore_keys: [ 'loss' ] ignore_keys: [ 'loss' ]
loss_config: loss_config:
@ -294,9 +294,9 @@ bash inference.sh
``` ```
. .
├── labels ├── labels
   ├── 1.txt ├── 1.txt
   ├── 2.txt ├── 2.txt
   ├── ... ├── ...
└── videos └── videos
├── 1.mp4 ├── 1.mp4
├── 2.mp4 ├── 2.mp4

View File

@ -1,7 +1,7 @@
model: model:
scale_factor: 0.7 # different from cogvideox_2b_infer.yaml scale_factor: 0.7 # different from cogvideox_2b_infer.yaml
disable_first_stage_autocast: true disable_first_stage_autocast: true
not_trainable_prefixes: ['all'] ## Using Lora not_trainable_prefixes: ['all'] # Using Lora
log_keys: log_keys:
- txt - txt
@ -53,7 +53,7 @@ model:
hidden_size_head: 64 hidden_size_head: 64
text_length: 226 text_length: 226
lora_config: ## Using Lora lora_config: # Using Lora
target: sat.model.finetune.lora2.LoraMixin target: sat.model.finetune.lora2.LoraMixin
params: params:
r: 128 r: 128

View File

@ -225,7 +225,7 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
## old # old
""" """
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k del q, k
@ -241,7 +241,7 @@ class CrossAttention(nn.Module):
out = einsum('b i j, b j d -> b i d', sim, v) out = einsum('b i j, b j d -> b i d', sim, v)
""" """
## new # new
with sdp_kernel(**BACKEND_MAP[self.backend]): with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default

View File

@ -34,7 +34,6 @@ def convert_module_to_f32(x):
pass pass
## go
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
""" """
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py

View File

@ -225,7 +225,7 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
## old # old
""" """
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k del q, k
@ -241,7 +241,7 @@ class CrossAttention(nn.Module):
out = einsum('b i j, b j d -> b i d', sim, v) out = einsum('b i j, b j d -> b i d', sim, v)
""" """
## new # new
with sdp_kernel(**BACKEND_MAP[self.backend]): with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default

View File

@ -1,12 +1,20 @@
#!/bin/bash #!/bin/bash
NUM_VIDEOS=100 NUM_VIDEOS=10
INFERENCE_STEPS=50 INFERENCE_STEPS=50
GUIDANCE_SCALE=7.0 GUIDANCE_SCALE=7.0
OUTPUT_DIR_PREFIX="outputs/gpu_" OUTPUT_DIR_PREFIX="outputs/gpu_"
LOG_DIR_PREFIX="logs/gpu_" LOG_DIR_PREFIX="logs/gpu_"
CUDA_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} VIDEO_MODEL_PATH="/share/official_pretrains/hf_home/CogVideoX-5b-I2V"
LLM_MODEL_PATH="/share/home/zyx/Models/Meta-Llama-3.1-8B-Instruct"
IMAGE_MODEL_PATH = "share/home/zyx/Models/FLUX.1-dev"
#VIDEO_MODEL_PATH="THUDM/CogVideoX-5B-I2V"
#LLM_MODEL_PATH="THUDM/glm-4-9b-chat"
#IMAGE_MODEL_PATH = "black-forest-labs/FLUX.1-dev"
CUDA_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"}
IFS=',' read -r -a GPU_ARRAY <<< "$CUDA_DEVICES" IFS=',' read -r -a GPU_ARRAY <<< "$CUDA_DEVICES"
@ -15,6 +23,9 @@ do
GPU=${GPU_ARRAY[$i]} GPU=${GPU_ARRAY[$i]}
echo "Starting task on GPU $GPU..." echo "Starting task on GPU $GPU..."
CUDA_VISIBLE_DEVICES=$GPU nohup python3 llm_flux_cogvideox.py \ CUDA_VISIBLE_DEVICES=$GPU nohup python3 llm_flux_cogvideox.py \
--caption_generator_model_id $LLM_MODEL_PATH \
--image_generator_model_id $IMAGE_MODEL_PATH \
--model_path $VIDEO_MODEL_PATH \
--num_videos $NUM_VIDEOS \ --num_videos $NUM_VIDEOS \
--image_generator_num_inference_steps $INFERENCE_STEPS \ --image_generator_num_inference_steps $INFERENCE_STEPS \
--guidance_scale $GUIDANCE_SCALE \ --guidance_scale $GUIDANCE_SCALE \

View File

@ -48,6 +48,7 @@ There are a few rules to follow:
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. - If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit.
You responses should just be the video generation prompt. Here are examples: You responses should just be the video generation prompt. Here are examples:
- A lone figure stands on a city rooftop at night, gazing up at the full moon. The moon glows brightly, casting a gentle light over the quiet cityscape. Below, the windows of countless homes shine with warm lights, creating a contrast between the bustling life below and the peaceful solitude above. The scene captures the essence of the Mid-Autumn Festival, where despite the distance, the figure feels connected to loved ones through the shared beauty of the moonlit sky.
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." - "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
- "A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" - "A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall"
""".strip() """.strip()
@ -172,7 +173,7 @@ def main(args: Dict[str, Any]) -> None:
captions = [] captions = []
for i in range(args.num_videos): for i in range(args.num_videos):
num_words = random.choice([100, 150, 200]) num_words = random.choice([50, 75, 100])
user_prompt = USER_PROMPT.format(num_words) user_prompt = USER_PROMPT.format(num_words)
messages = [ messages = [