format and check fp16 for cogvideox2b

This commit is contained in:
zR 2025-01-07 13:16:18 +08:00
parent 1b886326b2
commit 1789f07256
15 changed files with 166 additions and 201 deletions

View File

@ -8,5 +8,5 @@ __all__ = [
"I2VDatasetWithBuckets",
"T2VDatasetWithResize",
"T2VDatasetWithBuckets",
"BucketSampler"
"BucketSampler",
]

View File

@ -37,7 +37,6 @@ class BucketSampler(Sampler):
self._raised_warning_for_drop_last = False
def __len__(self):
if self.drop_last and not self._raised_warning_for_drop_last:
self._raised_warning_for_drop_last = True
@ -46,7 +45,6 @@ class BucketSampler(Sampler):
)
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
def __iter__(self):
for index, data in enumerate(self.data_source):
video_metadata = data["video_metadata"]

View File

@ -13,11 +13,12 @@ from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
load_prompts, load_videos, load_images,
load_prompts,
load_videos,
load_images,
preprocess_image_with_resize,
preprocess_video_with_resize,
preprocess_video_with_buckets
preprocess_video_with_buckets,
)
if TYPE_CHECKING:
@ -46,6 +47,7 @@ class BaseI2VDataset(Dataset):
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
data_root: str,
@ -55,7 +57,7 @@ class BaseI2VDataset(Dataset):
device: torch.device,
trainer: "Trainer" = None,
*args,
**kwargs
**kwargs,
) -> None:
super().__init__()
@ -120,7 +122,10 @@ class BaseI2VDataset(Dataset):
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
@ -235,6 +240,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
height (int): Target height for resizing videos and images
width (int): Target width for resizing videos and images
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -242,11 +248,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
self.height = height
self.width = width
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__image_transforms = self.__frame_transforms
@override
@ -271,14 +273,14 @@ class I2VDatasetWithResize(BaseI2VDataset):
class I2VDatasetWithBuckets(BaseI2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
@ -290,11 +292,7 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
)
for b in video_resolution_buckets
]
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__image_transforms = self.__frame_transforms
@override

View File

@ -12,11 +12,7 @@ from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
load_prompts, load_videos,
preprocess_video_with_resize,
preprocess_video_with_buckets
)
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
if TYPE_CHECKING:
from finetune.trainer import Trainer
@ -52,7 +48,7 @@ class BaseT2VDataset(Dataset):
device: torch.device = None,
trainer: "Trainer" = None,
*args,
**kwargs
**kwargs,
) -> None:
super().__init__()
@ -108,7 +104,10 @@ class BaseT2VDataset(Dataset):
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
@ -203,16 +202,15 @@ class T2VDatasetWithResize(BaseT2VDataset):
self.height = height
self.width = width
self.__frame_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_resize(
video_path, self.max_num_frames, self.height, self.width,
video_path,
self.max_num_frames,
self.height,
self.width,
)
@override
@ -221,18 +219,16 @@ class T2VDatasetWithResize(BaseT2VDataset):
class T2VDatasetWithBuckets(BaseT2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
*args,
**kwargs,
) -> None:
"""
"""
""" """
super().__init__(*args, **kwargs)
self.video_resolution_buckets = [
@ -244,17 +240,11 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
for b in video_resolution_buckets
]
self.__frame_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_buckets(
video_path, self.video_resolution_buckets
)
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:

View File

