mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
支持了从SAT权重文件 导出的lora权重,脚本在CogVideoX仓库 tools/export_sat_lora_weight.py
加载后使用 load_cogvideox_lora.py 推理
This commit is contained in:
parent
98466e674c
commit
b2b772a942
BIN
resources/hf_lora_weights.png
Normal file
BIN
resources/hf_lora_weights.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 60 KiB |
@ -406,4 +406,30 @@ The SAT weight format is different from Huggingface's weight format and needs to
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
**Note**: This content has not yet been tested with LORA fine-tuning models.
|
||||
### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints
|
||||
|
||||
After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file at `{args.save}/1000/1000/mp_rank_00_model_states.pt`.
|
||||
|
||||
The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`. After exporting, you can use `load_cogvideox_lora.py` for inference.
|
||||
|
||||
#### Export command:
|
||||
```bash
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
|
||||
This training mainly modified the following model structures. The table below lists the corresponding structure mappings for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the model's attention structure.
|
||||
|
||||
```
|
||||
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
||||
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
|
||||
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
```
|
||||
|
||||
Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format.
|
||||

|
||||
|
@ -411,4 +411,34 @@ SAT ウェイト形式は Huggingface のウェイト形式と異なり、変換
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
**注意**:この内容は LORA ファインチューニングモデルではまだテストされていません。
|
||||
### SATチェックポイントからHuggingface Diffusers lora LoRAウェイトをエクスポート
|
||||
|
||||
上記のステップを完了すると、LoRAウェイト付きのSATチェックポイントが得られます。ファイルは `{args.save}/1000/1000/mp_rank_00_model_states.pt` にあります。
|
||||
|
||||
LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。
|
||||
|
||||
#### エクスポートコマンド:
|
||||
```bash
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
|
||||
このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。
|
||||
|
||||
|
||||
|
||||
```
|
||||
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
||||
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
|
||||
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
```
|
||||
|
||||
export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。
|
||||
|
||||
|
||||

|
||||
|
@ -404,4 +404,30 @@ SAT 权重格式与 Huggingface 的权重格式不同,需要转换。请运行
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
**注意** 本内容暂未测试 LORA 微调模型。
|
||||
### 从SAT权重文件 导出Huggingface Diffusers lora权重
|
||||
|
||||
支持了从SAT权重文件
|
||||
在经过上面这些步骤训练之后,我们得到了一个sat带lora的权重,在{args.save}/1000/1000/mp_rank_00_model_states.pt你可以看到这个文件
|
||||
|
||||
导出的lora权重脚本在CogVideoX仓库 tools/export_sat_lora_weight.py ,导出后使用 load_cogvideox_lora.py 推理
|
||||
- 导出命令
|
||||
```
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
···
|
||||
|
||||
这次训练主要修改了下面几个模型结构,下面列出了 转换为HF格式的lora结构对应关系,可以看到lora将模型注意力结构上增加一个低秩权重,
|
||||
|
||||
```
|
||||
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
||||
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
|
||||
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
```
|
||||
|
||||
通过export_sat_lora_weight.py将它转换为HF格式的lora结构
|
||||

|
||||
|
83
tools/export_sat_lora_weight.py
Normal file
83
tools/export_sat_lora_weight.py
Normal file
@ -0,0 +1,83 @@
|
||||
from typing import Any, Dict
|
||||
import torch
|
||||
import argparse
|
||||
from diffusers.loaders.lora_base import LoraBaseMixin
|
||||
from diffusers.models.modeling_utils import load_state_dict
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
LORA_KEYS_RENAME = {
|
||||
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
||||
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
|
||||
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
}
|
||||
|
||||
|
||||
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
SAT_UNIT_KEY = "layers"
|
||||
LORA_PREFIX_KEY = "transformer_blocks"
|
||||
|
||||
|
||||
|
||||
def export_lora_weight(ckpt_path,lora_save_directory):
|
||||
|
||||
merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
|
||||
|
||||
lora_state_dict = {}
|
||||
for key in list(merge_original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for special_key, lora_keys in LORA_KEYS_RENAME.items():
|
||||
if new_key.endswith(special_key):
|
||||
new_key = new_key.replace(special_key, lora_keys)
|
||||
new_key = new_key.replace(SAT_UNIT_KEY, LORA_PREFIX_KEY)
|
||||
|
||||
lora_state_dict[new_key] = merge_original_state_dict[key]
|
||||
|
||||
|
||||
|
||||
# final length should be 240
|
||||
if len(lora_state_dict) != 240:
|
||||
raise ValueError("lora_state_dict length is not 240")
|
||||
|
||||
lora_state_dict.keys()
|
||||
|
||||
LoraBaseMixin.write_lora_layers(
|
||||
state_dict=lora_state_dict,
|
||||
save_directory=lora_save_directory,
|
||||
is_main_process=True,
|
||||
weight_name=None,
|
||||
save_function=None,
|
||||
safe_serialization=True
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
export_lora_weight(args.sat_pt_path, args.lora_save_directory)
|
107
tools/load_cogvideox_lora.py
Normal file
107
tools/load_cogvideox_lora.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright 2024 The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from datetime import datetime, timedelta
|
||||
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_weights_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to lora weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_r",
|
||||
type=int,
|
||||
default=128,
|
||||
help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256.
|
||||
This part is used to calculate the value for lora_scale, which is by default divided by the alpha value,
|
||||
used for stable learning and to prevent underflow. In the SAT training framework,
|
||||
alpha is set to 1 by default. The higher the rank, the better the expressive capability,
|
||||
but it requires more memory and training time. Increasing this number blindly isn't always better.
|
||||
The formula for lora_scale is: lora_r / alpha.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
pipe = CogVideoXPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)
|
||||
pipe.load_lora_weights(args.lora_weights_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
|
||||
pipe.fuse_lora(lora_scale=1/128)
|
||||
|
||||
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
prompt="""In the heart of a bustling city, a young woman with long, flowing brown hair and a radiant smile stands out. She's donned in a cozy white beanie adorned with playful animal ears, adding a touch of whimsy to her appearance. Her eyes sparkle with joy as she looks directly into the camera, her expression inviting and warm. The background is a blur of activity, with indistinct figures moving about, suggesting a lively public space. The lighting is soft and diffused, casting a gentle glow on her face and highlighting her features. The overall mood is cheerful and vibrant, capturing a moment of happiness in the midst of urban life.
|
||||
"""
|
||||
latents = pipe(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=1,
|
||||
num_inference_steps=50,
|
||||
num_frames=49,
|
||||
use_dynamic_cfg=True,
|
||||
output_type="pt",
|
||||
guidance_scale=3.0,
|
||||
generator=torch.Generator(device="cpu").manual_seed(42),
|
||||
).frames
|
||||
batch_size = latents.shape[0]
|
||||
batch_video_frames = []
|
||||
for batch_idx in range(batch_size):
|
||||
pt_image = latents[batch_idx]
|
||||
pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
|
||||
|
||||
image_np = VaeImageProcessor.pt_to_numpy(pt_image)
|
||||
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
||||
batch_video_frames.append(image_pil)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
video_path = f"{args.output_dir}/{timestamp}.mp4"
|
||||
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
||||
tensor = batch_video_frames[0]
|
||||
fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
|
||||
|
||||
export_to_video(tensor, video_path, fps=fps)
|
Loading…
x
Reference in New Issue
Block a user