feat(models): add scaffolding

This commit is contained in:
OleehyO 2024-12-27 09:59:49 +00:00
parent 918ebb5a54
commit 85e00a1082
5 changed files with 161 additions and 0 deletions

View File

@ -0,0 +1,12 @@
import importlib
from pathlib import Path
package_dir = Path(__file__).parent
for subdir in package_dir.iterdir():
if subdir.is_dir() and not subdir.name.startswith('_'):
for module_path in subdir.glob('*.py'):
module_name = module_path.stem
full_module_name = f".{subdir.name}.{module_name}"
importlib.import_module(full_module_name, package=__name__)

View File

@ -0,0 +1,29 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List
from finetune.trainer import Trainer
from ..utils import register
class CogVideoX1dot5I2VLoraTrainer(Trainer):
@override
def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
raise NotImplementedError
@override
def load_components(self) -> Dict[str, Any]:
raise NotImplementedError
@override
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
@override
def validate(self) -> None:
raise NotImplementedError
register("cogvideox1.5-i2v", "lora", CogVideoX1dot5I2VLoraTrainer)

View File

@ -0,0 +1,29 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List
from finetune.trainer import Trainer
from ..utils import register
class CogVideoX1dot5T2VLoraTrainer(Trainer):
@override
def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
raise NotImplementedError
@override
def load_components(self) -> Dict[str, Any]:
raise NotImplementedError
@override
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
@override
def validate(self) -> None:
raise NotImplementedError
register("cogvideox1.5-t2v", "lora", CogVideoX1dot5T2VLoraTrainer)

View File

@ -0,0 +1,29 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List
from finetune.trainer import Trainer
from ..utils import register
class CogVideoXI2VLoraTrainer(Trainer):
@override
def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
raise NotImplementedError
@override
def load_components(self) -> Dict[str, Any]:
raise NotImplementedError
@override
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
@override
def validate(self) -> None:
raise NotImplementedError
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)

62
finetune/models/utils.py Normal file
View File

@ -0,0 +1,62 @@
from typing import Literal, Dict
from finetune.trainer import Trainer
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
"""Register a model and its associated functions for a specific training type.
Args:
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
trainer_cls (Trainer): Trainer class to register.
"""
# Check if model_name exists in SUPPORTED_MODELS
if model_name not in SUPPORTED_MODELS:
SUPPORTED_MODELS[model_name] = {}
else:
raise ValueError(f"Model {model_name} already exists")
# Check if training_type exists for this model
if training_type not in SUPPORTED_MODELS[model_name]:
SUPPORTED_MODELS[model_name][training_type] = {}
else:
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
def show_supported_models():
"""Print all currently supported models and their training types."""
print("\nSupported Models:")
print("================")
for model_name, training_types in SUPPORTED_MODELS.items():
print(f"\n{model_name}")
print("-" * len(model_name))
for training_type in training_types:
print(f"{training_type}")
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
"""Get the trainer class for a specific model and training type."""
if model_type not in SUPPORTED_MODELS:
print(f"\nModel '{model_type}' is not supported.")
print("\nSupported models are:")
for supported_model in SUPPORTED_MODELS:
print(f"{supported_model}")
raise ValueError(f"Model '{model_type}' is not supported")
if training_type not in SUPPORTED_MODELS[model_type]:
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
print(f"\nSupported training types for '{model_type}' are:")
for supported_type in SUPPORTED_MODELS[model_type]:
print(f"{supported_type}")
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
return SUPPORTED_MODELS[model_type][training_type]