支持了从SAT权重文件 导出的lora权重,脚本在CogVideoX仓库 tools/export_sat_lora_weight.py

加载后使用 load_cogvideox_lora.py 推理
This commit is contained in:
glide-the 2024-09-11 16:24:31 +08:00
parent 98466e674c
commit b2b772a942
6 changed files with 275 additions and 3 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

View File

@ -406,4 +406,30 @@ The SAT weight format is different from Huggingface's weight format and needs to
python ../tools/convert_weight_sat2hf.py
```
**Note**: This content has not yet been tested with LORA fine-tuning models.
### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints
After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file at `{args.save}/1000/1000/mp_rank_00_model_states.pt`.
The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`. After exporting, you can use `load_cogvideox_lora.py` for inference.
#### Export command:
```bash
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
```
This training mainly modified the following model structures. The table below lists the corresponding structure mappings for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the model's attention structure.
```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
```
Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format.
![alt text](../resources/hf_lora_weights.png)

View File

@ -411,4 +411,34 @@ SAT ウェイト形式は Huggingface のウェイト形式と異なり、変換
python ../tools/convert_weight_sat2hf.py
```
**注意**:この内容は LORA ファインチューニングモデルではまだテストされていません。
### SATチェックポイントからHuggingface Diffusers lora LoRAウェイトをエクスポート
上記のステップを完了すると、LoRAウェイト付きのSATチェックポイントが得られます。ファイルは `{args.save}/1000/1000/mp_rank_00_model_states.pt` にあります。
LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。
#### エクスポートコマンド:
```bash
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
```
このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。
```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
```
export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。
![alt text](../resources/hf_lora_weights.png)

View File

@ -404,4 +404,30 @@ SAT 权重格式与 Huggingface 的权重格式不同,需要转换。请运行
python ../tools/convert_weight_sat2hf.py
```
**注意** 本内容暂未测试 LORA 微调模型。
### 从SAT权重文件 导出Huggingface Diffusers lora权重
支持了从SAT权重文件
在经过上面这些步骤训练之后我们得到了一个sat带lora的权重在{args.save}/1000/1000/mp_rank_00_model_states.pt你可以看到这个文件
导出的lora权重脚本在CogVideoX仓库 tools/export_sat_lora_weight.py ,导出后使用 load_cogvideox_lora.py 推理
- 导出命令
```
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
···
这次训练主要修改了下面几个模型结构,下面列出了 转换为HF格式的lora结构对应关系,可以看到lora将模型注意力结构上增加一个低秩权重,
```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
```
通过export_sat_lora_weight.py将它转换为HF格式的lora结构
![alt text](../resources/hf_lora_weights.png)

View File

@ -0,0 +1,83 @@
from typing import Any, Dict
import torch
import argparse
from diffusers.loaders.lora_base import LoraBaseMixin
from diffusers.models.modeling_utils import load_state_dict
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
LORA_KEYS_RENAME = {
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
}
PREFIX_KEY = "model.diffusion_model."
SAT_UNIT_KEY = "layers"
LORA_PREFIX_KEY = "transformer_blocks"
def export_lora_weight(ckpt_path,lora_save_directory):
merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
lora_state_dict = {}
for key in list(merge_original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for special_key, lora_keys in LORA_KEYS_RENAME.items():
if new_key.endswith(special_key):
new_key = new_key.replace(special_key, lora_keys)
new_key = new_key.replace(SAT_UNIT_KEY, LORA_PREFIX_KEY)
lora_state_dict[new_key] = merge_original_state_dict[key]
# final length should be 240
if len(lora_state_dict) != 240:
raise ValueError("lora_state_dict length is not 240")
lora_state_dict.keys()
LoraBaseMixin.write_lora_layers(
state_dict=lora_state_dict,
save_directory=lora_save_directory,
is_main_process=True,
weight_name=None,
save_function=None,
safe_serialization=True
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint"
)
parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
export_lora_weight(args.sat_pt_path, args.lora_save_directory)

View File

@ -0,0 +1,107 @@
# Copyright 2024 The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import time
from diffusers.utils import export_to_video
from diffusers.image_processor import VaeImageProcessor
from datetime import datetime, timedelta
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler
import os
import torch
import argparse
device = "cuda" if torch.cuda.is_available() else "cpu"
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--lora_weights_path",
type=str,
default=None,
required=True,
help="Path to lora weights.",
)
parser.add_argument(
"--lora_r",
type=int,
default=128,
help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256.
This part is used to calculate the value for lora_scale, which is by default divided by the alpha value,
used for stable learning and to prevent underflow. In the SAT training framework,
alpha is set to 1 by default. The higher the rank, the better the expressive capability,
but it requires more memory and training time. Increasing this number blindly isn't always better.
The formula for lora_scale is: lora_r / alpha.
""",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="The output directory where the model predictions and checkpoints will be written.",
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
pipe = CogVideoXPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)
pipe.load_lora_weights(args.lora_weights_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
pipe.fuse_lora(lora_scale=1/128)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
os.makedirs(args.output_dir, exist_ok=True)
prompt="""In the heart of a bustling city, a young woman with long, flowing brown hair and a radiant smile stands out. She's donned in a cozy white beanie adorned with playful animal ears, adding a touch of whimsy to her appearance. Her eyes sparkle with joy as she looks directly into the camera, her expression inviting and warm. The background is a blur of activity, with indistinct figures moving about, suggesting a lively public space. The lighting is soft and diffused, casting a gentle glow on her face and highlighting her features. The overall mood is cheerful and vibrant, capturing a moment of happiness in the midst of urban life.
"""
latents = pipe(
prompt=prompt,
num_videos_per_prompt=1,
num_inference_steps=50,
num_frames=49,
use_dynamic_cfg=True,
output_type="pt",
guidance_scale=3.0,
generator=torch.Generator(device="cpu").manual_seed(42),
).frames
batch_size = latents.shape[0]
batch_video_frames = []
for batch_idx in range(batch_size):
pt_image = latents[batch_idx]
pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
image_np = VaeImageProcessor.pt_to_numpy(pt_image)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
batch_video_frames.append(image_pil)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_path = f"{args.output_dir}/{timestamp}.mp4"
os.makedirs(os.path.dirname(video_path), exist_ok=True)
tensor = batch_video_frames[0]
fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
export_to_video(tensor, video_path, fps=fps)