mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
60 lines
2.2 KiB
Python
60 lines
2.2 KiB
Python
from typing import Dict, Literal
|
|
|
|
from finetune.trainer import Trainer
|
|
|
|
|
|
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
|
|
|
|
|
|
def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
|
|
"""Register a model and its associated functions for a specific training type.
|
|
|
|
Args:
|
|
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
|
|
training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
|
|
trainer_cls (Trainer): Trainer class to register.
|
|
"""
|
|
|
|
# Check if model_name and training_type exists in SUPPORTED_MODELS
|
|
if model_name not in SUPPORTED_MODELS:
|
|
SUPPORTED_MODELS[model_name] = {}
|
|
else:
|
|
if training_type in SUPPORTED_MODELS[model_name]:
|
|
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
|
|
|
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
|
|
|
|
|
def show_supported_models():
|
|
"""Print all currently supported models and their training types."""
|
|
|
|
print("\nSupported Models:")
|
|
print("================")
|
|
|
|
for model_name, training_types in SUPPORTED_MODELS.items():
|
|
print(f"\n{model_name}")
|
|
print("-" * len(model_name))
|
|
for training_type in training_types:
|
|
print(f" • {training_type}")
|
|
|
|
|
|
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
|
|
"""Get the trainer class for a specific model and training type."""
|
|
if model_type not in SUPPORTED_MODELS:
|
|
print(f"\nModel '{model_type}' is not supported.")
|
|
print("\nSupported models are:")
|
|
for supported_model in SUPPORTED_MODELS:
|
|
print(f" • {supported_model}")
|
|
raise ValueError(f"Model '{model_type}' is not supported")
|
|
|
|
if training_type not in SUPPORTED_MODELS[model_type]:
|
|
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
|
|
print(f"\nSupported training types for '{model_type}' are:")
|
|
for supported_type in SUPPORTED_MODELS[model_type]:
|
|
print(f" • {supported_type}")
|
|
raise ValueError(
|
|
f"Training type '{training_type}' is not supported for model '{model_type}'"
|
|
)
|
|
|
|
return SUPPORTED_MODELS[model_type][training_type]
|