format and check fp16 for cogvideox2b

This commit is contained in:
zR 2025-01-07 13:16:18 +08:00
parent 1b886326b2
commit 1789f07256
15 changed files with 166 additions and 201 deletions

View File

@ -8,5 +8,5 @@ __all__ = [
"I2VDatasetWithBuckets", "I2VDatasetWithBuckets",
"T2VDatasetWithResize", "T2VDatasetWithResize",
"T2VDatasetWithBuckets", "T2VDatasetWithBuckets",
"BucketSampler" "BucketSampler",
] ]

View File

@ -37,7 +37,6 @@ class BucketSampler(Sampler):
self._raised_warning_for_drop_last = False self._raised_warning_for_drop_last = False
def __len__(self): def __len__(self):
if self.drop_last and not self._raised_warning_for_drop_last: if self.drop_last and not self._raised_warning_for_drop_last:
self._raised_warning_for_drop_last = True self._raised_warning_for_drop_last = True
@ -46,7 +45,6 @@ class BucketSampler(Sampler):
) )
return (len(self.data_source) + self.batch_size - 1) // self.batch_size return (len(self.data_source) + self.batch_size - 1) // self.batch_size
def __iter__(self): def __iter__(self):
for index, data in enumerate(self.data_source): for index, data in enumerate(self.data_source):
video_metadata = data["video_metadata"] video_metadata = data["video_metadata"]

View File

