google-labs-jules[bot] 193b1f4dcb feat: Add knowledge distillation for logo generation
This commit introduces a knowledge distillation module to enhance logo generation in the CogVideoX-2B text-to-video model.

The key changes include:

- A new `KDTrainer` class that inherits from `CogVideoXT2VLoraTrainer`. This trainer loads a teacher model (OpenLogo Faster R-CNN) and computes a knowledge distillation loss to guide the student model.
- The `kd` training type is now supported, allowing users to select it from the command line.
- New command-line arguments (`teacher_model_path`, `teacher_model_num_classes`, `kd_loss_weight`) have been added to configure the knowledge distillation process.
- A new configuration file (`cogvideox_2b_kd.yaml`) is provided as an example for running a `kd` training session.
2025-08-21 09:14:51 +00:00

60 lines
2.3 KiB
Python

from typing import Dict, Literal
from finetune.trainer import Trainer
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
def register(model_name: str, training_type: Literal["lora", "sft", "kd"], trainer_cls: Trainer):
"""Register a model and its associated functions for a specific training type.
Args:
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
trainer_cls (Trainer): Trainer class to register.
"""
# Check if model_name and training_type exists in SUPPORTED_MODELS
if model_name not in SUPPORTED_MODELS:
SUPPORTED_MODELS[model_name] = {}
else:
if training_type in SUPPORTED_MODELS[model_name]:
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
def show_supported_models():
"""Print all currently supported models and their training types."""
print("\nSupported Models:")
print("================")
for model_name, training_types in SUPPORTED_MODELS.items():
print(f"\n{model_name}")
print("-" * len(model_name))
for training_type in training_types:
print(f"{training_type}")
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
"""Get the trainer class for a specific model and training type."""
if model_type not in SUPPORTED_MODELS:
print(f"\nModel '{model_type}' is not supported.")
print("\nSupported models are:")
for supported_model in SUPPORTED_MODELS:
print(f"{supported_model}")
raise ValueError(f"Model '{model_type}' is not supported")
if training_type not in SUPPORTED_MODELS[model_type]:
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
print(f"\nSupported training types for '{model_type}' are:")
for supported_type in SUPPORTED_MODELS[model_type]:
print(f"{supported_type}")
raise ValueError(
f"Training type '{training_type}' is not supported for model '{model_type}'"
)
return SUPPORTED_MODELS[model_type][training_type]