From 85e00a1082c041904691551cc33800b6ab4502c7 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 27 Dec 2024 09:59:49 +0000 Subject: [PATCH] feat(models): add scaffolding --- finetune/models/__init__.py | 12 ++++ .../models/cogvideox1dot5_i2v/lora_trainer.py | 29 +++++++++ .../models/cogvideox1dot5_t2v/lora_trainer.py | 29 +++++++++ finetune/models/cogvideox_i2v/lora_trainer.py | 29 +++++++++ finetune/models/utils.py | 62 +++++++++++++++++++ 5 files changed, 161 insertions(+) create mode 100644 finetune/models/__init__.py create mode 100644 finetune/models/cogvideox1dot5_i2v/lora_trainer.py create mode 100644 finetune/models/cogvideox1dot5_t2v/lora_trainer.py create mode 100644 finetune/models/cogvideox_i2v/lora_trainer.py create mode 100644 finetune/models/utils.py diff --git a/finetune/models/__init__.py b/finetune/models/__init__.py new file mode 100644 index 0000000..b315ff5 --- /dev/null +++ b/finetune/models/__init__.py @@ -0,0 +1,12 @@ +import importlib +from pathlib import Path + + +package_dir = Path(__file__).parent + +for subdir in package_dir.iterdir(): + if subdir.is_dir() and not subdir.name.startswith('_'): + for module_path in subdir.glob('*.py'): + module_name = module_path.stem + full_module_name = f".{subdir.name}.{module_name}" + importlib.import_module(full_module_name, package=__name__) diff --git a/finetune/models/cogvideox1dot5_i2v/lora_trainer.py b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py new file mode 100644 index 0000000..6ef9dd4 --- /dev/null +++ b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py @@ -0,0 +1,29 @@ +import torch + +from typing_extensions import override +from typing import Any, Dict, List + +from finetune.trainer import Trainer +from ..utils import register + + +class CogVideoX1dot5I2VLoraTrainer(Trainer): + + @override + def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + raise NotImplementedError + + @override + def load_components(self) -> Dict[str, Any]: + raise NotImplementedError + + @override + def compute_loss(self, batch) -> torch.Tensor: + raise NotImplementedError + + @override + def validate(self) -> None: + raise NotImplementedError + + +register("cogvideox1.5-i2v", "lora", CogVideoX1dot5I2VLoraTrainer) diff --git a/finetune/models/cogvideox1dot5_t2v/lora_trainer.py b/finetune/models/cogvideox1dot5_t2v/lora_trainer.py new file mode 100644 index 0000000..dfc2a78 --- /dev/null +++ b/finetune/models/cogvideox1dot5_t2v/lora_trainer.py @@ -0,0 +1,29 @@ +import torch + +from typing_extensions import override +from typing import Any, Dict, List + +from finetune.trainer import Trainer +from ..utils import register + + +class CogVideoX1dot5T2VLoraTrainer(Trainer): + + @override + def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + raise NotImplementedError + + @override + def load_components(self) -> Dict[str, Any]: + raise NotImplementedError + + @override + def compute_loss(self, batch) -> torch.Tensor: + raise NotImplementedError + + @override + def validate(self) -> None: + raise NotImplementedError + + +register("cogvideox1.5-t2v", "lora", CogVideoX1dot5T2VLoraTrainer) diff --git a/finetune/models/cogvideox_i2v/lora_trainer.py b/finetune/models/cogvideox_i2v/lora_trainer.py new file mode 100644 index 0000000..d625f18 --- /dev/null +++ b/finetune/models/cogvideox_i2v/lora_trainer.py @@ -0,0 +1,29 @@ +import torch + +from typing_extensions import override +from typing import Any, Dict, List + +from finetune.trainer import Trainer +from ..utils import register + + +class CogVideoXI2VLoraTrainer(Trainer): + + @override + def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + raise NotImplementedError + + @override + def load_components(self) -> Dict[str, Any]: + raise NotImplementedError + + @override + def compute_loss(self, batch) -> torch.Tensor: + raise NotImplementedError + + @override + def validate(self) -> None: + raise NotImplementedError + + +register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer) \ No newline at end of file diff --git a/finetune/models/utils.py b/finetune/models/utils.py new file mode 100644 index 0000000..fd5a455 --- /dev/null +++ b/finetune/models/utils.py @@ -0,0 +1,62 @@ +from typing import Literal, Dict + +from finetune.trainer import Trainer + + +SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {} + + +def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer): + """Register a model and its associated functions for a specific training type. + + Args: + model_name (str): Name of the model to register (e.g. "cogvideox-5b") + training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft" + trainer_cls (Trainer): Trainer class to register. + """ + + # Check if model_name exists in SUPPORTED_MODELS + if model_name not in SUPPORTED_MODELS: + SUPPORTED_MODELS[model_name] = {} + else: + raise ValueError(f"Model {model_name} already exists") + + # Check if training_type exists for this model + if training_type not in SUPPORTED_MODELS[model_name]: + SUPPORTED_MODELS[model_name][training_type] = {} + else: + raise ValueError(f"Training type {training_type} already exists for model {model_name}") + + SUPPORTED_MODELS[model_name][training_type] = trainer_cls + + +def show_supported_models(): + """Print all currently supported models and their training types.""" + + print("\nSupported Models:") + print("================") + + for model_name, training_types in SUPPORTED_MODELS.items(): + print(f"\n{model_name}") + print("-" * len(model_name)) + for training_type in training_types: + print(f" • {training_type}") + + +def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer: + """Get the trainer class for a specific model and training type.""" + if model_type not in SUPPORTED_MODELS: + print(f"\nModel '{model_type}' is not supported.") + print("\nSupported models are:") + for supported_model in SUPPORTED_MODELS: + print(f" • {supported_model}") + raise ValueError(f"Model '{model_type}' is not supported") + + if training_type not in SUPPORTED_MODELS[model_type]: + print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.") + print(f"\nSupported training types for '{model_type}' are:") + for supported_type in SUPPORTED_MODELS[model_type]: + print(f" • {supported_type}") + raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'") + + return SUPPORTED_MODELS[model_type][training_type]