mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 10:32:09 +08:00
This commit introduces a knowledge distillation module to enhance logo generation in the CogVideoX-2B text-to-video model. The key changes include: - A new `KDTrainer` class that inherits from `CogVideoXT2VLoraTrainer`. This trainer loads a teacher model (OpenLogo Faster R-CNN) and computes a knowledge distillation loss to guide the student model. - The `kd` training type is now supported, allowing users to select it from the command line. - New command-line arguments (`teacher_model_path`, `teacher_model_num_classes`, `kd_loss_weight`) have been added to configure the knowledge distillation process. - A new configuration file (`cogvideox_2b_kd.yaml`) is provided as an example for running a `kd` training session.
60 lines
2.3 KiB
Python
60 lines
2.3 KiB
Python
from typing import Dict, Literal
|
|
|
|
from finetune.trainer import Trainer
|
|
|
|
|
|
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
|
|
|
|
|
|
def register(model_name: str, training_type: Literal["lora", "sft", "kd"], trainer_cls: Trainer):
|
|
"""Register a model and its associated functions for a specific training type.
|
|
|
|
Args:
|
|
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
|
|
training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
|
|
trainer_cls (Trainer): Trainer class to register.
|
|
"""
|
|
|
|
# Check if model_name and training_type exists in SUPPORTED_MODELS
|
|
if model_name not in SUPPORTED_MODELS:
|
|
SUPPORTED_MODELS[model_name] = {}
|
|
else:
|
|
if training_type in SUPPORTED_MODELS[model_name]:
|
|
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
|
|
|
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
|
|
|
|
|
def show_supported_models():
|
|
"""Print all currently supported models and their training types."""
|
|
|
|
print("\nSupported Models:")
|
|
print("================")
|
|
|
|
for model_name, training_types in SUPPORTED_MODELS.items():
|
|
print(f"\n{model_name}")
|
|
print("-" * len(model_name))
|
|
for training_type in training_types:
|
|
print(f" • {training_type}")
|
|
|
|
|
|
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
|
|
"""Get the trainer class for a specific model and training type."""
|
|
if model_type not in SUPPORTED_MODELS:
|
|
print(f"\nModel '{model_type}' is not supported.")
|
|
print("\nSupported models are:")
|
|
for supported_model in SUPPORTED_MODELS:
|
|
print(f" • {supported_model}")
|
|
raise ValueError(f"Model '{model_type}' is not supported")
|
|
|
|
if training_type not in SUPPORTED_MODELS[model_type]:
|
|
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
|
|
print(f"\nSupported training types for '{model_type}' are:")
|
|
for supported_type in SUPPORTED_MODELS[model_type]:
|
|
print(f" • {supported_type}")
|
|
raise ValueError(
|
|
f"Training type '{training_type}' is not supported for model '{model_type}'"
|
|
)
|
|
|
|
return SUPPORTED_MODELS[model_type][training_type]
|