mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-29 13:55:53 +08:00
feat: add SFT support with ZeRO optimization strategies
- Add SFT (Supervised Fine-Tuning) trainers for all model variants: - CogVideoX I2V and T2V - CogVideoX-1.5 I2V and T2V - Add DeepSpeed ZeRO configuration files: - ZeRO-2 with and without CPU offload - ZeRO-3 with and without CPU offload - Add base accelerate config for distributed training - Update trainer.py to support SFT training mode This enables full-parameter fine-tuning with memory-efficient distributed training using DeepSpeed ZeRO optimization.
This commit is contained in:
parent
e213b6c083
commit
caa24bdc36
27
finetune/accelerate_config.yaml
Normal file
27
finetune/accelerate_config.yaml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
|
||||||
|
# gpu_ids: "0" # 0,1,2,3,4,5,6,7
|
||||||
|
# num_processes: 1
|
||||||
|
|
||||||
|
gpu_ids: all
|
||||||
|
num_processes: 8
|
||||||
|
|
||||||
|
debug: false
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_config_file: /path/to/your/configs/zero2.yaml
|
||||||
|
# deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml
|
||||||
|
# deepspeed_config_file: /path/to/your/configs/zero3.yaml
|
||||||
|
# deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml
|
||||||
|
zero3_init_flag: false
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
num_machines: 1
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
38
finetune/configs/zero2.yaml
Normal file
38
finetune/configs/zero2.yaml
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"weight_decay": "auto",
|
||||||
|
"torch_adam": true,
|
||||||
|
"adam_w_mode": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 2e8,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"train_micro_batch_size_per_gpu": 1,
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"steps_per_print": 2000,
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
42
finetune/configs/zero2_offload.yaml
Normal file
42
finetune/configs/zero2_offload.yaml
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
{
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"weight_decay": "auto",
|
||||||
|
"torch_adam": true,
|
||||||
|
"adam_w_mode": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 2e8,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"train_micro_batch_size_per_gpu": 1,
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"steps_per_print": 2000,
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
43
finetune/configs/zero3.yaml
Normal file
43
finetune/configs/zero3.yaml
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
{
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"weight_decay": "auto",
|
||||||
|
"torch_adam": true,
|
||||||
|
"adam_w_mode": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"stage3_max_live_parameters": 1e9,
|
||||||
|
"stage3_max_reuse_distance": 1e9,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": 5e8,
|
||||||
|
"stage3_param_persistence_threshold": 1e5
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"train_micro_batch_size_per_gpu": 1,
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"steps_per_print": 2000,
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
51
finetune/configs/zero3_offload.yaml
Normal file
51
finetune/configs/zero3_offload.yaml
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
{
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"weight_decay": "auto",
|
||||||
|
"torch_adam": true,
|
||||||
|
"adam_w_mode": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"stage3_max_live_parameters": 1e9,
|
||||||
|
"stage3_max_reuse_distance": 1e9,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": 5e8,
|
||||||
|
"stage3_param_persistence_threshold": 1e6
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"train_micro_batch_size_per_gpu": 1,
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"steps_per_print": 2000,
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
9
finetune/models/cogvideox1dot5_i2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox1dot5_i2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from ..cogvideox_i2v.sft_trainer import CogVideoXI2VSftTrainer
|
||||||
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoX1dot5I2VSftTrainer(CogVideoXI2VSftTrainer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
register("cogvideox1.5-i2v", "sft", CogVideoX1dot5I2VSftTrainer)
|
9
finetune/models/cogvideox1dot5_t2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox1dot5_t2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from ..cogvideox_t2v.sft_trainer import CogVideoXT2VSftTrainer
|
||||||
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoX1dot5T2VSftTrainer(CogVideoXT2VSftTrainer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
register("cogvideox1.5-t2v", "sft", CogVideoX1dot5T2VSftTrainer)
|
27
finetune/models/cogvideox_i2v/sft_trainer.py
Normal file
27
finetune/models/cogvideox_i2v/sft_trainer.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import dotmap
|
||||||
|
from diffusers import CogVideoXImageToVideoPipeline
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from finetune.utils import unwrap_model
|
||||||
|
|
||||||
|
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||||
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
||||||
|
@override
|
||||||
|
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
||||||
|
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
||||||
|
self.components.transformer.config.update(origin_model.config)
|
||||||
|
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
||||||
|
pipe = CogVideoXImageToVideoPipeline(
|
||||||
|
tokenizer=self.components.tokenizer,
|
||||||
|
text_encoder=self.components.text_encoder,
|
||||||
|
vae=self.components.vae,
|
||||||
|
transformer=self.components.transformer,
|
||||||
|
scheduler=self.components.scheduler,
|
||||||
|
)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|
27
finetune/models/cogvideox_t2v/sft_trainer.py
Normal file
27
finetune/models/cogvideox_t2v/sft_trainer.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import dotmap
|
||||||
|
from diffusers import CogVideoXPipeline
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from finetune.utils import unwrap_model
|
||||||
|
|
||||||
|
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||||
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
|
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
|
||||||
|
@override
|
||||||
|
def initialize_pipeline(self) -> CogVideoXPipeline:
|
||||||
|
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
||||||
|
self.components.transformer.config.update(origin_model.config)
|
||||||
|
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
||||||
|
pipe = CogVideoXPipeline(
|
||||||
|
tokenizer=self.components.tokenizer,
|
||||||
|
text_encoder=self.components.text_encoder,
|
||||||
|
vae=self.components.vae,
|
||||||
|
transformer=self.components.transformer,
|
||||||
|
scheduler=self.components.scheduler,
|
||||||
|
)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)
|
@ -15,17 +15,12 @@ def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls
|
|||||||
trainer_cls (Trainer): Trainer class to register.
|
trainer_cls (Trainer): Trainer class to register.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Check if model_name exists in SUPPORTED_MODELS
|
# Check if model_name and training_type exists in SUPPORTED_MODELS
|
||||||
if model_name not in SUPPORTED_MODELS:
|
if model_name not in SUPPORTED_MODELS:
|
||||||
SUPPORTED_MODELS[model_name] = {}
|
SUPPORTED_MODELS[model_name] = {}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {model_name} already exists")
|
if training_type in SUPPORTED_MODELS[model_name]:
|
||||||
|
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
||||||
# Check if training_type exists for this model
|
|
||||||
if training_type not in SUPPORTED_MODELS[model_name]:
|
|
||||||
SUPPORTED_MODELS[model_name][training_type] = {}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
|
||||||
|
|
||||||
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@ -71,7 +72,7 @@ class Trainer:
|
|||||||
train_width=self.args.train_resolution[2],
|
train_width=self.args.train_resolution[2],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.components = Components()
|
self.components: Components = self.load_components()
|
||||||
self.accelerator: Accelerator = None
|
self.accelerator: Accelerator = None
|
||||||
self.dataset: Dataset = None
|
self.dataset: Dataset = None
|
||||||
self.data_loader: DataLoader = None
|
self.data_loader: DataLoader = None
|
||||||
@ -145,9 +146,6 @@ class Trainer:
|
|||||||
def prepare_models(self) -> None:
|
def prepare_models(self) -> None:
|
||||||
logger.info("Initializing models")
|
logger.info("Initializing models")
|
||||||
|
|
||||||
# Initialize model components
|
|
||||||
self.components = self.load_components()
|
|
||||||
|
|
||||||
if self.components.vae is not None:
|
if self.components.vae is not None:
|
||||||
if self.args.enable_slicing:
|
if self.args.enable_slicing:
|
||||||
self.components.vae.enable_slicing()
|
self.components.vae.enable_slicing()
|
||||||
@ -159,15 +157,11 @@ class Trainer:
|
|||||||
def prepare_dataset(self) -> None:
|
def prepare_dataset(self) -> None:
|
||||||
logger.info("Initializing dataset and dataloader")
|
logger.info("Initializing dataset and dataloader")
|
||||||
|
|
||||||
# self.state.train_frames includes one padding frame for image conditioning
|
|
||||||
# so we only sample train_frames - 1 frames from the actual video
|
|
||||||
sample_frames = self.state.train_frames - 1
|
|
||||||
|
|
||||||
if self.args.model_type == "i2v":
|
if self.args.model_type == "i2v":
|
||||||
self.dataset = I2VDatasetWithResize(
|
self.dataset = I2VDatasetWithResize(
|
||||||
**(self.args.model_dump()),
|
**(self.args.model_dump()),
|
||||||
device=self.accelerator.device,
|
device=self.accelerator.device,
|
||||||
max_num_frames=sample_frames,
|
max_num_frames=self.state.train_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
trainer=self,
|
trainer=self,
|
||||||
@ -176,7 +170,7 @@ class Trainer:
|
|||||||
self.dataset = T2VDatasetWithResize(
|
self.dataset = T2VDatasetWithResize(
|
||||||
**(self.args.model_dump()),
|
**(self.args.model_dump()),
|
||||||
device=self.accelerator.device,
|
device=self.accelerator.device,
|
||||||
max_num_frames=sample_frames,
|
max_num_frames=self.state.train_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
trainer=self,
|
trainer=self,
|
||||||
@ -223,12 +217,7 @@ class Trainer:
|
|||||||
def prepare_trainable_parameters(self):
|
def prepare_trainable_parameters(self):
|
||||||
logger.info("Initializing trainable parameters")
|
logger.info("Initializing trainable parameters")
|
||||||
|
|
||||||
# For now only lora is supported
|
# For mixed precision training we cast all non-trainable weights to half-precision
|
||||||
for attr_name, component in vars(self.components).items():
|
|
||||||
if hasattr(component, "requires_grad_"):
|
|
||||||
component.requires_grad_(False)
|
|
||||||
|
|
||||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
|
||||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||||
weight_dtype = self.state.weight_dtype
|
weight_dtype = self.state.weight_dtype
|
||||||
|
|
||||||
@ -238,35 +227,47 @@ class Trainer:
|
|||||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.__load_components()
|
# For LoRA, we freeze all the parameters
|
||||||
|
# For SFT, we train all the parameters in transformer model
|
||||||
|
for attr_name, component in vars(self.components).items():
|
||||||
|
if hasattr(component, "requires_grad_"):
|
||||||
|
if self.args.training_type == "sft" and attr_name == "transformer":
|
||||||
|
component.requires_grad_(True)
|
||||||
|
else:
|
||||||
|
component.requires_grad_(False)
|
||||||
|
|
||||||
|
if self.args.training_type == "lora":
|
||||||
|
transformer_lora_config = LoraConfig(
|
||||||
|
r=self.args.rank,
|
||||||
|
lora_alpha=self.args.lora_alpha,
|
||||||
|
init_lora_weights=True,
|
||||||
|
target_modules=self.args.target_modules,
|
||||||
|
)
|
||||||
|
self.components.transformer.add_adapter(transformer_lora_config)
|
||||||
|
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
||||||
|
|
||||||
|
# Load components needed for training to GPU (except transformer),
|
||||||
|
# and cast them to the specified data type
|
||||||
|
self.__move_components_to_device(dtype=weight_dtype)
|
||||||
|
|
||||||
if self.args.gradient_checkpointing:
|
if self.args.gradient_checkpointing:
|
||||||
self.components.transformer.enable_gradient_checkpointing()
|
self.components.transformer.enable_gradient_checkpointing()
|
||||||
|
|
||||||
transformer_lora_config = LoraConfig(
|
|
||||||
r=self.args.rank,
|
|
||||||
lora_alpha=self.args.lora_alpha,
|
|
||||||
init_lora_weights=True,
|
|
||||||
target_modules=self.args.target_modules,
|
|
||||||
)
|
|
||||||
self.components.transformer.add_adapter(transformer_lora_config)
|
|
||||||
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
|
||||||
|
|
||||||
def prepare_optimizer(self) -> None:
|
def prepare_optimizer(self) -> None:
|
||||||
logger.info("Initializing optimizer and lr scheduler")
|
logger.info("Initializing optimizer and lr scheduler")
|
||||||
|
|
||||||
# Make sure the trainable params are in float32
|
# Make sure the trainable params are in float32
|
||||||
if self.args.mixed_precision != "no":
|
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||||
# only upcast trainable parameters (LoRA) into fp32
|
|
||||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
|
||||||
|
|
||||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
|
# For LoRA, we only want to train the LoRA weights
|
||||||
|
# For SFT, we want to train all the parameters
|
||||||
|
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
|
||||||
transformer_parameters_with_lr = {
|
transformer_parameters_with_lr = {
|
||||||
"params": transformer_lora_parameters,
|
"params": trainable_parameters,
|
||||||
"lr": self.args.learning_rate,
|
"lr": self.args.learning_rate,
|
||||||
}
|
}
|
||||||
params_to_optimize = [transformer_parameters_with_lr]
|
params_to_optimize = [transformer_parameters_with_lr]
|
||||||
self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters)
|
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters)
|
||||||
|
|
||||||
use_deepspeed_opt = (
|
use_deepspeed_opt = (
|
||||||
self.accelerator.state.deepspeed_plugin is not None
|
self.accelerator.state.deepspeed_plugin is not None
|
||||||
@ -502,13 +503,15 @@ class Trainer:
|
|||||||
# Convert all model weights to training dtype
|
# Convert all model weights to training dtype
|
||||||
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
|
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
|
||||||
pipe = pipe.to(dtype=self.state.weight_dtype)
|
pipe = pipe.to(dtype=self.state.weight_dtype)
|
||||||
|
|
||||||
#################################
|
#################################
|
||||||
|
|
||||||
all_processes_artifacts = []
|
all_processes_artifacts = []
|
||||||
for i in range(num_validation_samples):
|
for i in range(num_validation_samples):
|
||||||
# Skip current validation on all processes but one
|
if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
||||||
if i % accelerator.num_processes != accelerator.process_index:
|
# Skip current validation on all processes but one
|
||||||
continue
|
if i % accelerator.num_processes != accelerator.process_index:
|
||||||
|
continue
|
||||||
|
|
||||||
prompt = self.state.validation_prompts[i]
|
prompt = self.state.validation_prompts[i]
|
||||||
image = self.state.validation_images[i]
|
image = self.state.validation_images[i]
|
||||||
@ -534,7 +537,19 @@ class Trainer:
|
|||||||
main_process_only=False,
|
main_process_only=False,
|
||||||
)
|
)
|
||||||
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.accelerator.deepspeed_plugin is not None
|
||||||
|
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
||||||
|
and not accelerator.is_main_process
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
prompt_filename = string_to_filename(prompt)[:25]
|
prompt_filename = string_to_filename(prompt)[:25]
|
||||||
|
# Calculate hash of reversed prompt as a unique identifier
|
||||||
|
reversed_prompt = prompt[::-1]
|
||||||
|
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
|
||||||
|
|
||||||
artifacts = {
|
artifacts = {
|
||||||
"image": {"type": "image", "value": image},
|
"image": {"type": "image", "value": image},
|
||||||
"video": {"type": "video", "value": video},
|
"video": {"type": "video", "value": video},
|
||||||
@ -553,7 +568,7 @@ class Trainer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
extension = "png" if artifact_type == "image" else "mp4"
|
extension = "png" if artifact_type == "image" else "mp4"
|
||||||
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
|
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
|
||||||
validation_path = self.args.output_dir / "validation_res"
|
validation_path = self.args.output_dir / "validation_res"
|
||||||
validation_path.mkdir(parents=True, exist_ok=True)
|
validation_path.mkdir(parents=True, exist_ok=True)
|
||||||
filename = str(validation_path / filename)
|
filename = str(validation_path / filename)
|
||||||
@ -587,11 +602,15 @@ class Trainer:
|
|||||||
pipe.remove_all_hooks()
|
pipe.remove_all_hooks()
|
||||||
del pipe
|
del pipe
|
||||||
# Unload models except those needed for training
|
# Unload models except those needed for training
|
||||||
self.__unload_components()
|
self.__move_components_to_cpu()
|
||||||
|
|
||||||
# Load models except those not needed for training
|
# Load models except those not needed for training
|
||||||
self.__load_components()
|
self.__move_components_to_device(dtype=self.state.weight_dtype)
|
||||||
# Change LoRA weights back to fp32
|
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
|
||||||
|
if self.accelerator.state.deepspeed_plugin is None:
|
||||||
|
# Change trainable weights back to fp32
|
||||||
|
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@ -649,16 +668,21 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
||||||
|
|
||||||
def __load_components(self):
|
def __move_components_to_device(self, dtype):
|
||||||
components = self.components.model_dump()
|
components = self.components.model_dump()
|
||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, "to"):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
if name in self.UNLOAD_LIST:
|
if name in self.UNLOAD_LIST:
|
||||||
continue
|
continue
|
||||||
# setattr(self.components, name, component.to(self.accelerator.device))
|
|
||||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
|
|
||||||
|
|
||||||
def __unload_components(self):
|
# We don't need to move transformer to device
|
||||||
|
# because we will prepare it in the `prepare_for_training()`
|
||||||
|
if name == "transformer":
|
||||||
|
continue
|
||||||
|
|
||||||
|
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
||||||
|
|
||||||
|
def __move_components_to_cpu(self):
|
||||||
components = self.components.model_dump()
|
components = self.components.model_dump()
|
||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, "to"):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
@ -723,13 +747,6 @@ class Trainer:
|
|||||||
f" {unexpected_keys}. "
|
f" {unexpected_keys}. "
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
|
||||||
# are in `weight_dtype`. More details:
|
|
||||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
|
||||||
if self.args.mixed_precision == "fp16":
|
|
||||||
# only upcast trainable parameters (LoRA) into fp32
|
|
||||||
cast_training_params([transformer_])
|
|
||||||
|
|
||||||
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user