mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
feat(models): add scaffolding
This commit is contained in:
parent
918ebb5a54
commit
85e00a1082
12
finetune/models/__init__.py
Normal file
12
finetune/models/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
|
||||
for subdir in package_dir.iterdir():
|
||||
if subdir.is_dir() and not subdir.name.startswith('_'):
|
||||
for module_path in subdir.glob('*.py'):
|
||||
module_name = module_path.stem
|
||||
full_module_name = f".{subdir.name}.{module_name}"
|
||||
importlib.import_module(full_module_name, package=__name__)
|
29
finetune/models/cogvideox1dot5_i2v/lora_trainer.py
Normal file
29
finetune/models/cogvideox1dot5_i2v/lora_trainer.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from finetune.trainer import Trainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1dot5I2VLoraTrainer(Trainer):
|
||||
|
||||
@override
|
||||
def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_components(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
register("cogvideox1.5-i2v", "lora", CogVideoX1dot5I2VLoraTrainer)
|
29
finetune/models/cogvideox1dot5_t2v/lora_trainer.py
Normal file
29
finetune/models/cogvideox1dot5_t2v/lora_trainer.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from finetune.trainer import Trainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1dot5T2VLoraTrainer(Trainer):
|
||||
|
||||
@override
|
||||
def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_components(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
register("cogvideox1.5-t2v", "lora", CogVideoX1dot5T2VLoraTrainer)
|
29
finetune/models/cogvideox_i2v/lora_trainer.py
Normal file
29
finetune/models/cogvideox_i2v/lora_trainer.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from finetune.trainer import Trainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXI2VLoraTrainer(Trainer):
|
||||
|
||||
@override
|
||||
def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_components(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
|
62
finetune/models/utils.py
Normal file
62
finetune/models/utils.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Literal, Dict
|
||||
|
||||
from finetune.trainer import Trainer
|
||||
|
||||
|
||||
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
|
||||
|
||||
|
||||
def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
|
||||
"""Register a model and its associated functions for a specific training type.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
|
||||
training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
|
||||
trainer_cls (Trainer): Trainer class to register.
|
||||
"""
|
||||
|
||||
# Check if model_name exists in SUPPORTED_MODELS
|
||||
if model_name not in SUPPORTED_MODELS:
|
||||
SUPPORTED_MODELS[model_name] = {}
|
||||
else:
|
||||
raise ValueError(f"Model {model_name} already exists")
|
||||
|
||||
# Check if training_type exists for this model
|
||||
if training_type not in SUPPORTED_MODELS[model_name]:
|
||||
SUPPORTED_MODELS[model_name][training_type] = {}
|
||||
else:
|
||||
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
||||
|
||||
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
||||
|
||||
|
||||
def show_supported_models():
|
||||
"""Print all currently supported models and their training types."""
|
||||
|
||||
print("\nSupported Models:")
|
||||
print("================")
|
||||
|
||||
for model_name, training_types in SUPPORTED_MODELS.items():
|
||||
print(f"\n{model_name}")
|
||||
print("-" * len(model_name))
|
||||
for training_type in training_types:
|
||||
print(f" • {training_type}")
|
||||
|
||||
|
||||
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
|
||||
"""Get the trainer class for a specific model and training type."""
|
||||
if model_type not in SUPPORTED_MODELS:
|
||||
print(f"\nModel '{model_type}' is not supported.")
|
||||
print("\nSupported models are:")
|
||||
for supported_model in SUPPORTED_MODELS:
|
||||
print(f" • {supported_model}")
|
||||
raise ValueError(f"Model '{model_type}' is not supported")
|
||||
|
||||
if training_type not in SUPPORTED_MODELS[model_type]:
|
||||
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
|
||||
print(f"\nSupported training types for '{model_type}' are:")
|
||||
for supported_type in SUPPORTED_MODELS[model_type]:
|
||||
print(f" • {supported_type}")
|
||||
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
|
||||
|
||||
return SUPPORTED_MODELS[model_type][training_type]
|
Loading…
x
Reference in New Issue
Block a user