@ -13,11 +13,12 @@ from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import ( from .utils import (
load_prompts, load_videos, load_images, load_prompts,
load_videos,
load_images,
preprocess_image_with_resize, preprocess_image_with_resize,
preprocess_video_with_resize, preprocess_video_with_resize,
preprocess_video_with_buckets preprocess_video_with_buckets,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -46,6 +47,7 @@ class BaseI2VDataset(Dataset):
device (torch.device): Device to load the data on device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
""" """
def __init__( def __init__(
self, self,
data_root: str, data_root: str,
@ -55,7 +57,7 @@ class BaseI2VDataset(Dataset):
device: torch.device, device: torch.device,
trainer: "Trainer" = None, trainer: "Trainer" = None,
*args, *args,
**kwargs **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -120,7 +122,10 @@ class BaseI2VDataset(Dataset):
if prompt_embedding_path.exists(): if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"] prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False) logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else: else:
prompt_embedding = self.encode_text(prompt) prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu") prompt_embedding = prompt_embedding.to("cpu")
@ -235,6 +240,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
height (int): Target height for resizing videos and images height (int): Target height for resizing videos and images
width (int): Target width for resizing videos and images width (int): Target width for resizing videos and images
""" """
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -242,11 +248,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
self.height = height self.height = height
self.width = width self.width = width
self.__frame_transforms = transforms.Compose( self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__image_transforms = self.__frame_transforms self.__image_transforms = self.__frame_transforms
@override @override
@ -271,14 +273,14 @@ class I2VDatasetWithResize(BaseI2VDataset):
class I2VDatasetWithBuckets(BaseI2VDataset): class I2VDatasetWithBuckets(BaseI2VDataset):
def __init__( def __init__(
self, self,
video_resolution_buckets: List[Tuple[int, int, int]], video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int, vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int, vae_height_compression_ratio: int,
vae_width_compression_ratio: int, vae_width_compression_ratio: int,
*args, **kwargs *args,
**kwargs,
) -> None: ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -290,11 +292,7 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
) )
for b in video_resolution_buckets for b in video_resolution_buckets
] ]
self.__frame_transforms = transforms.Compose( self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__image_transforms = self.__frame_transforms self.__image_transforms = self.__frame_transforms
@override @override

View File

@ -12,11 +12,7 @@ from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import ( from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
load_prompts, load_videos,
preprocess_video_with_resize,
preprocess_video_with_buckets
)
if TYPE_CHECKING: if TYPE_CHECKING:
from finetune.trainer import Trainer from finetune.trainer import Trainer
@ -52,7 +48,7 @@ class BaseT2VDataset(Dataset):
device: torch.device = None, device: torch.device = None,
trainer: "Trainer" = None, trainer: "Trainer" = None,
*args, *args,
**kwargs **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -108,7 +104,10 @@ class BaseT2VDataset(Dataset):
if prompt_embedding_path.exists(): if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"] prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False) logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else: else:
prompt_embedding = self.encode_text(prompt) prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu") prompt_embedding = prompt_embedding.to("cpu")
@ -203,16 +202,15 @@ class T2VDatasetWithResize(BaseT2VDataset):
self.height = height self.height = height
self.width = width self.width = width
self.__frame_transform = transforms.Compose( self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
@override @override
def preprocess(self, video_path: Path) -> torch.Tensor: def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_resize( return preprocess_video_with_resize(
video_path, self.max_num_frames, self.height, self.width, video_path,
self.max_num_frames,
self.height,
self.width,
) )
@override @override
@ -221,18 +219,16 @@ class T2VDatasetWithResize(BaseT2VDataset):
class T2VDatasetWithBuckets(BaseT2VDataset): class T2VDatasetWithBuckets(BaseT2VDataset):
def __init__( def __init__(
self, self,
video_resolution_buckets: List[Tuple[int, int, int]], video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int, vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int, vae_height_compression_ratio: int,
vae_width_compression_ratio: int, vae_width_compression_ratio: int,
*args, **kwargs *args,
**kwargs,
) -> None: ) -> None:
""" """ """
"""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.video_resolution_buckets = [ self.video_resolution_buckets = [
@ -244,17 +240,11 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
for b in video_resolution_buckets for b in video_resolution_buckets
] ]
self.__frame_transform = transforms.Compose( self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
@override @override
def preprocess(self, video_path: Path) -> torch.Tensor: def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_buckets( return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
video_path, self.video_resolution_buckets
)
@override @override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor: def video_transform(self, frames: torch.Tensor) -> torch.Tensor:

View File

@ -15,6 +15,7 @@ decord.bridge.set_bridge("torch")
########## loaders ########## ########## loaders ##########
def load_prompts(prompt_path: Path) -> List[str]: def load_prompts(prompt_path: Path) -> List[str]:
with open(prompt_path, "r", encoding="utf-8") as file: with open(prompt_path, "r", encoding="utf-8") as file:
return [line.strip() for line in file.readlines() if len(line.strip()) > 0] return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
@ -32,6 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
########## preprocessors ########## ########## preprocessors ##########
def preprocess_image_with_resize( def preprocess_image_with_resize(
image_path: Path | str, image_path: Path | str,
height: int, height: int,
@ -96,7 +98,7 @@ def preprocess_video_with_resize(
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices) frames = video_reader.get_batch(indices)
frames = frames[: max_num_frames].float() frames = frames[:max_num_frames].float()
frames = frames.permute(0, 3, 1, 2).contiguous() frames = frames.permute(0, 3, 1, 2).contiguous()
return frames return frames

View File

@ -5,8 +5,8 @@ from pathlib import Path
package_dir = Path(__file__).parent package_dir = Path(__file__).parent
for subdir in package_dir.iterdir(): for subdir in package_dir.iterdir():
if subdir.is_dir() and not subdir.name.startswith('_'): if subdir.is_dir() and not subdir.name.startswith("_"):
for module_path in subdir.glob('*.py'): for module_path in subdir.glob("*.py"):
module_name = module_path.stem module_name = module_path.stem
full_module_name = f".{subdir.name}.{module_name}" full_module_name = f".{subdir.name}.{module_name}"
importlib.import_module(full_module_name, package=__name__) importlib.import_module(full_module_name, package=__name__)

View File

@ -30,25 +30,15 @@ class CogVideoXI2VLoraTrainer(Trainer):
components.pipeline_cls = CogVideoXImageToVideoPipeline components.pipeline_cls = CogVideoXImageToVideoPipeline
components.tokenizer = AutoTokenizer.from_pretrained( components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
model_path, subfolder="tokenizer"
)
components.text_encoder = T5EncoderModel.from_pretrained( components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained( components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained( components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
model_path, subfolder="vae"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained( components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
model_path, subfolder="scheduler"
)
return components return components
@ -59,7 +49,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
text_encoder=self.components.text_encoder, text_encoder=self.components.text_encoder,
vae=self.components.vae, vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer), transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler scheduler=self.components.scheduler,
) )
return pipe return pipe
@ -88,11 +78,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
@override @override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = { ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
"encoded_videos": [],
"prompt_embedding": [],
"images": []
}
for sample in samples: for sample in samples:
encoded_video = sample["encoded_video"] encoded_video = sample["encoded_video"]
@ -144,8 +130,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
# Sample a random timestep for each sample # Sample a random timestep for each sample
timesteps = torch.randint( timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps, 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
(batch_size,), device=self.accelerator.device
) )
timesteps = timesteps.long() timesteps = timesteps.long()
@ -183,7 +168,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
) )
# Predict noise # Predict noise
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0) ofs_emb = (
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
)
predicted_noise = self.components.transformer( predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy, hidden_states=latent_img_noisy,
encoder_hidden_states=prompt_embedding, encoder_hidden_states=prompt_embedding,
@ -222,7 +209,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
width=self.state.train_width, width=self.state.train_width,
prompt=prompt, prompt=prompt,
image=image, image=image,
generator=self.state.generator generator=self.state.generator,
).frames[0] ).frames[0]
return [("video", video_generate)] return [("video", video_generate)]
@ -233,7 +220,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
num_frames: int, num_frames: int,
transformer_config: Dict, transformer_config: Dict,
vae_scale_factor_spatial: int, vae_scale_factor_spatial: int,
device: torch.device device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size) grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size) grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)