@ -15,6 +15,7 @@ decord.bridge.set_bridge("torch")
########## loaders ##########
def load_prompts(prompt_path: Path) -> List[str]:
with open(prompt_path, "r", encoding="utf-8") as file:
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
@ -32,6 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
########## preprocessors ##########
def preprocess_image_with_resize(
image_path: Path | str,
height: int,
@ -96,7 +98,7 @@ def preprocess_video_with_resize(
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices)
frames = frames[: max_num_frames].float()
frames = frames[:max_num_frames].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
return frames

View File

@ -5,8 +5,8 @@ from pathlib import Path
package_dir = Path(__file__).parent
for subdir in package_dir.iterdir():
if subdir.is_dir() and not subdir.name.startswith('_'):
for module_path in subdir.glob('*.py'):
if subdir.is_dir() and not subdir.name.startswith("_"):
for module_path in subdir.glob("*.py"):
module_name = module_path.stem
full_module_name = f".{subdir.name}.{module_name}"
importlib.import_module(full_module_name, package=__name__)

View File

@ -30,25 +30,15 @@ class CogVideoXI2VLoraTrainer(Trainer):
components.pipeline_cls = CogVideoXImageToVideoPipeline
components.tokenizer = AutoTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
components.vae = AutoencoderKLCogVideoX.from_pretrained(
model_path, subfolder="vae"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
return components
@ -59,7 +49,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
scheduler=self.components.scheduler,
)
return pipe
@ -88,11 +78,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_embedding": [],
"images": []
}
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
for sample in samples:
encoded_video = sample["encoded_video"]
@ -144,8 +130,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps,
(batch_size,), device=self.accelerator.device
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
)
timesteps = timesteps.long()
@ -183,7 +168,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
)
# Predict noise
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
ofs_emb = (
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
)
predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy,
encoder_hidden_states=prompt_embedding,
@ -222,7 +209,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
width=self.state.train_width,
prompt=prompt,
image=image,
generator=self.state.generator
generator=self.state.generator,
).frames[0]
return [("video", video_generate)]
@ -233,7 +220,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)

View File

@ -31,25 +31,15 @@ class CogVideoXT2VLoraTrainer(Trainer):
components.pipeline_cls = CogVideoXPipeline
components.tokenizer = AutoTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
components.vae = AutoencoderKLCogVideoX.from_pretrained(
model_path, subfolder="vae"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
return components
@ -60,7 +50,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
scheduler=self.components.scheduler,
)
return pipe
@ -89,10 +79,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_embedding": []
}
ret = {"encoded_videos": [], "prompt_embedding": []}
for sample in samples:
encoded_video = sample["encoded_video"]
@ -116,10 +103,20 @@ class CogVideoXT2VLoraTrainer(Trainer):
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
raise ValueError(
"Number of frames in latent must be divisible by patch size, please check your args for training."
)
# Add 2 random noise frames at the beginning of frame dimension
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
noise_frames = torch.randn(
latent.shape[0],
latent.shape[1],
2,
latent.shape[3],
latent.shape[4],
device=latent.device,
dtype=latent.dtype,
)
latent = torch.cat([noise_frames, latent], dim=2)
batch_size, num_channels, num_frames, height, width = latent.shape
@ -130,8 +127,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps,
(batch_size,), device=self.accelerator.device
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
)
timesteps = timesteps.long()
@ -193,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,
generator=self.state.generator
generator=self.state.generator,
).frames[0]
return [("video", video_generate)]
@ -204,7 +200,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)

View File

