mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-22 08:59:17 +08:00
format and check fp16 for cogvideox2b
This commit is contained in:
parent
1b886326b2
commit
1789f07256
@ -1,2 +1,2 @@
|
||||
LOG_NAME = "trainer"
|
||||
LOG_LEVEL = "INFO"
|
||||
LOG_LEVEL = "INFO"
|
||||
|
@ -8,5 +8,5 @@ __all__ = [
|
||||
"I2VDatasetWithBuckets",
|
||||
"T2VDatasetWithResize",
|
||||
"T2VDatasetWithBuckets",
|
||||
"BucketSampler"
|
||||
"BucketSampler",
|
||||
]
|
||||
|
@ -37,7 +37,6 @@ class BucketSampler(Sampler):
|
||||
|
||||
self._raised_warning_for_drop_last = False
|
||||
|
||||
|
||||
def __len__(self):
|
||||
if self.drop_last and not self._raised_warning_for_drop_last:
|
||||
self._raised_warning_for_drop_last = True
|
||||
@ -46,7 +45,6 @@ class BucketSampler(Sampler):
|
||||
)
|
||||
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
for index, data in enumerate(self.data_source):
|
||||
video_metadata = data["video_metadata"]
|
||||
|
@ -13,11 +13,12 @@ from safetensors.torch import save_file, load_file
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
from .utils import (
|
||||
load_prompts, load_videos, load_images,
|
||||
|
||||
load_prompts,
|
||||
load_videos,
|
||||
load_images,
|
||||
preprocess_image_with_resize,
|
||||
preprocess_video_with_resize,
|
||||
preprocess_video_with_buckets
|
||||
preprocess_video_with_buckets,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -46,6 +47,7 @@ class BaseI2VDataset(Dataset):
|
||||
device (torch.device): Device to load the data on
|
||||
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_root: str,
|
||||
@ -55,7 +57,7 @@ class BaseI2VDataset(Dataset):
|
||||
device: torch.device,
|
||||
trainer: "Trainer" = None,
|
||||
*args,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -120,7 +122,10 @@ class BaseI2VDataset(Dataset):
|
||||
|
||||
if prompt_embedding_path.exists():
|
||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
||||
logger.debug(
|
||||
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
||||
main_process_only=False,
|
||||
)
|
||||
else:
|
||||
prompt_embedding = self.encode_text(prompt)
|
||||
prompt_embedding = prompt_embedding.to("cpu")
|
||||
@ -187,7 +192,7 @@ class BaseI2VDataset(Dataset):
|
||||
- image(torch.Tensor) of shape [C, H, W]
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies transformations to a video.
|
||||
@ -197,14 +202,14 @@ class BaseI2VDataset(Dataset):
|
||||
with shape [F, C, H, W] where:
|
||||
- F is number of frames
|
||||
- C is number of channels (3 for RGB)
|
||||
- H is height
|
||||
- H is height
|
||||
- W is width
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed video tensor
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
|
||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies transformations to an image.
|
||||
@ -213,7 +218,7 @@ class BaseI2VDataset(Dataset):
|
||||
image (torch.Tensor): A 3D tensor representing an image
|
||||
with shape [C, H, W] where:
|
||||
- C is number of channels (3 for RGB)
|
||||
- H is height
|
||||
- H is height
|
||||
- W is width
|
||||
|
||||
Returns:
|
||||
@ -235,6 +240,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
||||
height (int): Target height for resizing videos and images
|
||||
width (int): Target width for resizing videos and images
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -242,11 +248,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
self.__frame_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||
]
|
||||
)
|
||||
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||
self.__image_transforms = self.__frame_transforms
|
||||
|
||||
@override
|
||||
@ -260,25 +262,25 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
||||
else:
|
||||
image = None
|
||||
return video, image
|
||||
|
||||
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||
|
||||
|
||||
@override
|
||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||
return self.__image_transforms(image)
|
||||
|
||||
|
||||
class I2VDatasetWithBuckets(BaseI2VDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||
vae_temporal_compression_ratio: int,
|
||||
vae_height_compression_ratio: int,
|
||||
vae_width_compression_ratio: int,
|
||||
*args, **kwargs
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -290,23 +292,19 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
|
||||
)
|
||||
for b in video_resolution_buckets
|
||||
]
|
||||
self.__frame_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||
]
|
||||
)
|
||||
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||
self.__image_transforms = self.__frame_transforms
|
||||
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
||||
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
|
||||
return video, image
|
||||
|
||||
@override
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||
|
||||
|
||||
@override
|
||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||
return self.__image_transforms(image)
|
||||
|
@ -12,11 +12,7 @@ from safetensors.torch import save_file, load_file
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
from .utils import (
|
||||
load_prompts, load_videos,
|
||||
preprocess_video_with_resize,
|
||||
preprocess_video_with_buckets
|
||||
)
|
||||
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from finetune.trainer import Trainer
|
||||
@ -52,7 +48,7 @@ class BaseT2VDataset(Dataset):
|
||||
device: torch.device = None,
|
||||
trainer: "Trainer" = None,
|
||||
*args,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -108,7 +104,10 @@ class BaseT2VDataset(Dataset):
|
||||
|
||||
if prompt_embedding_path.exists():
|
||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
||||
logger.debug(
|
||||
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
||||
main_process_only=False,
|
||||
)
|
||||
else:
|
||||
prompt_embedding = self.encode_text(prompt)
|
||||
prompt_embedding = prompt_embedding.to("cpu")
|
||||
@ -164,7 +163,7 @@ class BaseT2VDataset(Dataset):
|
||||
- W is width
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies transformations to a video.
|
||||
@ -174,7 +173,7 @@ class BaseT2VDataset(Dataset):
|
||||
with shape [F, C, H, W] where:
|
||||
- F is number of frames
|
||||
- C is number of channels (3 for RGB)
|
||||
- H is height
|
||||
- H is height
|
||||
- W is width
|
||||
|
||||
Returns:
|
||||
@ -203,36 +202,33 @@ class T2VDatasetWithResize(BaseT2VDataset):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
self.__frame_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||
]
|
||||
)
|
||||
|
||||
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||
return preprocess_video_with_resize(
|
||||
video_path, self.max_num_frames, self.height, self.width,
|
||||
video_path,
|
||||
self.max_num_frames,
|
||||
self.height,
|
||||
self.width,
|
||||
)
|
||||
|
||||
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
||||
|
||||
|
||||
class T2VDatasetWithBuckets(BaseT2VDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||
vae_temporal_compression_ratio: int,
|
||||
vae_height_compression_ratio: int,
|
||||
vae_width_compression_ratio: int,
|
||||
*args, **kwargs
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
"""
|
||||
""" """
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.video_resolution_buckets = [
|
||||
@ -244,18 +240,12 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
|
||||
for b in video_resolution_buckets
|
||||
]
|
||||
|
||||
self.__frame_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||
]
|
||||
)
|
||||
|
||||
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||
return preprocess_video_with_buckets(
|
||||
video_path, self.video_resolution_buckets
|
||||
)
|
||||
|
||||
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
||||
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
||||
|
@ -15,6 +15,7 @@ decord.bridge.set_bridge("torch")
|
||||
|
||||
########## loaders ##########
|
||||
|
||||
|
||||
def load_prompts(prompt_path: Path) -> List[str]:
|
||||
with open(prompt_path, "r", encoding="utf-8") as file:
|
||||
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||
@ -32,6 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
|
||||
|
||||
########## preprocessors ##########
|
||||
|
||||
|
||||
def preprocess_image_with_resize(
|
||||
image_path: Path | str,
|
||||
height: int,
|
||||
@ -96,7 +98,7 @@ def preprocess_video_with_resize(
|
||||
|
||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||
frames = video_reader.get_batch(indices)
|
||||
frames = frames[: max_num_frames].float()
|
||||
frames = frames[:max_num_frames].float()
|
||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||
return frames
|
||||
|
||||
@ -144,4 +146,4 @@ def preprocess_video_with_buckets(
|
||||
nearest_res = (nearest_res[1], nearest_res[2])
|
||||
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
||||
|
||||
return frames
|
||||
return frames
|
||||
|
@ -5,8 +5,8 @@ from pathlib import Path
|
||||
package_dir = Path(__file__).parent
|
||||
|
||||
for subdir in package_dir.iterdir():
|
||||
if subdir.is_dir() and not subdir.name.startswith('_'):
|
||||
for module_path in subdir.glob('*.py'):
|
||||
if subdir.is_dir() and not subdir.name.startswith("_"):
|
||||
for module_path in subdir.glob("*.py"):
|
||||
module_name = module_path.stem
|
||||
full_module_name = f".{subdir.name}.{module_name}"
|
||||
importlib.import_module(full_module_name, package=__name__)
|
||||
|
@ -30,28 +30,18 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
|
||||
components.pipeline_cls = CogVideoXImageToVideoPipeline
|
||||
|
||||
components.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, subfolder="tokenizer"
|
||||
)
|
||||
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path, subfolder="text_encoder"
|
||||
)
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
||||
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
model_path, subfolder="transformer"
|
||||
)
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
||||
|
||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(
|
||||
model_path, subfolder="vae"
|
||||
)
|
||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
||||
model_path, subfolder="scheduler"
|
||||
)
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
|
||||
return components
|
||||
|
||||
|
||||
@override
|
||||
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
||||
pipe = CogVideoXImageToVideoPipeline(
|
||||
@ -59,7 +49,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
text_encoder=self.components.text_encoder,
|
||||
vae=self.components.vae,
|
||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||
scheduler=self.components.scheduler
|
||||
scheduler=self.components.scheduler,
|
||||
)
|
||||
return pipe
|
||||
|
||||
@ -71,7 +61,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
latent = latent_dist.sample() * vae.config.scaling_factor
|
||||
return latent
|
||||
|
||||
|
||||
@override
|
||||
def encode_text(self, prompt: str) -> torch.Tensor:
|
||||
prompt_token_ids = self.components.tokenizer(
|
||||
@ -88,12 +78,8 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
|
||||
@override
|
||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
ret = {
|
||||
"encoded_videos": [],
|
||||
"prompt_embedding": [],
|
||||
"images": []
|
||||
}
|
||||
|
||||
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
|
||||
|
||||
for sample in samples:
|
||||
encoded_video = sample["encoded_video"]
|
||||
prompt_embedding = sample["prompt_embedding"]
|
||||
@ -102,13 +88,13 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
ret["encoded_videos"].append(encoded_video)
|
||||
ret["prompt_embedding"].append(prompt_embedding)
|
||||
ret["images"].append(image)
|
||||
|
||||
|
||||
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
||||
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
||||
ret["images"] = torch.stack(ret["images"])
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@override
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
prompt_embedding = batch["prompt_embedding"]
|
||||
@ -144,8 +130,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
|
||||
# Sample a random timestep for each sample
|
||||
timesteps = torch.randint(
|
||||
0, self.components.scheduler.config.num_train_timesteps,
|
||||
(batch_size,), device=self.accelerator.device
|
||||
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
@ -183,7 +168,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
)
|
||||
|
||||
# Predict noise
|
||||
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
||||
ofs_emb = (
|
||||
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
||||
)
|
||||
predicted_noise = self.components.transformer(
|
||||
hidden_states=latent_img_noisy,
|
||||
encoder_hidden_states=prompt_embedding,
|
||||
@ -222,7 +209,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
width=self.state.train_width,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
generator=self.state.generator
|
||||
generator=self.state.generator,
|
||||
).frames[0]
|
||||
return [("video", video_generate)]
|
||||
|
||||
@ -233,7 +220,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
num_frames: int,
|
||||
transformer_config: Dict,
|
||||
vae_scale_factor_spatial: int,
|
||||
device: torch.device
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||
@ -256,4 +243,4 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
|
||||
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
|
||||
|
@ -31,25 +31,15 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
|
||||
components.pipeline_cls = CogVideoXPipeline
|
||||
|
||||
components.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, subfolder="tokenizer"
|
||||
)
|
||||
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path, subfolder="text_encoder"
|
||||
)
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
||||
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
model_path, subfolder="transformer"
|
||||
)
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
||||
|
||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(
|
||||
model_path, subfolder="vae"
|
||||
)
|
||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
||||
model_path, subfolder="scheduler"
|
||||
)
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
|
||||
return components
|
||||
|
||||
@ -60,10 +50,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
text_encoder=self.components.text_encoder,
|
||||
vae=self.components.vae,
|
||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||
scheduler=self.components.scheduler
|
||||
scheduler=self.components.scheduler,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
@override
|
||||
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
||||
# shape of input video: [B, C, F, H, W]
|
||||
@ -86,21 +76,18 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
prompt_token_ids = prompt_token_ids.input_ids
|
||||
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
||||
return prompt_embedding
|
||||
|
||||
|
||||
@override
|
||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
ret = {
|
||||
"encoded_videos": [],
|
||||
"prompt_embedding": []
|
||||
}
|
||||
|
||||
ret = {"encoded_videos": [], "prompt_embedding": []}
|
||||
|
||||
for sample in samples:
|
||||
encoded_video = sample["encoded_video"]
|
||||
prompt_embedding = sample["prompt_embedding"]
|
||||
|
||||
ret["encoded_videos"].append(encoded_video)
|
||||
ret["prompt_embedding"].append(prompt_embedding)
|
||||
|
||||
|
||||
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
||||
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
||||
|
||||
@ -116,10 +103,20 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
|
||||
patch_size_t = self.state.transformer_config.patch_size_t
|
||||
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
||||
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
|
||||
|
||||
raise ValueError(
|
||||
"Number of frames in latent must be divisible by patch size, please check your args for training."
|
||||
)
|
||||
|
||||
# Add 2 random noise frames at the beginning of frame dimension
|
||||
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
|
||||
noise_frames = torch.randn(
|
||||
latent.shape[0],
|
||||
latent.shape[1],
|
||||
2,
|
||||
latent.shape[3],
|
||||
latent.shape[4],
|
||||
device=latent.device,
|
||||
dtype=latent.dtype,
|
||||
)
|
||||
latent = torch.cat([noise_frames, latent], dim=2)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||
@ -130,8 +127,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
|
||||
# Sample a random timestep for each sample
|
||||
timesteps = torch.randint(
|
||||
0, self.components.scheduler.config.num_train_timesteps,
|
||||
(batch_size,), device=self.accelerator.device
|
||||
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
@ -193,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
prompt=prompt,
|
||||
generator=self.state.generator
|
||||
generator=self.state.generator,
|
||||
).frames[0]
|
||||
return [("video", video_generate)]
|
||||
|
||||
@ -204,7 +200,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
num_frames: int,
|
||||
transformer_config: Dict,
|
||||
vae_scale_factor_spatial: int,
|
||||
device: torch.device
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||
|
@ -2,4 +2,4 @@ from .args import Args
|
||||
from .state import State
|
||||
from .components import Components
|
||||
|
||||
__all__ = ["Args", "State", "Components"]
|
||||
__all__ = ["Args", "State", "Components"]
|
||||
|
@ -78,10 +78,10 @@ class Args(BaseModel):
|
||||
########## Validation ##########
|
||||
do_validation: bool = False
|
||||
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
|
||||
validation_dir: Path | None # if set do_validation, should not be None
|
||||
validation_dir: Path | None # if set do_validation, should not be None
|
||||
validation_prompts: str | None # if set do_validation, should not be None
|
||||
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
||||
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
|
||||
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
||||
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
|
||||
gen_fps: int = 15
|
||||
|
||||
#### deprecated args: gen_video_resolution
|
||||
@ -115,7 +115,7 @@ class Args(BaseModel):
|
||||
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
|
||||
return v
|
||||
|
||||
@field_validator("validation_videos")
|
||||
@field_validator("validation_videos")
|
||||
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||
values = info.data
|
||||
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
|
||||
@ -131,31 +131,32 @@ class Args(BaseModel):
|
||||
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
|
||||
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator("train_resolution")
|
||||
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
|
||||
try:
|
||||
frames, height, width = v
|
||||
|
||||
|
||||
# Check if (frames - 1) is multiple of 8
|
||||
if (frames - 1) % 8 != 0:
|
||||
raise ValueError("Number of frames - 1 must be a multiple of 8")
|
||||
|
||||
|
||||
# Check resolution for cogvideox-5b models
|
||||
model_name = info.data.get("model_name", "")
|
||||
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
|
||||
if (height, width) != (480, 720):
|
||||
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
|
||||
|
||||
|
||||
return v
|
||||
|
||||
|
||||
except ValueError as e:
|
||||
if str(e) == "not enough values to unpack (expected 3, got 0)" or \
|
||||
str(e) == "invalid literal for int() with base 10":
|
||||
if (
|
||||
str(e) == "not enough values to unpack (expected 3, got 0)"
|
||||
or str(e) == "invalid literal for int() with base 10"
|
||||
):
|
||||
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
||||
raise e
|
||||
|
||||
|
||||
@classmethod
|
||||
def parse_args(cls):
|
||||
"""Parse command line arguments and return Args instance"""
|
||||
@ -208,8 +209,7 @@ class Args(BaseModel):
|
||||
# LoRA parameters
|
||||
parser.add_argument("--rank", type=int, default=128)
|
||||
parser.add_argument("--lora_alpha", type=int, default=64)
|
||||
parser.add_argument("--target_modules", type=str, nargs="+",
|
||||
default=["to_q", "to_k", "to_v", "to_out.0"])
|
||||
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
||||
@ -226,7 +226,7 @@ class Args(BaseModel):
|
||||
parser.add_argument("--gen_fps", type=int, default=15)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Convert video_resolution_buckets string to list of tuples
|
||||
frames, height, width = args.train_resolution.split("x")
|
||||
args.train_resolution = (int(frames), int(height), int(width))
|
||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
@ -3,11 +3,15 @@ import os
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory")
|
||||
parser.add_argument(
|
||||
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# Create data/images directory if it doesn't exist
|
||||
@ -24,24 +28,24 @@ with open(videos_file, "r") as f:
|
||||
image_paths = []
|
||||
for video_rel_path in video_paths:
|
||||
video_path = data_dir / video_rel_path
|
||||
|
||||
|
||||
# Open video
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
|
||||
|
||||
# Read first frame
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print(f"Failed to read video: {video_path}")
|
||||
continue
|
||||
|
||||
|
||||
# Save frame as PNG with same name as video
|
||||
image_name = f"images/{video_path.stem}.png"
|
||||
image_path = data_dir / image_name
|
||||
cv2.imwrite(str(image_path), frame)
|
||||
|
||||
|
||||
# Release video capture
|
||||
cap.release()
|
||||
|
||||
|
||||
print(f"Extracted first frame from {video_path} to {image_path}")
|
||||
image_paths.append(image_name)
|
||||
|
||||
@ -49,4 +53,4 @@ for video_rel_path in video_paths:
|
||||
images_file = data_dir / "images.txt"
|
||||
with open(images_file, "w") as f:
|
||||
for path in image_paths:
|
||||
f.write(f"{path}\n")
|
||||
f.write(f"{path}\n")
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
import json
|
||||
@ -32,24 +31,23 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
|
||||
|
||||
from finetune.schemas import Args, State, Components
|
||||
from finetune.utils import (
|
||||
unwrap_model, cast_training_params,
|
||||
unwrap_model,
|
||||
cast_training_params,
|
||||
get_optimizer,
|
||||
|
||||
get_memory_statistics,
|
||||
free_memory,
|
||||
unload_model,
|
||||
|
||||
get_latest_ckpt_path_to_resume_from,
|
||||
get_intermediate_ckpt_path,
|
||||
get_latest_ckpt_path_to_resume_from,
|
||||
get_intermediate_ckpt_path,
|
||||
|
||||
string_to_filename
|
||||
string_to_filename,
|
||||
)
|
||||
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
||||
from finetune.datasets.utils import (
|
||||
load_prompts, load_images, load_videos,
|
||||
preprocess_image_with_resize, preprocess_video_with_resize
|
||||
load_prompts,
|
||||
load_images,
|
||||
load_videos,
|
||||
preprocess_image_with_resize,
|
||||
preprocess_video_with_resize,
|
||||
)
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
@ -59,22 +57,22 @@ logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
_DTYPE_MAP = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
class Trainer:
|
||||
# If set, should be a list of components to unload (refer to `Components``)
|
||||
UNLOAD_LIST: List[str] = None
|
||||
UNLOAD_LIST: List[str] = None
|
||||
|
||||
def __init__(self, args: Args) -> None:
|
||||
self.args = args
|
||||
self.args = args
|
||||
self.state = State(
|
||||
weight_dtype=self.__get_training_dtype(),
|
||||
train_frames=self.args.train_resolution[0],
|
||||
train_height=self.args.train_resolution[1],
|
||||
train_width=self.args.train_resolution[2]
|
||||
train_width=self.args.train_resolution[2],
|
||||
)
|
||||
|
||||
self.components = Components()
|
||||
@ -136,11 +134,13 @@ class Trainer:
|
||||
if self.accelerator.is_main_process:
|
||||
self.args.output_dir = Path(self.args.output_dir)
|
||||
self.args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def check_setting(self) -> None:
|
||||
# Check for unload_list
|
||||
if self.UNLOAD_LIST is None:
|
||||
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m")
|
||||
logger.warning(
|
||||
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
|
||||
)
|
||||
else:
|
||||
for name in self.UNLOAD_LIST:
|
||||
if name not in self.components.model_fields:
|
||||
@ -174,7 +174,7 @@ class Trainer:
|
||||
max_num_frames=sample_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self
|
||||
trainer=self,
|
||||
)
|
||||
elif self.args.model_type == "t2v":
|
||||
self.dataset = T2VDatasetWithResize(
|
||||
@ -183,7 +183,7 @@ class Trainer:
|
||||
max_num_frames=sample_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self
|
||||
trainer=self,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
||||
@ -204,7 +204,8 @@ class Trainer:
|
||||
pin_memory=self.args.pin_memory,
|
||||
)
|
||||
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
|
||||
for _ in tmp_data_loader: ...
|
||||
for _ in tmp_data_loader:
|
||||
...
|
||||
self.accelerator.wait_for_everyone()
|
||||
logger.info("Precomputing latent for video and prompt embedding ... Done")
|
||||
|
||||
@ -218,16 +219,15 @@ class Trainer:
|
||||
batch_size=self.args.batch_size,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=self.args.pin_memory,
|
||||
shuffle=True
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
|
||||
def prepare_trainable_parameters(self):
|
||||
logger.info("Initializing trainable parameters")
|
||||
|
||||
# For now only lora is supported
|
||||
for attr_name, component in vars(self.components).items():
|
||||
if hasattr(component, 'requires_grad_'):
|
||||
if hasattr(component, "requires_grad_"):
|
||||
component.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
@ -332,7 +332,7 @@ class Trainer:
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
|
||||
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
||||
|
||||
|
||||
def prepare_for_validation(self):
|
||||
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
|
||||
|
||||
@ -452,10 +452,7 @@ class Trainer:
|
||||
progress_bar.set_postfix(logs)
|
||||
|
||||
# Maybe run validation
|
||||
should_run_validation = (
|
||||
self.args.do_validation
|
||||
and global_step % self.args.validation_steps == 0
|
||||
)
|
||||
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
|
||||
if should_run_validation:
|
||||
del loss
|
||||
free_memory()
|
||||
@ -500,7 +497,7 @@ class Trainer:
|
||||
|
||||
##### Initialize pipeline #####
|
||||
pipe = self.initialize_pipeline()
|
||||
|
||||
|
||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
||||
|
||||
@ -520,9 +517,7 @@ class Trainer:
|
||||
video = self.state.validation_videos[i]
|
||||
|
||||
if image is not None:
|
||||
image = preprocess_image_with_resize(
|
||||
image, self.state.train_height, self.state.train_width
|
||||
)
|
||||
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
|
||||
# Convert image tensor (C, H, W) to PIL images
|
||||
image = image.to(torch.uint8)
|
||||
image = image.permute(1, 2, 0).cpu().numpy()
|
||||
@ -534,17 +529,13 @@ class Trainer:
|
||||
)
|
||||
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
||||
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
|
||||
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
||||
|
||||
logger.debug(
|
||||
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
||||
main_process_only=False,
|
||||
)
|
||||
validation_artifacts = self.validation_step({
|
||||
"prompt": prompt,
|
||||
"image": image,
|
||||
"video": video
|
||||
}, pipe)
|
||||
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||
prompt_filename = string_to_filename(prompt)[:25]
|
||||
artifacts = {
|
||||
"image": {"type": "image", "value": image},
|
||||
@ -611,7 +602,7 @@ class Trainer:
|
||||
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||
|
||||
torch.set_grad_enabled(True)
|
||||
torch.set_grad_enabled(True)
|
||||
self.components.transformer.train()
|
||||
|
||||
def fit(self):
|
||||
@ -628,10 +619,10 @@ class Trainer:
|
||||
|
||||
def collate_fn(self, examples: List[Dict[str, Any]]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def load_components(self) -> Components:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def initialize_pipeline(self) -> DiffusionPipeline:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -643,7 +634,7 @@ class Trainer:
|
||||
def encode_text(self, text: str) -> torch.Tensor:
|
||||
# shape of output text: [batch size, sequence length, embedding dimension]
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -663,18 +654,18 @@ class Trainer:
|
||||
def __load_components(self):
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, 'to'):
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name in self.UNLOAD_LIST:
|
||||
continue
|
||||
# setattr(self.components, name, component.to(self.accelerator.device))
|
||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
|
||||
|
||||
|
||||
def __unload_components(self):
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, 'to'):
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name in self.UNLOAD_LIST:
|
||||
setattr(self.components, name, component.to('cpu'))
|
||||
setattr(self.components, name, component.to("cpu"))
|
||||
|
||||
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@ -711,9 +702,7 @@ class Trainer:
|
||||
):
|
||||
transformer_ = unwrap_model(self.accelerator, model)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
|
||||
)
|
||||
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
|
||||
else:
|
||||
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
|
||||
self.args.model_path, subfolder="transformer"
|
||||
|
@ -49,4 +49,4 @@ def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], d
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
||||
param.data = param.to(dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user