View File

@ -31,25 +31,15 @@ class CogVideoXT2VLoraTrainer(Trainer):
components.pipeline_cls = CogVideoXPipeline components.pipeline_cls = CogVideoXPipeline
components.tokenizer = AutoTokenizer.from_pretrained( components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
model_path, subfolder="tokenizer"
)
components.text_encoder = T5EncoderModel.from_pretrained( components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained( components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained( components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
model_path, subfolder="vae"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained( components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
model_path, subfolder="scheduler"
)
return components return components
@ -60,7 +50,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
text_encoder=self.components.text_encoder, text_encoder=self.components.text_encoder,
vae=self.components.vae, vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer), transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler scheduler=self.components.scheduler,
) )
return pipe return pipe
@ -89,10 +79,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
@override @override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = { ret = {"encoded_videos": [], "prompt_embedding": []}
"encoded_videos": [],
"prompt_embedding": []
}
for sample in samples: for sample in samples:
encoded_video = sample["encoded_video"] encoded_video = sample["encoded_video"]
@ -116,10 +103,20 @@ class CogVideoXT2VLoraTrainer(Trainer):
patch_size_t = self.state.transformer_config.patch_size_t patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0: if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.") raise ValueError(
"Number of frames in latent must be divisible by patch size, please check your args for training."
)
# Add 2 random noise frames at the beginning of frame dimension # Add 2 random noise frames at the beginning of frame dimension
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype) noise_frames = torch.randn(
latent.shape[0],
latent.shape[1],
2,
latent.shape[3],
latent.shape[4],
device=latent.device,
dtype=latent.dtype,
)
latent = torch.cat([noise_frames, latent], dim=2) latent = torch.cat([noise_frames, latent], dim=2)
batch_size, num_channels, num_frames, height, width = latent.shape batch_size, num_channels, num_frames, height, width = latent.shape
@ -130,8 +127,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
# Sample a random timestep for each sample # Sample a random timestep for each sample
timesteps = torch.randint( timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps, 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
(batch_size,), device=self.accelerator.device
) )
timesteps = timesteps.long() timesteps = timesteps.long()
@ -193,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
height=self.state.train_height, height=self.state.train_height,
width=self.state.train_width, width=self.state.train_width,
prompt=prompt, prompt=prompt,
generator=self.state.generator generator=self.state.generator,
).frames[0] ).frames[0]
return [("video", video_generate)] return [("video", video_generate)]
@ -204,7 +200,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
num_frames: int, num_frames: int,
transformer_config: Dict, transformer_config: Dict,
vae_scale_factor_spatial: int, vae_scale_factor_spatial: int,
device: torch.device device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size) grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size) grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)

View File

