format and check fp16 for cogvideox2b

This commit is contained in:
zR 2025-01-07 13:16:18 +08:00
parent 1b886326b2
commit 1789f07256
15 changed files with 166 additions and 201 deletions

View File

@ -1,2 +1,2 @@
LOG_NAME = "trainer"
LOG_LEVEL = "INFO"
LOG_LEVEL = "INFO"

View File

@ -8,5 +8,5 @@ __all__ = [
"I2VDatasetWithBuckets",
"T2VDatasetWithResize",
"T2VDatasetWithBuckets",
"BucketSampler"
"BucketSampler",
]

View File

@ -37,7 +37,6 @@ class BucketSampler(Sampler):
self._raised_warning_for_drop_last = False
def __len__(self):
if self.drop_last and not self._raised_warning_for_drop_last:
self._raised_warning_for_drop_last = True
@ -46,7 +45,6 @@ class BucketSampler(Sampler):
)
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
def __iter__(self):
for index, data in enumerate(self.data_source):
video_metadata = data["video_metadata"]

View File

@ -13,11 +13,12 @@ from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
load_prompts, load_videos, load_images,
load_prompts,
load_videos,
load_images,
preprocess_image_with_resize,
preprocess_video_with_resize,
preprocess_video_with_buckets
preprocess_video_with_buckets,
)
if TYPE_CHECKING:
@ -46,6 +47,7 @@ class BaseI2VDataset(Dataset):
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
data_root: str,
@ -55,7 +57,7 @@ class BaseI2VDataset(Dataset):
device: torch.device,
trainer: "Trainer" = None,
*args,
**kwargs
**kwargs,
) -> None:
super().__init__()
@ -120,7 +122,10 @@ class BaseI2VDataset(Dataset):
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
@ -187,7 +192,7 @@ class BaseI2VDataset(Dataset):
- image(torch.Tensor) of shape [C, H, W]
"""
raise NotImplementedError("Subclass must implement this method")
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to a video.
@ -197,14 +202,14 @@ class BaseI2VDataset(Dataset):
with shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- H is height
- W is width
Returns:
torch.Tensor: The transformed video tensor
"""
raise NotImplementedError("Subclass must implement this method")
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to an image.
@ -213,7 +218,7 @@ class BaseI2VDataset(Dataset):
image (torch.Tensor): A 3D tensor representing an image
with shape [C, H, W] where:
- C is number of channels (3 for RGB)
- H is height
- H is height
- W is width
Returns:
@ -235,6 +240,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
height (int): Target height for resizing videos and images
width (int): Target width for resizing videos and images
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -242,11 +248,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
self.height = height
self.width = width
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__image_transforms = self.__frame_transforms
@override
@ -260,25 +262,25 @@ class I2VDatasetWithResize(BaseI2VDataset):
else:
image = None
return video, image
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)
class I2VDatasetWithBuckets(BaseI2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
@ -290,23 +292,19 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
)
for b in video_resolution_buckets
]
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__image_transforms = self.__frame_transforms
@override
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
return video, image
@override
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)

View File