@ -150,12 +150,13 @@ class Args(BaseModel):
return v
except ValueError as e:
if str(e) == "not enough values to unpack (expected 3, got 0)" or \
str(e) == "invalid literal for int() with base 10":
if (
str(e) == "not enough values to unpack (expected 3, got 0)"
or str(e) == "invalid literal for int() with base 10"
):
raise ValueError("train_resolution must be in format 'frames x height x width'")
raise e
@classmethod
def parse_args(cls):
"""Parse command line arguments and return Args instance"""
@ -208,8 +209,7 @@ class Args(BaseModel):
# LoRA parameters
parser.add_argument("--rank", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument("--target_modules", type=str, nargs="+",
default=["to_q", "to_k", "to_v", "to_out.0"])
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
# Checkpointing
parser.add_argument("--checkpointing_steps", type=int, default=200)

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import List, Dict, Any
from pydantic import BaseModel, field_validator
class State(BaseModel):
model_config = {"arbitrary_types_allowed": True}

View File

@ -3,11 +3,15 @@ import os
from pathlib import Path
import cv2
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory")
parser.add_argument(
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
)
return parser.parse_args()
args = parse_args()
# Create data/images directory if it doesn't exist

View File

@ -1,4 +1,3 @@
import os
import logging
import math
import json
@ -32,24 +31,23 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
from finetune.schemas import Args, State, Components
from finetune.utils import (
unwrap_model, cast_training_params,
unwrap_model,
cast_training_params,
get_optimizer,
get_memory_statistics,
free_memory,
unload_model,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
string_to_filename
string_to_filename,
)
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
from finetune.datasets.utils import (
load_prompts, load_images, load_videos,
preprocess_image_with_resize, preprocess_video_with_resize
load_prompts,
load_images,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_resize,
)
from finetune.constants import LOG_NAME, LOG_LEVEL
@ -59,7 +57,7 @@ logger = get_logger(LOG_NAME, LOG_LEVEL)
_DTYPE_MAP = {
"fp32": torch.float32,
"fp16": torch.float16,
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
"bf16": torch.bfloat16,
}
@ -74,7 +72,7 @@ class Trainer:
weight_dtype=self.__get_training_dtype(),
train_frames=self.args.train_resolution[0],
train_height=self.args.train_resolution[1],
train_width=self.args.train_resolution[2]
train_width=self.args.train_resolution[2],
)
self.components = Components()
@ -140,7 +138,9 @@ class Trainer:
def check_setting(self) -> None:
# Check for unload_list
if self.UNLOAD_LIST is None:
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m")
logger.warning(
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
)
else:
for name in self.UNLOAD_LIST:
if name not in self.components.model_fields:
@ -174,7 +174,7 @@ class Trainer:
max_num_frames=sample_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self
trainer=self,
)
elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize(
@ -183,7 +183,7 @@ class Trainer:
max_num_frames=sample_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self
trainer=self,
)
else:
raise ValueError(f"Invalid model type: {self.args.model_type}")
@ -204,7 +204,8 @@ class Trainer:
pin_memory=self.args.pin_memory,
)
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader: ...
for _ in tmp_data_loader:
...
self.accelerator.wait_for_everyone()
logger.info("Precomputing latent for video and prompt embedding ... Done")
@ -218,16 +219,15 @@ class Trainer:
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
shuffle=True
shuffle=True,
)
def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters")
# For now only lora is supported
for attr_name, component in vars(self.components).items():
if hasattr(component, 'requires_grad_'):
if hasattr(component, "requires_grad_"):
component.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
@ -452,10 +452,7 @@ class Trainer:
progress_bar.set_postfix(logs)
# Maybe run validation
should_run_validation = (
self.args.do_validation
and global_step % self.args.validation_steps == 0
)
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
if should_run_validation:
del loss
free_memory()
@ -520,9 +517,7 @@ class Trainer:
video = self.state.validation_videos[i]
if image is not None:
image = preprocess_image_with_resize(
image, self.state.train_height, self.state.train_width
)
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
# Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy()
@ -534,17 +529,13 @@ class Trainer:
)
# Convert video tensor (F, C, H, W) to list of PIL images
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.validation_step({
"prompt": prompt,
"image": image,
"video": video
}, pipe)
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
@ -663,7 +654,7 @@ class Trainer:
def __load_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
continue
# setattr(self.components, name, component.to(self.accelerator.device))
@ -672,9 +663,9 @@ class Trainer:
def __unload_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
setattr(self.components, name, component.to('cpu'))
setattr(self.components, name, component.to("cpu"))
def __prepare_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@ -711,9 +702,7 @@ class Trainer:
):
transformer_ = unwrap_model(self.accelerator, model)
else:
raise ValueError(
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
)
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
else:
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
self.args.model_path, subfolder="transformer"