@ -78,10 +78,10 @@ class Args(BaseModel):
########## Validation ########## ########## Validation ##########
do_validation: bool = False do_validation: bool = False
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
validation_dir: Path | None # if set do_validation, should not be None validation_dir: Path | None # if set do_validation, should not be None
validation_prompts: str | None # if set do_validation, should not be None validation_prompts: str | None # if set do_validation, should not be None
validation_images: str | None # if set do_validation and model_type == i2v, should not be None validation_images: str | None # if set do_validation and model_type == i2v, should not be None
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
gen_fps: int = 15 gen_fps: int = 15
#### deprecated args: gen_video_resolution #### deprecated args: gen_video_resolution
@ -150,12 +150,13 @@ class Args(BaseModel):
return v return v
except ValueError as e: except ValueError as e:
if str(e) == "not enough values to unpack (expected 3, got 0)" or \ if (
str(e) == "invalid literal for int() with base 10": str(e) == "not enough values to unpack (expected 3, got 0)"
or str(e) == "invalid literal for int() with base 10"
):
raise ValueError("train_resolution must be in format 'frames x height x width'") raise ValueError("train_resolution must be in format 'frames x height x width'")
raise e raise e
@classmethod @classmethod
def parse_args(cls): def parse_args(cls):
"""Parse command line arguments and return Args instance""" """Parse command line arguments and return Args instance"""
@ -208,8 +209,7 @@ class Args(BaseModel):
# LoRA parameters # LoRA parameters
parser.add_argument("--rank", type=int, default=128) parser.add_argument("--rank", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=64) parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument("--target_modules", type=str, nargs="+", parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
default=["to_q", "to_k", "to_v", "to_out.0"])
# Checkpointing # Checkpointing
parser.add_argument("--checkpointing_steps", type=int, default=200) parser.add_argument("--checkpointing_steps", type=int, default=200)

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import List, Dict, Any from typing import List, Dict, Any
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
class State(BaseModel): class State(BaseModel):
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}

View File

@ -3,11 +3,15 @@ import os
from pathlib import Path from pathlib import Path
import cv2 import cv2
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory") parser.add_argument(
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
)
return parser.parse_args() return parser.parse_args()
args = parse_args() args = parse_args()
# Create data/images directory if it doesn't exist # Create data/images directory if it doesn't exist

View File