@ -12,11 +12,7 @@ from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
load_prompts, load_videos,
preprocess_video_with_resize,
preprocess_video_with_buckets
)
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
if TYPE_CHECKING:
from finetune.trainer import Trainer
@ -52,7 +48,7 @@ class BaseT2VDataset(Dataset):
device: torch.device = None,
trainer: "Trainer" = None,
*args,
**kwargs
**kwargs,
) -> None:
super().__init__()
@ -108,7 +104,10 @@ class BaseT2VDataset(Dataset):
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
@ -164,7 +163,7 @@ class BaseT2VDataset(Dataset):
- W is width
"""
raise NotImplementedError("Subclass must implement this method")
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to a video.
@ -174,7 +173,7 @@ class BaseT2VDataset(Dataset):
with shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- H is height
- W is width
Returns:
@ -203,36 +202,33 @@ class T2VDatasetWithResize(BaseT2VDataset):
self.height = height
self.width = width
self.__frame_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_resize(
video_path, self.max_num_frames, self.height, self.width,
video_path,
self.max_num_frames,
self.height,
self.width,
)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
class T2VDatasetWithBuckets(BaseT2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
*args,
**kwargs,
) -> None:
"""
"""
""" """
super().__init__(*args, **kwargs)
self.video_resolution_buckets = [
@ -244,18 +240,12 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
for b in video_resolution_buckets
]
self.__frame_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_buckets(
video_path, self.video_resolution_buckets
)
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)

View File

@ -15,6 +15,7 @@ decord.bridge.set_bridge("torch")
########## loaders ##########
def load_prompts(prompt_path: Path) -> List[str]:
with open(prompt_path, "r", encoding="utf-8") as file:
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
@ -32,6 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
########## preprocessors ##########
def preprocess_image_with_resize(
image_path: Path | str,
height: int,
@ -96,7 +98,7 @@ def preprocess_video_with_resize(
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices)
frames = frames[: max_num_frames].float()
frames = frames[:max_num_frames].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
return frames
@ -144,4 +146,4 @@ def preprocess_video_with_buckets(
nearest_res = (nearest_res[1], nearest_res[2])
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
return frames
return frames

View File

@ -5,8 +5,8 @@ from pathlib import Path
package_dir = Path(__file__).parent
for subdir in package_dir.iterdir():
if subdir.is_dir() and not subdir.name.startswith('_'):
for module_path in subdir.glob('*.py'):
if subdir.is_dir() and not subdir.name.startswith("_"):
for module_path in subdir.glob("*.py"):
module_name = module_path.stem
full_module_name = f".{subdir.name}.{module_name}"
importlib.import_module(full_module_name, package=__name__)

View File

@ -30,28 +30,18 @@ class CogVideoXI2VLoraTrainer(Trainer):
components.pipeline_cls = CogVideoXImageToVideoPipeline
components.tokenizer = AutoTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
components.vae = AutoencoderKLCogVideoX.from_pretrained(
model_path, subfolder="vae"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
return components
@override
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
pipe = CogVideoXImageToVideoPipeline(
@ -59,7 +49,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
scheduler=self.components.scheduler,
)
return pipe
@ -71,7 +61,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@override
def encode_text(self, prompt: str) -> torch.Tensor:
prompt_token_ids = self.components.tokenizer(
@ -88,12 +78,8 @@ class CogVideoXI2VLoraTrainer(Trainer):
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_embedding": [],
"images": []
}
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt_embedding = sample["prompt_embedding"]
@ -102,13 +88,13 @@ class CogVideoXI2VLoraTrainer(Trainer):
ret["encoded_videos"].append(encoded_video)
ret["prompt_embedding"].append(prompt_embedding)
ret["images"].append(image)
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
ret["images"] = torch.stack(ret["images"])
return ret
@override
def compute_loss(self, batch) -> torch.Tensor:
prompt_embedding = batch["prompt_embedding"]
@ -144,8 +130,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps,
(batch_size,), device=self.accelerator.device
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
)
timesteps = timesteps.long()
@ -183,7 +168,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
)
# Predict noise
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
ofs_emb = (
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
)
predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy,
encoder_hidden_states=prompt_embedding,
@ -222,7 +209,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
width=self.state.train_width,
prompt=prompt,
image=image,
generator=self.state.generator
generator=self.state.generator,
).frames[0]
return [("video", video_generate)]
@ -233,7 +220,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
@ -256,4 +243,4 @@ class CogVideoXI2VLoraTrainer(Trainer):
return freqs_cos, freqs_sin
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)

View File

@ -31,25 +31,15 @@ class CogVideoXT2VLoraTrainer(Trainer):
components.pipeline_cls = CogVideoXPipeline
components.tokenizer = AutoTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
components.vae = AutoencoderKLCogVideoX.from_pretrained(
model_path, subfolder="vae"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
return components
@ -60,10 +50,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
scheduler=self.components.scheduler,
)
return pipe
@override
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W]
@ -86,21 +76,18 @@ class CogVideoXT2VLoraTrainer(Trainer):
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
return prompt_embedding
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_embedding": []
}
ret = {"encoded_videos": [], "prompt_embedding": []}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt_embedding = sample["prompt_embedding"]
ret["encoded_videos"].append(encoded_video)
ret["prompt_embedding"].append(prompt_embedding)
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
@ -116,10 +103,20 @@ class CogVideoXT2VLoraTrainer(Trainer):
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
raise ValueError(
"Number of frames in latent must be divisible by patch size, please check your args for training."
)
# Add 2 random noise frames at the beginning of frame dimension
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
noise_frames = torch.randn(
latent.shape[0],
latent.shape[1],
2,
latent.shape[3],
latent.shape[4],
device=latent.device,
dtype=latent.dtype,
)
latent = torch.cat([noise_frames, latent], dim=2)
batch_size, num_channels, num_frames, height, width = latent.shape
@ -130,8 +127,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps,
(batch_size,), device=self.accelerator.device
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
)
timesteps = timesteps.long()
@ -193,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,
generator=self.state.generator
generator=self.state.generator,
).frames[0]
return [("video", video_generate)]
@ -204,7 +200,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)

View File

@ -2,4 +2,4 @@ from .args import Args
from .state import State
from .components import Components
__all__ = ["Args", "State", "Components"]
__all__ = ["Args", "State", "Components"]

View File

@ -78,10 +78,10 @@ class Args(BaseModel):
########## Validation ##########
do_validation: bool = False
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
validation_dir: Path | None # if set do_validation, should not be None
validation_dir: Path | None # if set do_validation, should not be None
validation_prompts: str | None # if set do_validation, should not be None
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
gen_fps: int = 15
#### deprecated args: gen_video_resolution
@ -115,7 +115,7 @@ class Args(BaseModel):
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
return v
@field_validator("validation_videos")
@field_validator("validation_videos")
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
@ -131,31 +131,32 @@ class Args(BaseModel):
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
return v
@field_validator("train_resolution")
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
try:
frames, height, width = v
# Check if (frames - 1) is multiple of 8
if (frames - 1) % 8 != 0:
raise ValueError("Number of frames - 1 must be a multiple of 8")
# Check resolution for cogvideox-5b models
model_name = info.data.get("model_name", "")
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
if (height, width) != (480, 720):
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
return v
except ValueError as e:
if str(e) == "not enough values to unpack (expected 3, got 0)" or \
str(e) == "invalid literal for int() with base 10":
if (
str(e) == "not enough values to unpack (expected 3, got 0)"
or str(e) == "invalid literal for int() with base 10"
):
raise ValueError("train_resolution must be in format 'frames x height x width'")
raise e
@classmethod
def parse_args(cls):
"""Parse command line arguments and return Args instance"""
@ -208,8 +209,7 @@ class Args(BaseModel):
# LoRA parameters
parser.add_argument("--rank", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument("--target_modules", type=str, nargs="+",
default=["to_q", "to_k", "to_v", "to_out.0"])
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
# Checkpointing
parser.add_argument("--checkpointing_steps", type=int, default=200)
@ -226,7 +226,7 @@ class Args(BaseModel):
parser.add_argument("--gen_fps", type=int, default=15)
args = parser.parse_args()
# Convert video_resolution_buckets string to list of tuples
frames, height, width = args.train_resolution.split("x")
args.train_resolution = (int(frames), int(height), int(width))

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import List, Dict, Any
from pydantic import BaseModel, field_validator
class State(BaseModel):
model_config = {"arbitrary_types_allowed": True}

View File

@ -3,11 +3,15 @@ import os
from pathlib import Path
import cv2
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory")
parser.add_argument(
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
)
return parser.parse_args()
args = parse_args()
# Create data/images directory if it doesn't exist
@ -24,24 +28,24 @@ with open(videos_file, "r") as f:
image_paths = []
for video_rel_path in video_paths:
video_path = data_dir / video_rel_path
# Open video
cap = cv2.VideoCapture(str(video_path))
# Read first frame
ret, frame = cap.read()
if not ret:
print(f"Failed to read video: {video_path}")
continue
# Save frame as PNG with same name as video
image_name = f"images/{video_path.stem}.png"
image_path = data_dir / image_name
cv2.imwrite(str(image_path), frame)
# Release video capture
cap.release()
print(f"Extracted first frame from {video_path} to {image_path}")
image_paths.append(image_name)
@ -49,4 +53,4 @@ for video_rel_path in video_paths:
images_file = data_dir / "images.txt"
with open(images_file, "w") as f:
for path in image_paths:
f.write(f"{path}\n")
f.write(f"{path}\n")

View File

@ -1,4 +1,3 @@
import os
import logging
import math
import json
@ -32,24 +31,23 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
from finetune.schemas import Args, State, Components
from finetune.utils import (
unwrap_model, cast_training_params,
unwrap_model,
cast_training_params,
get_optimizer,
get_memory_statistics,
free_memory,
unload_model,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
string_to_filename
string_to_filename,
)
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
from finetune.datasets.utils import (
load_prompts, load_images, load_videos,
preprocess_image_with_resize, preprocess_video_with_resize
load_prompts,
load_images,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_resize,
)
from finetune.constants import LOG_NAME, LOG_LEVEL
@ -59,22 +57,22 @@ logger = get_logger(LOG_NAME, LOG_LEVEL)
_DTYPE_MAP = {
"fp32": torch.float32,
"fp16": torch.float16,
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
"bf16": torch.bfloat16,
}
class Trainer:
# If set, should be a list of components to unload (refer to `Components``)
UNLOAD_LIST: List[str] = None
UNLOAD_LIST: List[str] = None
def __init__(self, args: Args) -> None:
self.args = args
self.args = args
self.state = State(
weight_dtype=self.__get_training_dtype(),
train_frames=self.args.train_resolution[0],
train_height=self.args.train_resolution[1],
train_width=self.args.train_resolution[2]
train_width=self.args.train_resolution[2],
)
self.components = Components()
@ -136,11 +134,13 @@ class Trainer:
if self.accelerator.is_main_process:
self.args.output_dir = Path(self.args.output_dir)
self.args.output_dir.mkdir(parents=True, exist_ok=True)
def check_setting(self) -> None:
# Check for unload_list
if self.UNLOAD_LIST is None:
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m")
logger.warning(
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
)
else:
for name in self.UNLOAD_LIST:
if name not in self.components.model_fields:
@ -174,7 +174,7 @@ class Trainer:
max_num_frames=sample_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self
trainer=self,
)
elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize(
@ -183,7 +183,7 @@ class Trainer:
max_num_frames=sample_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self
trainer=self,
)
else:
raise ValueError(f"Invalid model type: {self.args.model_type}")
@ -204,7 +204,8 @@ class Trainer:
pin_memory=self.args.pin_memory,
)
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader: ...
for _ in tmp_data_loader:
...
self.accelerator.wait_for_everyone()
logger.info("Precomputing latent for video and prompt embedding ... Done")
@ -218,16 +219,15 @@ class Trainer:
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
shuffle=True
shuffle=True,
)
def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters")
# For now only lora is supported
for attr_name, component in vars(self.components).items():
if hasattr(component, 'requires_grad_'):
if hasattr(component, "requires_grad_"):
component.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
@ -332,7 +332,7 @@ class Trainer:
# Afterwards we recalculate our number of training epochs
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
def prepare_for_validation(self):
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
@ -452,10 +452,7 @@ class Trainer:
progress_bar.set_postfix(logs)
# Maybe run validation
should_run_validation = (
self.args.do_validation
and global_step % self.args.validation_steps == 0
)
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
if should_run_validation:
del loss
free_memory()
@ -500,7 +497,7 @@ class Trainer:
##### Initialize pipeline #####
pipe = self.initialize_pipeline()
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
pipe.enable_model_cpu_offload(device=self.accelerator.device)
@ -520,9 +517,7 @@ class Trainer:
video = self.state.validation_videos[i]
if image is not None:
image = preprocess_image_with_resize(
image, self.state.train_height, self.state.train_width
)
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
# Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy()
@ -534,17 +529,13 @@ class Trainer:
)
# Convert video tensor (F, C, H, W) to list of PIL images
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.validation_step({
"prompt": prompt,
"image": image,
"video": video
}, pipe)
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
@ -611,7 +602,7 @@ class Trainer:
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
torch.set_grad_enabled(True)
torch.set_grad_enabled(True)
self.components.transformer.train()
def fit(self):
@ -628,10 +619,10 @@ class Trainer:
def collate_fn(self, examples: List[Dict[str, Any]]):
raise NotImplementedError
def load_components(self) -> Components:
raise NotImplementedError
def initialize_pipeline(self) -> DiffusionPipeline:
raise NotImplementedError
@ -643,7 +634,7 @@ class Trainer:
def encode_text(self, text: str) -> torch.Tensor:
# shape of output text: [batch size, sequence length, embedding dimension]
raise NotImplementedError
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
@ -663,18 +654,18 @@ class Trainer:
def __load_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
continue
# setattr(self.components, name, component.to(self.accelerator.device))
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
def __unload_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
setattr(self.components, name, component.to('cpu'))
setattr(self.components, name, component.to("cpu"))
def __prepare_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@ -711,9 +702,7 @@ class Trainer:
):
transformer_ = unwrap_model(self.accelerator, model)
else:
raise ValueError(
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
)
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
else:
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
self.args.model_path, subfolder="transformer"

View File

@ -49,4 +49,4 @@ def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], d
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
param.data = param.to(dtype)