mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
feat: add SFT support with ZeRO optimization strategies
- Add SFT (Supervised Fine-Tuning) trainers for all model variants: - CogVideoX I2V and T2V - CogVideoX-1.5 I2V and T2V - Add DeepSpeed ZeRO configuration files: - ZeRO-2 with and without CPU offload - ZeRO-3 with and without CPU offload - Add base accelerate config for distributed training - Update trainer.py to support SFT training mode This enables full-parameter fine-tuning with memory-efficient distributed training using DeepSpeed ZeRO optimization.
This commit is contained in:
parent
e213b6c083
commit
caa24bdc36
27
finetune/accelerate_config.yaml
Normal file
27
finetune/accelerate_config.yaml
Normal file
@ -0,0 +1,27 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
|
||||
# gpu_ids: "0" # 0,1,2,3,4,5,6,7
|
||||
# num_processes: 1
|
||||
|
||||
gpu_ids: all
|
||||
num_processes: 8
|
||||
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_config_file: /path/to/your/configs/zero2.yaml
|
||||
# deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml
|
||||
# deepspeed_config_file: /path/to/your/configs/zero3.yaml
|
||||
# deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml
|
||||
zero3_init_flag: false
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
num_machines: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
38
finetune/configs/zero2.yaml
Normal file
38
finetune/configs/zero2.yaml
Normal file
@ -0,0 +1,38 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
42
finetune/configs/zero2_offload.yaml
Normal file
42
finetune/configs/zero2_offload.yaml
Normal file
@ -0,0 +1,42 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
43
finetune/configs/zero3.yaml
Normal file
43
finetune/configs/zero3.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto",
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e5
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
51
finetune/configs/zero3_offload.yaml
Normal file
51
finetune/configs/zero3_offload.yaml
Normal file
@ -0,0 +1,51 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto",
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
9
finetune/models/cogvideox1dot5_i2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox1dot5_i2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_i2v.sft_trainer import CogVideoXI2VSftTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1dot5I2VSftTrainer(CogVideoXI2VSftTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-i2v", "sft", CogVideoX1dot5I2VSftTrainer)
|
9
finetune/models/cogvideox1dot5_t2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox1dot5_t2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_t2v.sft_trainer import CogVideoXT2VSftTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1dot5T2VSftTrainer(CogVideoXT2VSftTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-t2v", "sft", CogVideoX1dot5T2VSftTrainer)
|
27
finetune/models/cogvideox_i2v/sft_trainer.py
Normal file
27
finetune/models/cogvideox_i2v/sft_trainer.py
Normal file
@ -0,0 +1,27 @@
|
||||
import dotmap
|
||||
from diffusers import CogVideoXImageToVideoPipeline
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.utils import unwrap_model
|
||||
|
||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
||||
@override
|
||||
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
||||
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
||||
self.components.transformer.config.update(origin_model.config)
|
||||
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
||||
pipe = CogVideoXImageToVideoPipeline(
|
||||
tokenizer=self.components.tokenizer,
|
||||
text_encoder=self.components.text_encoder,
|
||||
vae=self.components.vae,
|
||||
transformer=self.components.transformer,
|
||||
scheduler=self.components.scheduler,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|
27
finetune/models/cogvideox_t2v/sft_trainer.py
Normal file
27
finetune/models/cogvideox_t2v/sft_trainer.py
Normal file
@ -0,0 +1,27 @@
|
||||
import dotmap
|
||||
from diffusers import CogVideoXPipeline
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.utils import unwrap_model
|
||||
|
||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
|
||||
@override
|
||||
def initialize_pipeline(self) -> CogVideoXPipeline:
|
||||
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
||||
self.components.transformer.config.update(origin_model.config)
|
||||
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
||||
pipe = CogVideoXPipeline(
|
||||
tokenizer=self.components.tokenizer,
|
||||
text_encoder=self.components.text_encoder,
|
||||
vae=self.components.vae,
|
||||
transformer=self.components.transformer,
|
||||
scheduler=self.components.scheduler,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)
|
@ -15,17 +15,12 @@ def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls
|
||||
trainer_cls (Trainer): Trainer class to register.
|
||||
"""
|
||||
|
||||
# Check if model_name exists in SUPPORTED_MODELS
|
||||
# Check if model_name and training_type exists in SUPPORTED_MODELS
|
||||
if model_name not in SUPPORTED_MODELS:
|
||||
SUPPORTED_MODELS[model_name] = {}
|
||||
else:
|
||||
raise ValueError(f"Model {model_name} already exists")
|
||||
|
||||
# Check if training_type exists for this model
|
||||
if training_type not in SUPPORTED_MODELS[model_name]:
|
||||
SUPPORTED_MODELS[model_name][training_type] = {}
|
||||
else:
|
||||
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
||||
if training_type in SUPPORTED_MODELS[model_name]:
|
||||
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
||||
|
||||
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@ -71,7 +72,7 @@ class Trainer:
|
||||
train_width=self.args.train_resolution[2],
|
||||
)
|
||||
|
||||
self.components = Components()
|
||||
self.components: Components = self.load_components()
|
||||
self.accelerator: Accelerator = None
|
||||
self.dataset: Dataset = None
|
||||
self.data_loader: DataLoader = None
|
||||
@ -145,9 +146,6 @@ class Trainer:
|
||||
def prepare_models(self) -> None:
|
||||
logger.info("Initializing models")
|
||||
|
||||
# Initialize model components
|
||||
self.components = self.load_components()
|
||||
|
||||
if self.components.vae is not None:
|
||||
if self.args.enable_slicing:
|
||||
self.components.vae.enable_slicing()
|
||||
@ -159,15 +157,11 @@ class Trainer:
|
||||
def prepare_dataset(self) -> None:
|
||||
logger.info("Initializing dataset and dataloader")
|
||||
|
||||
# self.state.train_frames includes one padding frame for image conditioning
|
||||
# so we only sample train_frames - 1 frames from the actual video
|
||||
sample_frames = self.state.train_frames - 1
|
||||
|
||||
if self.args.model_type == "i2v":
|
||||
self.dataset = I2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
max_num_frames=sample_frames,
|
||||
max_num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self,
|
||||
@ -176,7 +170,7 @@ class Trainer:
|
||||
self.dataset = T2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
max_num_frames=sample_frames,
|
||||
max_num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self,
|
||||
@ -223,12 +217,7 @@ class Trainer:
|
||||
def prepare_trainable_parameters(self):
|
||||
logger.info("Initializing trainable parameters")
|
||||
|
||||
# For now only lora is supported
|
||||
for attr_name, component in vars(self.components).items():
|
||||
if hasattr(component, "requires_grad_"):
|
||||
component.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# For mixed precision training we cast all non-trainable weights to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = self.state.weight_dtype
|
||||
|
||||
@ -238,35 +227,47 @@ class Trainer:
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
self.__load_components()
|
||||
# For LoRA, we freeze all the parameters
|
||||
# For SFT, we train all the parameters in transformer model
|
||||
for attr_name, component in vars(self.components).items():
|
||||
if hasattr(component, "requires_grad_"):
|
||||
if self.args.training_type == "sft" and attr_name == "transformer":
|
||||
component.requires_grad_(True)
|
||||
else:
|
||||
component.requires_grad_(False)
|
||||
|
||||
if self.args.training_type == "lora":
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=self.args.rank,
|
||||
lora_alpha=self.args.lora_alpha,
|
||||
init_lora_weights=True,
|
||||
target_modules=self.args.target_modules,
|
||||
)
|
||||
self.components.transformer.add_adapter(transformer_lora_config)
|
||||
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
||||
|
||||
# Load components needed for training to GPU (except transformer),
|
||||
# and cast them to the specified data type
|
||||
self.__move_components_to_device(dtype=weight_dtype)
|
||||
|
||||
if self.args.gradient_checkpointing:
|
||||
self.components.transformer.enable_gradient_checkpointing()
|
||||
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=self.args.rank,
|
||||
lora_alpha=self.args.lora_alpha,
|
||||
init_lora_weights=True,
|
||||
target_modules=self.args.target_modules,
|
||||
)
|
||||
self.components.transformer.add_adapter(transformer_lora_config)
|
||||
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
||||
|
||||
def prepare_optimizer(self) -> None:
|
||||
logger.info("Initializing optimizer and lr scheduler")
|
||||
|
||||
# Make sure the trainable params are in float32
|
||||
if self.args.mixed_precision != "no":
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
|
||||
# For LoRA, we only want to train the LoRA weights
|
||||
# For SFT, we want to train all the parameters
|
||||
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
|
||||
transformer_parameters_with_lr = {
|
||||
"params": transformer_lora_parameters,
|
||||
"params": trainable_parameters,
|
||||
"lr": self.args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters)
|
||||
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters)
|
||||
|
||||
use_deepspeed_opt = (
|
||||
self.accelerator.state.deepspeed_plugin is not None
|
||||
@ -502,13 +503,15 @@ class Trainer:
|
||||
# Convert all model weights to training dtype
|
||||
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
|
||||
pipe = pipe.to(dtype=self.state.weight_dtype)
|
||||
|
||||
#################################
|
||||
|
||||
all_processes_artifacts = []
|
||||
for i in range(num_validation_samples):
|
||||
# Skip current validation on all processes but one
|
||||
if i % accelerator.num_processes != accelerator.process_index:
|
||||
continue
|
||||
if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
||||
# Skip current validation on all processes but one
|
||||
if i % accelerator.num_processes != accelerator.process_index:
|
||||
continue
|
||||
|
||||
prompt = self.state.validation_prompts[i]
|
||||
image = self.state.validation_images[i]
|
||||
@ -534,7 +537,19 @@ class Trainer:
|
||||
main_process_only=False,
|
||||
)
|
||||
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||
|
||||
if (
|
||||
self.accelerator.deepspeed_plugin is not None
|
||||
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
||||
and not accelerator.is_main_process
|
||||
):
|
||||
continue
|
||||
|
||||
prompt_filename = string_to_filename(prompt)[:25]
|
||||
# Calculate hash of reversed prompt as a unique identifier
|
||||
reversed_prompt = prompt[::-1]
|
||||
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
|
||||
|
||||
artifacts = {
|
||||
"image": {"type": "image", "value": image},
|
||||
"video": {"type": "video", "value": video},
|
||||
@ -553,7 +568,7 @@ class Trainer:
|
||||
continue
|
||||
|
||||
extension = "png" if artifact_type == "image" else "mp4"
|
||||
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
|
||||
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
|
||||
validation_path = self.args.output_dir / "validation_res"
|
||||
validation_path.mkdir(parents=True, exist_ok=True)
|
||||
filename = str(validation_path / filename)
|
||||
@ -587,11 +602,15 @@ class Trainer:
|
||||
pipe.remove_all_hooks()
|
||||
del pipe
|
||||
# Unload models except those needed for training
|
||||
self.__unload_components()
|
||||
self.__move_components_to_cpu()
|
||||
|
||||
# Load models except those not needed for training
|
||||
self.__load_components()
|
||||
# Change LoRA weights back to fp32
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
self.__move_components_to_device(dtype=self.state.weight_dtype)
|
||||
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
|
||||
if self.accelerator.state.deepspeed_plugin is None:
|
||||
# Change trainable weights back to fp32
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@ -649,16 +668,21 @@ class Trainer:
|
||||
else:
|
||||
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
||||
|
||||
def __load_components(self):
|
||||
def __move_components_to_device(self, dtype):
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name in self.UNLOAD_LIST:
|
||||
continue
|
||||
# setattr(self.components, name, component.to(self.accelerator.device))
|
||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
|
||||
|
||||
def __unload_components(self):
|
||||
# We don't need to move transformer to device
|
||||
# because we will prepare it in the `prepare_for_training()`
|
||||
if name == "transformer":
|
||||
continue
|
||||
|
||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
||||
|
||||
def __move_components_to_cpu(self):
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
@ -723,13 +747,6 @@ class Trainer:
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if self.args.mixed_precision == "fp16":
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params([transformer_])
|
||||
|
||||
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user