@ -1,4 +1,3 @@
import os
import logging import logging
import math import math
import json import json
@ -32,24 +31,23 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
from finetune.schemas import Args, State, Components from finetune.schemas import Args, State, Components
from finetune.utils import ( from finetune.utils import (
unwrap_model, cast_training_params, unwrap_model,
cast_training_params,
get_optimizer, get_optimizer,
get_memory_statistics, get_memory_statistics,
free_memory, free_memory,
unload_model, unload_model,
get_latest_ckpt_path_to_resume_from, get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path, get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from, string_to_filename,
get_intermediate_ckpt_path,
string_to_filename
) )
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
from finetune.datasets.utils import ( from finetune.datasets.utils import (
load_prompts, load_images, load_videos, load_prompts,
preprocess_image_with_resize, preprocess_video_with_resize load_images,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_resize,
) )
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_NAME, LOG_LEVEL
@ -59,7 +57,7 @@ logger = get_logger(LOG_NAME, LOG_LEVEL)
_DTYPE_MAP = { _DTYPE_MAP = {
"fp32": torch.float32, "fp32": torch.float32,
"fp16": torch.float16, "fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
"bf16": torch.bfloat16, "bf16": torch.bfloat16,
} }
@ -69,12 +67,12 @@ class Trainer:
UNLOAD_LIST: List[str] = None UNLOAD_LIST: List[str] = None
def __init__(self, args: Args) -> None: def __init__(self, args: Args) -> None:
self.args = args self.args = args
self.state = State( self.state = State(
weight_dtype=self.__get_training_dtype(), weight_dtype=self.__get_training_dtype(),
train_frames=self.args.train_resolution[0], train_frames=self.args.train_resolution[0],
train_height=self.args.train_resolution[1], train_height=self.args.train_resolution[1],
train_width=self.args.train_resolution[2] train_width=self.args.train_resolution[2],
) )
self.components = Components() self.components = Components()
@ -140,7 +138,9 @@ class Trainer:
def check_setting(self) -> None: def check_setting(self) -> None:
# Check for unload_list # Check for unload_list
if self.UNLOAD_LIST is None: if self.UNLOAD_LIST is None:
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m") logger.warning(
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
)
else: else:
for name in self.UNLOAD_LIST: for name in self.UNLOAD_LIST:
if name not in self.components.model_fields: if name not in self.components.model_fields:
@ -174,7 +174,7 @@ class Trainer:
max_num_frames=sample_frames, max_num_frames=sample_frames,
height=self.state.train_height, height=self.state.train_height,
width=self.state.train_width, width=self.state.train_width,
trainer=self trainer=self,
) )
elif self.args.model_type == "t2v": elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize( self.dataset = T2VDatasetWithResize(
@ -183,7 +183,7 @@ class Trainer:
max_num_frames=sample_frames, max_num_frames=sample_frames,
height=self.state.train_height, height=self.state.train_height,
width=self.state.train_width, width=self.state.train_width,
trainer=self trainer=self,
) )
else: else:
raise ValueError(f"Invalid model type: {self.args.model_type}") raise ValueError(f"Invalid model type: {self.args.model_type}")
@ -204,7 +204,8 @@ class Trainer:
pin_memory=self.args.pin_memory, pin_memory=self.args.pin_memory,
) )
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader: ... for _ in tmp_data_loader:
...
self.accelerator.wait_for_everyone() self.accelerator.wait_for_everyone()
logger.info("Precomputing latent for video and prompt embedding ... Done") logger.info("Precomputing latent for video and prompt embedding ... Done")
@ -218,16 +219,15 @@ class Trainer:
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory, pin_memory=self.args.pin_memory,
shuffle=True shuffle=True,
) )
def prepare_trainable_parameters(self): def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters") logger.info("Initializing trainable parameters")
# For now only lora is supported # For now only lora is supported
for attr_name, component in vars(self.components).items(): for attr_name, component in vars(self.components).items():
if hasattr(component, 'requires_grad_'): if hasattr(component, "requires_grad_"):
component.requires_grad_(False) component.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
@ -452,10 +452,7 @@ class Trainer:
progress_bar.set_postfix(logs) progress_bar.set_postfix(logs)
# Maybe run validation # Maybe run validation
should_run_validation = ( should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
self.args.do_validation
and global_step % self.args.validation_steps == 0
)
if should_run_validation: if should_run_validation:
del loss del loss
free_memory() free_memory()
@ -520,9 +517,7 @@ class Trainer:
video = self.state.validation_videos[i] video = self.state.validation_videos[i]
if image is not None: if image is not None:
image = preprocess_image_with_resize( image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
image, self.state.train_height, self.state.train_width
)
# Convert image tensor (C, H, W) to PIL images # Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8) image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy() image = image.permute(1, 2, 0).cpu().numpy()
@ -534,17 +529,13 @@ class Trainer:
) )
# Convert video tensor (F, C, H, W) to list of PIL images # Convert video tensor (F, C, H, W) to list of PIL images
video = (video * 255).round().clamp(0, 255).to(torch.uint8) video = (video * 255).round().clamp(0, 255).to(torch.uint8)
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video] video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
logger.debug( logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False, main_process_only=False,
) )
validation_artifacts = self.validation_step({ validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
"prompt": prompt,
"image": image,
"video": video
}, pipe)
prompt_filename = string_to_filename(prompt)[:25] prompt_filename = string_to_filename(prompt)[:25]
artifacts = { artifacts = {
"image": {"type": "image", "value": image}, "image": {"type": "image", "value": image},
@ -663,7 +654,7 @@ class Trainer:
def __load_components(self): def __load_components(self):
components = self.components.model_dump() components = self.components.model_dump()
for name, component in components.items(): for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'): if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST: if name in self.UNLOAD_LIST:
continue continue
# setattr(self.components, name, component.to(self.accelerator.device)) # setattr(self.components, name, component.to(self.accelerator.device))
@ -672,9 +663,9 @@ class Trainer:
def __unload_components(self): def __unload_components(self):
components = self.components.model_dump() components = self.components.model_dump()
for name, component in components.items(): for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'): if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST: if name in self.UNLOAD_LIST:
setattr(self.components, name, component.to('cpu')) setattr(self.components, name, component.to("cpu"))
def __prepare_saving_loading_hooks(self, transformer_lora_config): def __prepare_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@ -711,9 +702,7 @@ class Trainer:
): ):
transformer_ = unwrap_model(self.accelerator, model) transformer_ = unwrap_model(self.accelerator, model)
else: else:
raise ValueError( raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
)
else: else:
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained( transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
self.args.model_path, subfolder="transformer" self.args.model_path, subfolder="transformer"