mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-25 16:55:58 +08:00
format and check fp16 for cogvideox2b
This commit is contained in:
parent
1b886326b2
commit
1789f07256
@ -8,5 +8,5 @@ __all__ = [
|
|||||||
"I2VDatasetWithBuckets",
|
"I2VDatasetWithBuckets",
|
||||||
"T2VDatasetWithResize",
|
"T2VDatasetWithResize",
|
||||||
"T2VDatasetWithBuckets",
|
"T2VDatasetWithBuckets",
|
||||||
"BucketSampler"
|
"BucketSampler",
|
||||||
]
|
]
|
||||||
|
@ -37,7 +37,6 @@ class BucketSampler(Sampler):
|
|||||||
|
|
||||||
self._raised_warning_for_drop_last = False
|
self._raised_warning_for_drop_last = False
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self.drop_last and not self._raised_warning_for_drop_last:
|
if self.drop_last and not self._raised_warning_for_drop_last:
|
||||||
self._raised_warning_for_drop_last = True
|
self._raised_warning_for_drop_last = True
|
||||||
@ -46,7 +45,6 @@ class BucketSampler(Sampler):
|
|||||||
)
|
)
|
||||||
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
||||||
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for index, data in enumerate(self.data_source):
|
for index, data in enumerate(self.data_source):
|
||||||
video_metadata = data["video_metadata"]
|
video_metadata = data["video_metadata"]
|
||||||
|
@ -13,11 +13,12 @@ from safetensors.torch import save_file, load_file
|
|||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
load_prompts, load_videos, load_images,
|
load_prompts,
|
||||||
|
load_videos,
|
||||||
|
load_images,
|
||||||
preprocess_image_with_resize,
|
preprocess_image_with_resize,
|
||||||
preprocess_video_with_resize,
|
preprocess_video_with_resize,
|
||||||
preprocess_video_with_buckets
|
preprocess_video_with_buckets,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -46,6 +47,7 @@ class BaseI2VDataset(Dataset):
|
|||||||
device (torch.device): Device to load the data on
|
device (torch.device): Device to load the data on
|
||||||
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_root: str,
|
data_root: str,
|
||||||
@ -55,7 +57,7 @@ class BaseI2VDataset(Dataset):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
trainer: "Trainer" = None,
|
trainer: "Trainer" = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -120,7 +122,10 @@ class BaseI2VDataset(Dataset):
|
|||||||
|
|
||||||
if prompt_embedding_path.exists():
|
if prompt_embedding_path.exists():
|
||||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||||
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
logger.debug(
|
||||||
|
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
||||||
|
main_process_only=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_embedding = self.encode_text(prompt)
|
prompt_embedding = self.encode_text(prompt)
|
||||||
prompt_embedding = prompt_embedding.to("cpu")
|
prompt_embedding = prompt_embedding.to("cpu")
|
||||||
@ -235,6 +240,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
|||||||
height (int): Target height for resizing videos and images
|
height (int): Target height for resizing videos and images
|
||||||
width (int): Target width for resizing videos and images
|
width (int): Target width for resizing videos and images
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@ -242,11 +248,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|
||||||
self.__frame_transforms = transforms.Compose(
|
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||||
[
|
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.__image_transforms = self.__frame_transforms
|
self.__image_transforms = self.__frame_transforms
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -271,14 +273,14 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
|||||||
|
|
||||||
|
|
||||||
class I2VDatasetWithBuckets(BaseI2VDataset):
|
class I2VDatasetWithBuckets(BaseI2VDataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||||
vae_temporal_compression_ratio: int,
|
vae_temporal_compression_ratio: int,
|
||||||
vae_height_compression_ratio: int,
|
vae_height_compression_ratio: int,
|
||||||
vae_width_compression_ratio: int,
|
vae_width_compression_ratio: int,
|
||||||
*args, **kwargs
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@ -290,11 +292,7 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
|
|||||||
)
|
)
|
||||||
for b in video_resolution_buckets
|
for b in video_resolution_buckets
|
||||||
]
|
]
|
||||||
self.__frame_transforms = transforms.Compose(
|
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||||
[
|
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.__image_transforms = self.__frame_transforms
|
self.__image_transforms = self.__frame_transforms
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
@ -12,11 +12,7 @@ from safetensors.torch import save_file, load_file
|
|||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
|
||||||
from .utils import (
|
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
|
||||||
load_prompts, load_videos,
|
|
||||||
preprocess_video_with_resize,
|
|
||||||
preprocess_video_with_buckets
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from finetune.trainer import Trainer
|
from finetune.trainer import Trainer
|
||||||
@ -52,7 +48,7 @@ class BaseT2VDataset(Dataset):
|
|||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
trainer: "Trainer" = None,
|
trainer: "Trainer" = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -108,7 +104,10 @@ class BaseT2VDataset(Dataset):
|
|||||||
|
|
||||||
if prompt_embedding_path.exists():
|
if prompt_embedding_path.exists():
|
||||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||||
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
logger.debug(
|
||||||
|
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
||||||
|
main_process_only=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_embedding = self.encode_text(prompt)
|
prompt_embedding = self.encode_text(prompt)
|
||||||
prompt_embedding = prompt_embedding.to("cpu")
|
prompt_embedding = prompt_embedding.to("cpu")
|
||||||
@ -203,16 +202,15 @@ class T2VDatasetWithResize(BaseT2VDataset):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|
||||||
self.__frame_transform = transforms.Compose(
|
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||||
[
|
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
return preprocess_video_with_resize(
|
return preprocess_video_with_resize(
|
||||||
video_path, self.max_num_frames, self.height, self.width,
|
video_path,
|
||||||
|
self.max_num_frames,
|
||||||
|
self.height,
|
||||||
|
self.width,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -221,18 +219,16 @@ class T2VDatasetWithResize(BaseT2VDataset):
|
|||||||
|
|
||||||
|
|
||||||
class T2VDatasetWithBuckets(BaseT2VDataset):
|
class T2VDatasetWithBuckets(BaseT2VDataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||||
vae_temporal_compression_ratio: int,
|
vae_temporal_compression_ratio: int,
|
||||||
vae_height_compression_ratio: int,
|
vae_height_compression_ratio: int,
|
||||||
vae_width_compression_ratio: int,
|
vae_width_compression_ratio: int,
|
||||||
*args, **kwargs
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
""" """
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.video_resolution_buckets = [
|
self.video_resolution_buckets = [
|
||||||
@ -244,17 +240,11 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
|
|||||||
for b in video_resolution_buckets
|
for b in video_resolution_buckets
|
||||||
]
|
]
|
||||||
|
|
||||||
self.__frame_transform = transforms.Compose(
|
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||||
[
|
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
return preprocess_video_with_buckets(
|
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
||||||
video_path, self.video_resolution_buckets
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -15,6 +15,7 @@ decord.bridge.set_bridge("torch")
|
|||||||
|
|
||||||
########## loaders ##########
|
########## loaders ##########
|
||||||
|
|
||||||
|
|
||||||
def load_prompts(prompt_path: Path) -> List[str]:
|
def load_prompts(prompt_path: Path) -> List[str]:
|
||||||
with open(prompt_path, "r", encoding="utf-8") as file:
|
with open(prompt_path, "r", encoding="utf-8") as file:
|
||||||
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||||
@ -32,6 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
|
|||||||
|
|
||||||
########## preprocessors ##########
|
########## preprocessors ##########
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image_with_resize(
|
def preprocess_image_with_resize(
|
||||||
image_path: Path | str,
|
image_path: Path | str,
|
||||||
height: int,
|
height: int,
|
||||||
@ -96,7 +98,7 @@ def preprocess_video_with_resize(
|
|||||||
|
|
||||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||||
frames = video_reader.get_batch(indices)
|
frames = video_reader.get_batch(indices)
|
||||||
frames = frames[: max_num_frames].float()
|
frames = frames[:max_num_frames].float()
|
||||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@ from pathlib import Path
|
|||||||
package_dir = Path(__file__).parent
|
package_dir = Path(__file__).parent
|
||||||
|
|
||||||
for subdir in package_dir.iterdir():
|
for subdir in package_dir.iterdir():
|
||||||
if subdir.is_dir() and not subdir.name.startswith('_'):
|
if subdir.is_dir() and not subdir.name.startswith("_"):
|
||||||
for module_path in subdir.glob('*.py'):
|
for module_path in subdir.glob("*.py"):
|
||||||
module_name = module_path.stem
|
module_name = module_path.stem
|
||||||
full_module_name = f".{subdir.name}.{module_name}"
|
full_module_name = f".{subdir.name}.{module_name}"
|
||||||
importlib.import_module(full_module_name, package=__name__)
|
importlib.import_module(full_module_name, package=__name__)
|
||||||
|
@ -30,25 +30,15 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
components.pipeline_cls = CogVideoXImageToVideoPipeline
|
components.pipeline_cls = CogVideoXImageToVideoPipeline
|
||||||
|
|
||||||
components.tokenizer = AutoTokenizer.from_pretrained(
|
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||||
model_path, subfolder="tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
||||||
model_path, subfolder="text_encoder"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
||||||
model_path, subfolder="transformer"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(
|
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||||
model_path, subfolder="vae"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||||
model_path, subfolder="scheduler"
|
|
||||||
)
|
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
@ -59,7 +49,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
text_encoder=self.components.text_encoder,
|
text_encoder=self.components.text_encoder,
|
||||||
vae=self.components.vae,
|
vae=self.components.vae,
|
||||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||||
scheduler=self.components.scheduler
|
scheduler=self.components.scheduler,
|
||||||
)
|
)
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@ -88,11 +78,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
ret = {
|
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
|
||||||
"encoded_videos": [],
|
|
||||||
"prompt_embedding": [],
|
|
||||||
"images": []
|
|
||||||
}
|
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
encoded_video = sample["encoded_video"]
|
encoded_video = sample["encoded_video"]
|
||||||
@ -144,8 +130,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
# Sample a random timestep for each sample
|
# Sample a random timestep for each sample
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.components.scheduler.config.num_train_timesteps,
|
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
||||||
(batch_size,), device=self.accelerator.device
|
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
@ -183,7 +168,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Predict noise
|
# Predict noise
|
||||||
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
ofs_emb = (
|
||||||
|
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
||||||
|
)
|
||||||
predicted_noise = self.components.transformer(
|
predicted_noise = self.components.transformer(
|
||||||
hidden_states=latent_img_noisy,
|
hidden_states=latent_img_noisy,
|
||||||
encoder_hidden_states=prompt_embedding,
|
encoder_hidden_states=prompt_embedding,
|
||||||
@ -222,7 +209,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=image,
|
image=image,
|
||||||
generator=self.state.generator
|
generator=self.state.generator,
|
||||||
).frames[0]
|
).frames[0]
|
||||||
return [("video", video_generate)]
|
return [("video", video_generate)]
|
||||||
|
|
||||||
@ -233,7 +220,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
transformer_config: Dict,
|
transformer_config: Dict,
|
||||||
vae_scale_factor_spatial: int,
|
vae_scale_factor_spatial: int,
|
||||||
device: torch.device
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||||
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||||
|
@ -31,25 +31,15 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
components.pipeline_cls = CogVideoXPipeline
|
components.pipeline_cls = CogVideoXPipeline
|
||||||
|
|
||||||
components.tokenizer = AutoTokenizer.from_pretrained(
|
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||||
model_path, subfolder="tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
||||||
model_path, subfolder="text_encoder"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
||||||
model_path, subfolder="transformer"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(
|
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||||
model_path, subfolder="vae"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||||
model_path, subfolder="scheduler"
|
|
||||||
)
|
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
@ -60,7 +50,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
text_encoder=self.components.text_encoder,
|
text_encoder=self.components.text_encoder,
|
||||||
vae=self.components.vae,
|
vae=self.components.vae,
|
||||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||||
scheduler=self.components.scheduler
|
scheduler=self.components.scheduler,
|
||||||
)
|
)
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@ -89,10 +79,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
ret = {
|
ret = {"encoded_videos": [], "prompt_embedding": []}
|
||||||
"encoded_videos": [],
|
|
||||||
"prompt_embedding": []
|
|
||||||
}
|
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
encoded_video = sample["encoded_video"]
|
encoded_video = sample["encoded_video"]
|
||||||
@ -116,10 +103,20 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
patch_size_t = self.state.transformer_config.patch_size_t
|
patch_size_t = self.state.transformer_config.patch_size_t
|
||||||
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
||||||
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
|
raise ValueError(
|
||||||
|
"Number of frames in latent must be divisible by patch size, please check your args for training."
|
||||||
|
)
|
||||||
|
|
||||||
# Add 2 random noise frames at the beginning of frame dimension
|
# Add 2 random noise frames at the beginning of frame dimension
|
||||||
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
|
noise_frames = torch.randn(
|
||||||
|
latent.shape[0],
|
||||||
|
latent.shape[1],
|
||||||
|
2,
|
||||||
|
latent.shape[3],
|
||||||
|
latent.shape[4],
|
||||||
|
device=latent.device,
|
||||||
|
dtype=latent.dtype,
|
||||||
|
)
|
||||||
latent = torch.cat([noise_frames, latent], dim=2)
|
latent = torch.cat([noise_frames, latent], dim=2)
|
||||||
|
|
||||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||||
@ -130,8 +127,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
# Sample a random timestep for each sample
|
# Sample a random timestep for each sample
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.components.scheduler.config.num_train_timesteps,
|
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
||||||
(batch_size,), device=self.accelerator.device
|
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
@ -193,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
generator=self.state.generator
|
generator=self.state.generator,
|
||||||
).frames[0]
|
).frames[0]
|
||||||
return [("video", video_generate)]
|
return [("video", video_generate)]
|
||||||
|
|
||||||
@ -204,7 +200,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
transformer_config: Dict,
|
transformer_config: Dict,
|
||||||
vae_scale_factor_spatial: int,
|
vae_scale_factor_spatial: int,
|
||||||
device: torch.device
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||||
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||||
|
@ -78,10 +78,10 @@ class Args(BaseModel):
|
|||||||
########## Validation ##########
|
########## Validation ##########
|
||||||
do_validation: bool = False
|
do_validation: bool = False
|
||||||
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
|
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
|
||||||
validation_dir: Path | None # if set do_validation, should not be None
|
validation_dir: Path | None # if set do_validation, should not be None
|
||||||
validation_prompts: str | None # if set do_validation, should not be None
|
validation_prompts: str | None # if set do_validation, should not be None
|
||||||
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
||||||
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
|
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
|
||||||
gen_fps: int = 15
|
gen_fps: int = 15
|
||||||
|
|
||||||
#### deprecated args: gen_video_resolution
|
#### deprecated args: gen_video_resolution
|
||||||
@ -150,12 +150,13 @@ class Args(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if str(e) == "not enough values to unpack (expected 3, got 0)" or \
|
if (
|
||||||
str(e) == "invalid literal for int() with base 10":
|
str(e) == "not enough values to unpack (expected 3, got 0)"
|
||||||
|
or str(e) == "invalid literal for int() with base 10"
|
||||||
|
):
|
||||||
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_args(cls):
|
def parse_args(cls):
|
||||||
"""Parse command line arguments and return Args instance"""
|
"""Parse command line arguments and return Args instance"""
|
||||||
@ -208,8 +209,7 @@ class Args(BaseModel):
|
|||||||
# LoRA parameters
|
# LoRA parameters
|
||||||
parser.add_argument("--rank", type=int, default=128)
|
parser.add_argument("--rank", type=int, default=128)
|
||||||
parser.add_argument("--lora_alpha", type=int, default=64)
|
parser.add_argument("--lora_alpha", type=int, default=64)
|
||||||
parser.add_argument("--target_modules", type=str, nargs="+",
|
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
|
||||||
default=["to_q", "to_k", "to_v", "to_out.0"])
|
|
||||||
|
|
||||||
# Checkpointing
|
# Checkpointing
|
||||||
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
||||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
|
||||||
class State(BaseModel):
|
class State(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
@ -3,11 +3,15 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory")
|
parser.add_argument(
|
||||||
|
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Create data/images directory if it doesn't exist
|
# Create data/images directory if it doesn't exist
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
@ -32,24 +31,23 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
|
|||||||
|
|
||||||
from finetune.schemas import Args, State, Components
|
from finetune.schemas import Args, State, Components
|
||||||
from finetune.utils import (
|
from finetune.utils import (
|
||||||
unwrap_model, cast_training_params,
|
unwrap_model,
|
||||||
|
cast_training_params,
|
||||||
get_optimizer,
|
get_optimizer,
|
||||||
|
|
||||||
get_memory_statistics,
|
get_memory_statistics,
|
||||||
free_memory,
|
free_memory,
|
||||||
unload_model,
|
unload_model,
|
||||||
|
|
||||||
get_latest_ckpt_path_to_resume_from,
|
get_latest_ckpt_path_to_resume_from,
|
||||||
get_intermediate_ckpt_path,
|
get_intermediate_ckpt_path,
|
||||||
get_latest_ckpt_path_to_resume_from,
|
string_to_filename,
|
||||||
get_intermediate_ckpt_path,
|
|
||||||
|
|
||||||
string_to_filename
|
|
||||||
)
|
)
|
||||||
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
||||||
from finetune.datasets.utils import (
|
from finetune.datasets.utils import (
|
||||||
load_prompts, load_images, load_videos,
|
load_prompts,
|
||||||
preprocess_image_with_resize, preprocess_video_with_resize
|
load_images,
|
||||||
|
load_videos,
|
||||||
|
preprocess_image_with_resize,
|
||||||
|
preprocess_video_with_resize,
|
||||||
)
|
)
|
||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
@ -59,7 +57,7 @@ logger = get_logger(LOG_NAME, LOG_LEVEL)
|
|||||||
|
|
||||||
_DTYPE_MAP = {
|
_DTYPE_MAP = {
|
||||||
"fp32": torch.float32,
|
"fp32": torch.float32,
|
||||||
"fp16": torch.float16,
|
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
|
||||||
"bf16": torch.bfloat16,
|
"bf16": torch.bfloat16,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,12 +67,12 @@ class Trainer:
|
|||||||
UNLOAD_LIST: List[str] = None
|
UNLOAD_LIST: List[str] = None
|
||||||
|
|
||||||
def __init__(self, args: Args) -> None:
|
def __init__(self, args: Args) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.state = State(
|
self.state = State(
|
||||||
weight_dtype=self.__get_training_dtype(),
|
weight_dtype=self.__get_training_dtype(),
|
||||||
train_frames=self.args.train_resolution[0],
|
train_frames=self.args.train_resolution[0],
|
||||||
train_height=self.args.train_resolution[1],
|
train_height=self.args.train_resolution[1],
|
||||||
train_width=self.args.train_resolution[2]
|
train_width=self.args.train_resolution[2],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.components = Components()
|
self.components = Components()
|
||||||
@ -140,7 +138,9 @@ class Trainer:
|
|||||||
def check_setting(self) -> None:
|
def check_setting(self) -> None:
|
||||||
# Check for unload_list
|
# Check for unload_list
|
||||||
if self.UNLOAD_LIST is None:
|
if self.UNLOAD_LIST is None:
|
||||||
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m")
|
logger.warning(
|
||||||
|
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for name in self.UNLOAD_LIST:
|
for name in self.UNLOAD_LIST:
|
||||||
if name not in self.components.model_fields:
|
if name not in self.components.model_fields:
|
||||||
@ -174,7 +174,7 @@ class Trainer:
|
|||||||
max_num_frames=sample_frames,
|
max_num_frames=sample_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
trainer=self
|
trainer=self,
|
||||||
)
|
)
|
||||||
elif self.args.model_type == "t2v":
|
elif self.args.model_type == "t2v":
|
||||||
self.dataset = T2VDatasetWithResize(
|
self.dataset = T2VDatasetWithResize(
|
||||||
@ -183,7 +183,7 @@ class Trainer:
|
|||||||
max_num_frames=sample_frames,
|
max_num_frames=sample_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
trainer=self
|
trainer=self,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
||||||
@ -204,7 +204,8 @@ class Trainer:
|
|||||||
pin_memory=self.args.pin_memory,
|
pin_memory=self.args.pin_memory,
|
||||||
)
|
)
|
||||||
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
|
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
|
||||||
for _ in tmp_data_loader: ...
|
for _ in tmp_data_loader:
|
||||||
|
...
|
||||||
self.accelerator.wait_for_everyone()
|
self.accelerator.wait_for_everyone()
|
||||||
logger.info("Precomputing latent for video and prompt embedding ... Done")
|
logger.info("Precomputing latent for video and prompt embedding ... Done")
|
||||||
|
|
||||||
@ -218,16 +219,15 @@ class Trainer:
|
|||||||
batch_size=self.args.batch_size,
|
batch_size=self.args.batch_size,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
pin_memory=self.args.pin_memory,
|
pin_memory=self.args.pin_memory,
|
||||||
shuffle=True
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_trainable_parameters(self):
|
def prepare_trainable_parameters(self):
|
||||||
logger.info("Initializing trainable parameters")
|
logger.info("Initializing trainable parameters")
|
||||||
|
|
||||||
# For now only lora is supported
|
# For now only lora is supported
|
||||||
for attr_name, component in vars(self.components).items():
|
for attr_name, component in vars(self.components).items():
|
||||||
if hasattr(component, 'requires_grad_'):
|
if hasattr(component, "requires_grad_"):
|
||||||
component.requires_grad_(False)
|
component.requires_grad_(False)
|
||||||
|
|
||||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||||
@ -452,10 +452,7 @@ class Trainer:
|
|||||||
progress_bar.set_postfix(logs)
|
progress_bar.set_postfix(logs)
|
||||||
|
|
||||||
# Maybe run validation
|
# Maybe run validation
|
||||||
should_run_validation = (
|
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
|
||||||
self.args.do_validation
|
|
||||||
and global_step % self.args.validation_steps == 0
|
|
||||||
)
|
|
||||||
if should_run_validation:
|
if should_run_validation:
|
||||||
del loss
|
del loss
|
||||||
free_memory()
|
free_memory()
|
||||||
@ -520,9 +517,7 @@ class Trainer:
|
|||||||
video = self.state.validation_videos[i]
|
video = self.state.validation_videos[i]
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image = preprocess_image_with_resize(
|
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
|
||||||
image, self.state.train_height, self.state.train_width
|
|
||||||
)
|
|
||||||
# Convert image tensor (C, H, W) to PIL images
|
# Convert image tensor (C, H, W) to PIL images
|
||||||
image = image.to(torch.uint8)
|
image = image.to(torch.uint8)
|
||||||
image = image.permute(1, 2, 0).cpu().numpy()
|
image = image.permute(1, 2, 0).cpu().numpy()
|
||||||
@ -534,17 +529,13 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
# Convert video tensor (F, C, H, W) to list of PIL images
|
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||||
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
||||||
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
|
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
||||||
main_process_only=False,
|
main_process_only=False,
|
||||||
)
|
)
|
||||||
validation_artifacts = self.validation_step({
|
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||||
"prompt": prompt,
|
|
||||||
"image": image,
|
|
||||||
"video": video
|
|
||||||
}, pipe)
|
|
||||||
prompt_filename = string_to_filename(prompt)[:25]
|
prompt_filename = string_to_filename(prompt)[:25]
|
||||||
artifacts = {
|
artifacts = {
|
||||||
"image": {"type": "image", "value": image},
|
"image": {"type": "image", "value": image},
|
||||||
@ -663,7 +654,7 @@ class Trainer:
|
|||||||
def __load_components(self):
|
def __load_components(self):
|
||||||
components = self.components.model_dump()
|
components = self.components.model_dump()
|
||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, 'to'):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
if name in self.UNLOAD_LIST:
|
if name in self.UNLOAD_LIST:
|
||||||
continue
|
continue
|
||||||
# setattr(self.components, name, component.to(self.accelerator.device))
|
# setattr(self.components, name, component.to(self.accelerator.device))
|
||||||
@ -672,9 +663,9 @@ class Trainer:
|
|||||||
def __unload_components(self):
|
def __unload_components(self):
|
||||||
components = self.components.model_dump()
|
components = self.components.model_dump()
|
||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, 'to'):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
if name in self.UNLOAD_LIST:
|
if name in self.UNLOAD_LIST:
|
||||||
setattr(self.components, name, component.to('cpu'))
|
setattr(self.components, name, component.to("cpu"))
|
||||||
|
|
||||||
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
||||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||||
@ -711,9 +702,7 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
transformer_ = unwrap_model(self.accelerator, model)
|
transformer_ = unwrap_model(self.accelerator, model)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
|
||||||
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
|
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
|
||||||
self.args.model_path, subfolder="transformer"
|
self.args.model_path, subfolder="transformer"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user