mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-25 16:55:58 +08:00
format and check fp16 for cogvideox2b
This commit is contained in:
parent
1b886326b2
commit
1789f07256
@ -1,2 +1,2 @@
|
|||||||
LOG_NAME = "trainer"
|
LOG_NAME = "trainer"
|
||||||
LOG_LEVEL = "INFO"
|
LOG_LEVEL = "INFO"
|
||||||
|
@ -8,5 +8,5 @@ __all__ = [
|
|||||||
"I2VDatasetWithBuckets",
|
"I2VDatasetWithBuckets",
|
||||||
"T2VDatasetWithResize",
|
"T2VDatasetWithResize",
|
||||||
"T2VDatasetWithBuckets",
|
"T2VDatasetWithBuckets",
|
||||||
"BucketSampler"
|
"BucketSampler",
|
||||||
]
|
]
|
||||||
|
@ -37,7 +37,6 @@ class BucketSampler(Sampler):
|
|||||||
|
|
||||||
self._raised_warning_for_drop_last = False
|
self._raised_warning_for_drop_last = False
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self.drop_last and not self._raised_warning_for_drop_last:
|
if self.drop_last and not self._raised_warning_for_drop_last:
|
||||||
self._raised_warning_for_drop_last = True
|
self._raised_warning_for_drop_last = True
|
||||||
@ -46,7 +45,6 @@ class BucketSampler(Sampler):
|
|||||||
)
|
)
|
||||||
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
||||||
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for index, data in enumerate(self.data_source):
|
for index, data in enumerate(self.data_source):
|
||||||
video_metadata = data["video_metadata"]
|
video_metadata = data["video_metadata"]
|
||||||
|
@ -13,11 +13,12 @@ from safetensors.torch import save_file, load_file
|
|||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
load_prompts, load_videos, load_images,
|
load_prompts,
|
||||||
|
load_videos,
|
||||||
|
load_images,
|
||||||
preprocess_image_with_resize,
|
preprocess_image_with_resize,
|
||||||
preprocess_video_with_resize,
|
preprocess_video_with_resize,
|
||||||
preprocess_video_with_buckets
|
preprocess_video_with_buckets,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -46,6 +47,7 @@ class BaseI2VDataset(Dataset):
|
|||||||
device (torch.device): Device to load the data on
|
device (torch.device): Device to load the data on
|
||||||
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_root: str,
|
data_root: str,
|
||||||
@ -55,7 +57,7 @@ class BaseI2VDataset(Dataset):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
trainer: "Trainer" = None,
|
trainer: "Trainer" = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -120,7 +122,10 @@ class BaseI2VDataset(Dataset):
|
|||||||
|
|
||||||
if prompt_embedding_path.exists():
|
if prompt_embedding_path.exists():
|
||||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||||
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
logger.debug(
|
||||||
|
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
||||||
|
main_process_only=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_embedding = self.encode_text(prompt)
|
prompt_embedding = self.encode_text(prompt)
|
||||||
prompt_embedding = prompt_embedding.to("cpu")
|
prompt_embedding = prompt_embedding.to("cpu")
|
||||||
@ -187,7 +192,7 @@ class BaseI2VDataset(Dataset):
|
|||||||
- image(torch.Tensor) of shape [C, H, W]
|
- image(torch.Tensor) of shape [C, H, W]
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Subclass must implement this method")
|
raise NotImplementedError("Subclass must implement this method")
|
||||||
|
|
||||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Applies transformations to a video.
|
Applies transformations to a video.
|
||||||
@ -197,14 +202,14 @@ class BaseI2VDataset(Dataset):
|
|||||||
with shape [F, C, H, W] where:
|
with shape [F, C, H, W] where:
|
||||||
- F is number of frames
|
- F is number of frames
|
||||||
- C is number of channels (3 for RGB)
|
- C is number of channels (3 for RGB)
|
||||||
- H is height
|
- H is height
|
||||||
- W is width
|
- W is width
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The transformed video tensor
|
torch.Tensor: The transformed video tensor
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Subclass must implement this method")
|
raise NotImplementedError("Subclass must implement this method")
|
||||||
|
|
||||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Applies transformations to an image.
|
Applies transformations to an image.
|
||||||
@ -213,7 +218,7 @@ class BaseI2VDataset(Dataset):
|
|||||||
image (torch.Tensor): A 3D tensor representing an image
|
image (torch.Tensor): A 3D tensor representing an image
|
||||||
with shape [C, H, W] where:
|
with shape [C, H, W] where:
|
||||||
- C is number of channels (3 for RGB)
|
- C is number of channels (3 for RGB)
|
||||||
- H is height
|
- H is height
|
||||||
- W is width
|
- W is width
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -235,6 +240,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
|||||||
height (int): Target height for resizing videos and images
|
height (int): Target height for resizing videos and images
|
||||||
width (int): Target width for resizing videos and images
|
width (int): Target width for resizing videos and images
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@ -242,11 +248,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|
||||||
self.__frame_transforms = transforms.Compose(
|
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||||
[
|
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.__image_transforms = self.__frame_transforms
|
self.__image_transforms = self.__frame_transforms
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -260,25 +262,25 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
|||||||
else:
|
else:
|
||||||
image = None
|
image = None
|
||||||
return video, image
|
return video, image
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||||
return self.__image_transforms(image)
|
return self.__image_transforms(image)
|
||||||
|
|
||||||
|
|
||||||
class I2VDatasetWithBuckets(BaseI2VDataset):
|
class I2VDatasetWithBuckets(BaseI2VDataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||||
vae_temporal_compression_ratio: int,
|
vae_temporal_compression_ratio: int,
|
||||||
vae_height_compression_ratio: int,
|
vae_height_compression_ratio: int,
|
||||||
vae_width_compression_ratio: int,
|
vae_width_compression_ratio: int,
|
||||||
*args, **kwargs
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@ -290,23 +292,19 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
|
|||||||
)
|
)
|
||||||
for b in video_resolution_buckets
|
for b in video_resolution_buckets
|
||||||
]
|
]
|
||||||
self.__frame_transforms = transforms.Compose(
|
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||||
[
|
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.__image_transforms = self.__frame_transforms
|
self.__image_transforms = self.__frame_transforms
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
||||||
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
|
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
|
||||||
return video, image
|
return video, image
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||||
return self.__image_transforms(image)
|
return self.__image_transforms(image)
|
||||||
|
@ -12,11 +12,7 @@ from safetensors.torch import save_file, load_file
|
|||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
|
||||||
from .utils import (
|
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
|
||||||
load_prompts, load_videos,
|
|
||||||
preprocess_video_with_resize,
|
|
||||||
preprocess_video_with_buckets
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from finetune.trainer import Trainer
|
from finetune.trainer import Trainer
|
||||||
@ -52,7 +48,7 @@ class BaseT2VDataset(Dataset):
|
|||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
trainer: "Trainer" = None,
|
trainer: "Trainer" = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -108,7 +104,10 @@ class BaseT2VDataset(Dataset):
|
|||||||
|
|
||||||
if prompt_embedding_path.exists():
|
if prompt_embedding_path.exists():
|
||||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||||
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
logger.debug(
|
||||||
|
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
||||||
|
main_process_only=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_embedding = self.encode_text(prompt)
|
prompt_embedding = self.encode_text(prompt)
|
||||||
prompt_embedding = prompt_embedding.to("cpu")
|
prompt_embedding = prompt_embedding.to("cpu")
|
||||||
@ -164,7 +163,7 @@ class BaseT2VDataset(Dataset):
|
|||||||
- W is width
|
- W is width
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Subclass must implement this method")
|
raise NotImplementedError("Subclass must implement this method")
|
||||||
|
|
||||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Applies transformations to a video.
|
Applies transformations to a video.
|
||||||
@ -174,7 +173,7 @@ class BaseT2VDataset(Dataset):
|
|||||||
with shape [F, C, H, W] where:
|
with shape [F, C, H, W] where:
|
||||||
- F is number of frames
|
- F is number of frames
|
||||||
- C is number of channels (3 for RGB)
|
- C is number of channels (3 for RGB)
|
||||||
- H is height
|
- H is height
|
||||||
- W is width
|
- W is width
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -203,36 +202,33 @@ class T2VDatasetWithResize(BaseT2VDataset):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|
||||||
self.__frame_transform = transforms.Compose(
|
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||||
[
|
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
return preprocess_video_with_resize(
|
return preprocess_video_with_resize(
|
||||||
video_path, self.max_num_frames, self.height, self.width,
|
video_path,
|
||||||
|
self.max_num_frames,
|
||||||
|
self.height,
|
||||||
|
self.width,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
||||||
|
|
||||||
|
|
||||||
class T2VDatasetWithBuckets(BaseT2VDataset):
|
class T2VDatasetWithBuckets(BaseT2VDataset):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||||
vae_temporal_compression_ratio: int,
|
vae_temporal_compression_ratio: int,
|
||||||
vae_height_compression_ratio: int,
|
vae_height_compression_ratio: int,
|
||||||
vae_width_compression_ratio: int,
|
vae_width_compression_ratio: int,
|
||||||
*args, **kwargs
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
""" """
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.video_resolution_buckets = [
|
self.video_resolution_buckets = [
|
||||||
@ -244,18 +240,12 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
|
|||||||
for b in video_resolution_buckets
|
for b in video_resolution_buckets
|
||||||
]
|
]
|
||||||
|
|
||||||
self.__frame_transform = transforms.Compose(
|
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||||
[
|
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
return preprocess_video_with_buckets(
|
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
||||||
video_path, self.video_resolution_buckets
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
||||||
|
@ -15,6 +15,7 @@ decord.bridge.set_bridge("torch")
|
|||||||
|
|
||||||
########## loaders ##########
|
########## loaders ##########
|
||||||
|
|
||||||
|
|
||||||
def load_prompts(prompt_path: Path) -> List[str]:
|
def load_prompts(prompt_path: Path) -> List[str]:
|
||||||
with open(prompt_path, "r", encoding="utf-8") as file:
|
with open(prompt_path, "r", encoding="utf-8") as file:
|
||||||
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||||
@ -32,6 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
|
|||||||
|
|
||||||
########## preprocessors ##########
|
########## preprocessors ##########
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image_with_resize(
|
def preprocess_image_with_resize(
|
||||||
image_path: Path | str,
|
image_path: Path | str,
|
||||||
height: int,
|
height: int,
|
||||||
@ -96,7 +98,7 @@ def preprocess_video_with_resize(
|
|||||||
|
|
||||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||||
frames = video_reader.get_batch(indices)
|
frames = video_reader.get_batch(indices)
|
||||||
frames = frames[: max_num_frames].float()
|
frames = frames[:max_num_frames].float()
|
||||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
@ -144,4 +146,4 @@ def preprocess_video_with_buckets(
|
|||||||
nearest_res = (nearest_res[1], nearest_res[2])
|
nearest_res = (nearest_res[1], nearest_res[2])
|
||||||
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
||||||
|
|
||||||
return frames
|
return frames
|
||||||
|
@ -5,8 +5,8 @@ from pathlib import Path
|
|||||||
package_dir = Path(__file__).parent
|
package_dir = Path(__file__).parent
|
||||||
|
|
||||||
for subdir in package_dir.iterdir():
|
for subdir in package_dir.iterdir():
|
||||||
if subdir.is_dir() and not subdir.name.startswith('_'):
|
if subdir.is_dir() and not subdir.name.startswith("_"):
|
||||||
for module_path in subdir.glob('*.py'):
|
for module_path in subdir.glob("*.py"):
|
||||||
module_name = module_path.stem
|
module_name = module_path.stem
|
||||||
full_module_name = f".{subdir.name}.{module_name}"
|
full_module_name = f".{subdir.name}.{module_name}"
|
||||||
importlib.import_module(full_module_name, package=__name__)
|
importlib.import_module(full_module_name, package=__name__)
|
||||||
|
@ -30,28 +30,18 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
components.pipeline_cls = CogVideoXImageToVideoPipeline
|
components.pipeline_cls = CogVideoXImageToVideoPipeline
|
||||||
|
|
||||||
components.tokenizer = AutoTokenizer.from_pretrained(
|
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||||
model_path, subfolder="tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
||||||
model_path, subfolder="text_encoder"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
||||||
model_path, subfolder="transformer"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(
|
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||||
model_path, subfolder="vae"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||||
model_path, subfolder="scheduler"
|
|
||||||
)
|
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
||||||
pipe = CogVideoXImageToVideoPipeline(
|
pipe = CogVideoXImageToVideoPipeline(
|
||||||
@ -59,7 +49,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
text_encoder=self.components.text_encoder,
|
text_encoder=self.components.text_encoder,
|
||||||
vae=self.components.vae,
|
vae=self.components.vae,
|
||||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||||
scheduler=self.components.scheduler
|
scheduler=self.components.scheduler,
|
||||||
)
|
)
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@ -71,7 +61,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
latent_dist = vae.encode(video).latent_dist
|
latent_dist = vae.encode(video).latent_dist
|
||||||
latent = latent_dist.sample() * vae.config.scaling_factor
|
latent = latent_dist.sample() * vae.config.scaling_factor
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def encode_text(self, prompt: str) -> torch.Tensor:
|
def encode_text(self, prompt: str) -> torch.Tensor:
|
||||||
prompt_token_ids = self.components.tokenizer(
|
prompt_token_ids = self.components.tokenizer(
|
||||||
@ -88,12 +78,8 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
ret = {
|
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
|
||||||
"encoded_videos": [],
|
|
||||||
"prompt_embedding": [],
|
|
||||||
"images": []
|
|
||||||
}
|
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
encoded_video = sample["encoded_video"]
|
encoded_video = sample["encoded_video"]
|
||||||
prompt_embedding = sample["prompt_embedding"]
|
prompt_embedding = sample["prompt_embedding"]
|
||||||
@ -102,13 +88,13 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
ret["encoded_videos"].append(encoded_video)
|
ret["encoded_videos"].append(encoded_video)
|
||||||
ret["prompt_embedding"].append(prompt_embedding)
|
ret["prompt_embedding"].append(prompt_embedding)
|
||||||
ret["images"].append(image)
|
ret["images"].append(image)
|
||||||
|
|
||||||
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
||||||
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
||||||
ret["images"] = torch.stack(ret["images"])
|
ret["images"] = torch.stack(ret["images"])
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(self, batch) -> torch.Tensor:
|
def compute_loss(self, batch) -> torch.Tensor:
|
||||||
prompt_embedding = batch["prompt_embedding"]
|
prompt_embedding = batch["prompt_embedding"]
|
||||||
@ -144,8 +130,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
# Sample a random timestep for each sample
|
# Sample a random timestep for each sample
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.components.scheduler.config.num_train_timesteps,
|
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
||||||
(batch_size,), device=self.accelerator.device
|
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
@ -183,7 +168,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Predict noise
|
# Predict noise
|
||||||
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
ofs_emb = (
|
||||||
|
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
||||||
|
)
|
||||||
predicted_noise = self.components.transformer(
|
predicted_noise = self.components.transformer(
|
||||||
hidden_states=latent_img_noisy,
|
hidden_states=latent_img_noisy,
|
||||||
encoder_hidden_states=prompt_embedding,
|
encoder_hidden_states=prompt_embedding,
|
||||||
@ -222,7 +209,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=image,
|
image=image,
|
||||||
generator=self.state.generator
|
generator=self.state.generator,
|
||||||
).frames[0]
|
).frames[0]
|
||||||
return [("video", video_generate)]
|
return [("video", video_generate)]
|
||||||
|
|
||||||
@ -233,7 +220,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
transformer_config: Dict,
|
transformer_config: Dict,
|
||||||
vae_scale_factor_spatial: int,
|
vae_scale_factor_spatial: int,
|
||||||
device: torch.device
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||||
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||||
@ -256,4 +243,4 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
return freqs_cos, freqs_sin
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
|
|
||||||
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
|
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
|
||||||
|
@ -31,25 +31,15 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
components.pipeline_cls = CogVideoXPipeline
|
components.pipeline_cls = CogVideoXPipeline
|
||||||
|
|
||||||
components.tokenizer = AutoTokenizer.from_pretrained(
|
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||||
model_path, subfolder="tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
||||||
model_path, subfolder="text_encoder"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
||||||
model_path, subfolder="transformer"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(
|
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||||
model_path, subfolder="vae"
|
|
||||||
)
|
|
||||||
|
|
||||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||||
model_path, subfolder="scheduler"
|
|
||||||
)
|
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
@ -60,10 +50,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
text_encoder=self.components.text_encoder,
|
text_encoder=self.components.text_encoder,
|
||||||
vae=self.components.vae,
|
vae=self.components.vae,
|
||||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||||
scheduler=self.components.scheduler
|
scheduler=self.components.scheduler,
|
||||||
)
|
)
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
||||||
# shape of input video: [B, C, F, H, W]
|
# shape of input video: [B, C, F, H, W]
|
||||||
@ -86,21 +76,18 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
prompt_token_ids = prompt_token_ids.input_ids
|
prompt_token_ids = prompt_token_ids.input_ids
|
||||||
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
||||||
return prompt_embedding
|
return prompt_embedding
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
ret = {
|
ret = {"encoded_videos": [], "prompt_embedding": []}
|
||||||
"encoded_videos": [],
|
|
||||||
"prompt_embedding": []
|
|
||||||
}
|
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
encoded_video = sample["encoded_video"]
|
encoded_video = sample["encoded_video"]
|
||||||
prompt_embedding = sample["prompt_embedding"]
|
prompt_embedding = sample["prompt_embedding"]
|
||||||
|
|
||||||
ret["encoded_videos"].append(encoded_video)
|
ret["encoded_videos"].append(encoded_video)
|
||||||
ret["prompt_embedding"].append(prompt_embedding)
|
ret["prompt_embedding"].append(prompt_embedding)
|
||||||
|
|
||||||
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
||||||
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
||||||
|
|
||||||
@ -116,10 +103,20 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
patch_size_t = self.state.transformer_config.patch_size_t
|
patch_size_t = self.state.transformer_config.patch_size_t
|
||||||
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
||||||
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
|
raise ValueError(
|
||||||
|
"Number of frames in latent must be divisible by patch size, please check your args for training."
|
||||||
|
)
|
||||||
|
|
||||||
# Add 2 random noise frames at the beginning of frame dimension
|
# Add 2 random noise frames at the beginning of frame dimension
|
||||||
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
|
noise_frames = torch.randn(
|
||||||
|
latent.shape[0],
|
||||||
|
latent.shape[1],
|
||||||
|
2,
|
||||||
|
latent.shape[3],
|
||||||
|
latent.shape[4],
|
||||||
|
device=latent.device,
|
||||||
|
dtype=latent.dtype,
|
||||||
|
)
|
||||||
latent = torch.cat([noise_frames, latent], dim=2)
|
latent = torch.cat([noise_frames, latent], dim=2)
|
||||||
|
|
||||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||||
@ -130,8 +127,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
# Sample a random timestep for each sample
|
# Sample a random timestep for each sample
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.components.scheduler.config.num_train_timesteps,
|
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
||||||
(batch_size,), device=self.accelerator.device
|
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
@ -193,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
generator=self.state.generator
|
generator=self.state.generator,
|
||||||
).frames[0]
|
).frames[0]
|
||||||
return [("video", video_generate)]
|
return [("video", video_generate)]
|
||||||
|
|
||||||
@ -204,7 +200,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
transformer_config: Dict,
|
transformer_config: Dict,
|
||||||
vae_scale_factor_spatial: int,
|
vae_scale_factor_spatial: int,
|
||||||
device: torch.device
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||||
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||||
|
@ -2,4 +2,4 @@ from .args import Args
|
|||||||
from .state import State
|
from .state import State
|
||||||
from .components import Components
|
from .components import Components
|
||||||
|
|
||||||
__all__ = ["Args", "State", "Components"]
|
__all__ = ["Args", "State", "Components"]
|
||||||
|
@ -78,10 +78,10 @@ class Args(BaseModel):
|
|||||||
########## Validation ##########
|
########## Validation ##########
|
||||||
do_validation: bool = False
|
do_validation: bool = False
|
||||||
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
|
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
|
||||||
validation_dir: Path | None # if set do_validation, should not be None
|
validation_dir: Path | None # if set do_validation, should not be None
|
||||||
validation_prompts: str | None # if set do_validation, should not be None
|
validation_prompts: str | None # if set do_validation, should not be None
|
||||||
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
||||||
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
|
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
|
||||||
gen_fps: int = 15
|
gen_fps: int = 15
|
||||||
|
|
||||||
#### deprecated args: gen_video_resolution
|
#### deprecated args: gen_video_resolution
|
||||||
@ -115,7 +115,7 @@ class Args(BaseModel):
|
|||||||
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
|
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("validation_videos")
|
@field_validator("validation_videos")
|
||||||
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
|
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||||
values = info.data
|
values = info.data
|
||||||
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
|
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
|
||||||
@ -131,31 +131,32 @@ class Args(BaseModel):
|
|||||||
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
|
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
|
||||||
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
|
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("train_resolution")
|
@field_validator("train_resolution")
|
||||||
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
|
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
|
||||||
try:
|
try:
|
||||||
frames, height, width = v
|
frames, height, width = v
|
||||||
|
|
||||||
# Check if (frames - 1) is multiple of 8
|
# Check if (frames - 1) is multiple of 8
|
||||||
if (frames - 1) % 8 != 0:
|
if (frames - 1) % 8 != 0:
|
||||||
raise ValueError("Number of frames - 1 must be a multiple of 8")
|
raise ValueError("Number of frames - 1 must be a multiple of 8")
|
||||||
|
|
||||||
# Check resolution for cogvideox-5b models
|
# Check resolution for cogvideox-5b models
|
||||||
model_name = info.data.get("model_name", "")
|
model_name = info.data.get("model_name", "")
|
||||||
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
|
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
|
||||||
if (height, width) != (480, 720):
|
if (height, width) != (480, 720):
|
||||||
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
|
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if str(e) == "not enough values to unpack (expected 3, got 0)" or \
|
if (
|
||||||
str(e) == "invalid literal for int() with base 10":
|
str(e) == "not enough values to unpack (expected 3, got 0)"
|
||||||
|
or str(e) == "invalid literal for int() with base 10"
|
||||||
|
):
|
||||||
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_args(cls):
|
def parse_args(cls):
|
||||||
"""Parse command line arguments and return Args instance"""
|
"""Parse command line arguments and return Args instance"""
|
||||||
@ -208,8 +209,7 @@ class Args(BaseModel):
|
|||||||
# LoRA parameters
|
# LoRA parameters
|
||||||
parser.add_argument("--rank", type=int, default=128)
|
parser.add_argument("--rank", type=int, default=128)
|
||||||
parser.add_argument("--lora_alpha", type=int, default=64)
|
parser.add_argument("--lora_alpha", type=int, default=64)
|
||||||
parser.add_argument("--target_modules", type=str, nargs="+",
|
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
|
||||||
default=["to_q", "to_k", "to_v", "to_out.0"])
|
|
||||||
|
|
||||||
# Checkpointing
|
# Checkpointing
|
||||||
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
||||||
@ -226,7 +226,7 @@ class Args(BaseModel):
|
|||||||
parser.add_argument("--gen_fps", type=int, default=15)
|
parser.add_argument("--gen_fps", type=int, default=15)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Convert video_resolution_buckets string to list of tuples
|
# Convert video_resolution_buckets string to list of tuples
|
||||||
frames, height, width = args.train_resolution.split("x")
|
frames, height, width = args.train_resolution.split("x")
|
||||||
args.train_resolution = (int(frames), int(height), int(width))
|
args.train_resolution = (int(frames), int(height), int(width))
|
||||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
|
||||||
class State(BaseModel):
|
class State(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
@ -3,11 +3,15 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory")
|
parser.add_argument(
|
||||||
|
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Create data/images directory if it doesn't exist
|
# Create data/images directory if it doesn't exist
|
||||||
@ -24,24 +28,24 @@ with open(videos_file, "r") as f:
|
|||||||
image_paths = []
|
image_paths = []
|
||||||
for video_rel_path in video_paths:
|
for video_rel_path in video_paths:
|
||||||
video_path = data_dir / video_rel_path
|
video_path = data_dir / video_rel_path
|
||||||
|
|
||||||
# Open video
|
# Open video
|
||||||
cap = cv2.VideoCapture(str(video_path))
|
cap = cv2.VideoCapture(str(video_path))
|
||||||
|
|
||||||
# Read first frame
|
# Read first frame
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
print(f"Failed to read video: {video_path}")
|
print(f"Failed to read video: {video_path}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Save frame as PNG with same name as video
|
# Save frame as PNG with same name as video
|
||||||
image_name = f"images/{video_path.stem}.png"
|
image_name = f"images/{video_path.stem}.png"
|
||||||
image_path = data_dir / image_name
|
image_path = data_dir / image_name
|
||||||
cv2.imwrite(str(image_path), frame)
|
cv2.imwrite(str(image_path), frame)
|
||||||
|
|
||||||
# Release video capture
|
# Release video capture
|
||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
print(f"Extracted first frame from {video_path} to {image_path}")
|
print(f"Extracted first frame from {video_path} to {image_path}")
|
||||||
image_paths.append(image_name)
|
image_paths.append(image_name)
|
||||||
|
|
||||||
@ -49,4 +53,4 @@ for video_rel_path in video_paths:
|
|||||||
images_file = data_dir / "images.txt"
|
images_file = data_dir / "images.txt"
|
||||||
with open(images_file, "w") as f:
|
with open(images_file, "w") as f:
|
||||||
for path in image_paths:
|
for path in image_paths:
|
||||||
f.write(f"{path}\n")
|
f.write(f"{path}\n")
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
@ -32,24 +31,23 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
|
|||||||
|
|
||||||
from finetune.schemas import Args, State, Components
|
from finetune.schemas import Args, State, Components
|
||||||
from finetune.utils import (
|
from finetune.utils import (
|
||||||
unwrap_model, cast_training_params,
|
unwrap_model,
|
||||||
|
cast_training_params,
|
||||||
get_optimizer,
|
get_optimizer,
|
||||||
|
|
||||||
get_memory_statistics,
|
get_memory_statistics,
|
||||||
free_memory,
|
free_memory,
|
||||||
unload_model,
|
unload_model,
|
||||||
|
|
||||||
get_latest_ckpt_path_to_resume_from,
|
get_latest_ckpt_path_to_resume_from,
|
||||||
get_intermediate_ckpt_path,
|
get_intermediate_ckpt_path,
|
||||||
get_latest_ckpt_path_to_resume_from,
|
string_to_filename,
|
||||||
get_intermediate_ckpt_path,
|
|
||||||
|
|
||||||
string_to_filename
|
|
||||||
)
|
)
|
||||||
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
||||||
from finetune.datasets.utils import (
|
from finetune.datasets.utils import (
|
||||||
load_prompts, load_images, load_videos,
|
load_prompts,
|
||||||
preprocess_image_with_resize, preprocess_video_with_resize
|
load_images,
|
||||||
|
load_videos,
|
||||||
|
preprocess_image_with_resize,
|
||||||
|
preprocess_video_with_resize,
|
||||||
)
|
)
|
||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
@ -59,22 +57,22 @@ logger = get_logger(LOG_NAME, LOG_LEVEL)
|
|||||||
|
|
||||||
_DTYPE_MAP = {
|
_DTYPE_MAP = {
|
||||||
"fp32": torch.float32,
|
"fp32": torch.float32,
|
||||||
"fp16": torch.float16,
|
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
|
||||||
"bf16": torch.bfloat16,
|
"bf16": torch.bfloat16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
# If set, should be a list of components to unload (refer to `Components``)
|
# If set, should be a list of components to unload (refer to `Components``)
|
||||||
UNLOAD_LIST: List[str] = None
|
UNLOAD_LIST: List[str] = None
|
||||||
|
|
||||||
def __init__(self, args: Args) -> None:
|
def __init__(self, args: Args) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.state = State(
|
self.state = State(
|
||||||
weight_dtype=self.__get_training_dtype(),
|
weight_dtype=self.__get_training_dtype(),
|
||||||
train_frames=self.args.train_resolution[0],
|
train_frames=self.args.train_resolution[0],
|
||||||
train_height=self.args.train_resolution[1],
|
train_height=self.args.train_resolution[1],
|
||||||
train_width=self.args.train_resolution[2]
|
train_width=self.args.train_resolution[2],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.components = Components()
|
self.components = Components()
|
||||||
@ -136,11 +134,13 @@ class Trainer:
|
|||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
self.args.output_dir = Path(self.args.output_dir)
|
self.args.output_dir = Path(self.args.output_dir)
|
||||||
self.args.output_dir.mkdir(parents=True, exist_ok=True)
|
self.args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def check_setting(self) -> None:
|
def check_setting(self) -> None:
|
||||||
# Check for unload_list
|
# Check for unload_list
|
||||||
if self.UNLOAD_LIST is None:
|
if self.UNLOAD_LIST is None:
|
||||||
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m")
|
logger.warning(
|
||||||
|
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for name in self.UNLOAD_LIST:
|
for name in self.UNLOAD_LIST:
|
||||||
if name not in self.components.model_fields:
|
if name not in self.components.model_fields:
|
||||||
@ -174,7 +174,7 @@ class Trainer:
|
|||||||
max_num_frames=sample_frames,
|
max_num_frames=sample_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
trainer=self
|
trainer=self,
|
||||||
)
|
)
|
||||||
elif self.args.model_type == "t2v":
|
elif self.args.model_type == "t2v":
|
||||||
self.dataset = T2VDatasetWithResize(
|
self.dataset = T2VDatasetWithResize(
|
||||||
@ -183,7 +183,7 @@ class Trainer:
|
|||||||
max_num_frames=sample_frames,
|
max_num_frames=sample_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
trainer=self
|
trainer=self,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
||||||
@ -204,7 +204,8 @@ class Trainer:
|
|||||||
pin_memory=self.args.pin_memory,
|
pin_memory=self.args.pin_memory,
|
||||||
)
|
)
|
||||||
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
|
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
|
||||||
for _ in tmp_data_loader: ...
|
for _ in tmp_data_loader:
|
||||||
|
...
|
||||||
self.accelerator.wait_for_everyone()
|
self.accelerator.wait_for_everyone()
|
||||||
logger.info("Precomputing latent for video and prompt embedding ... Done")
|
logger.info("Precomputing latent for video and prompt embedding ... Done")
|
||||||
|
|
||||||
@ -218,16 +219,15 @@ class Trainer:
|
|||||||
batch_size=self.args.batch_size,
|
batch_size=self.args.batch_size,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
pin_memory=self.args.pin_memory,
|
pin_memory=self.args.pin_memory,
|
||||||
shuffle=True
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_trainable_parameters(self):
|
def prepare_trainable_parameters(self):
|
||||||
logger.info("Initializing trainable parameters")
|
logger.info("Initializing trainable parameters")
|
||||||
|
|
||||||
# For now only lora is supported
|
# For now only lora is supported
|
||||||
for attr_name, component in vars(self.components).items():
|
for attr_name, component in vars(self.components).items():
|
||||||
if hasattr(component, 'requires_grad_'):
|
if hasattr(component, "requires_grad_"):
|
||||||
component.requires_grad_(False)
|
component.requires_grad_(False)
|
||||||
|
|
||||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||||
@ -332,7 +332,7 @@ class Trainer:
|
|||||||
# Afterwards we recalculate our number of training epochs
|
# Afterwards we recalculate our number of training epochs
|
||||||
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
|
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
|
||||||
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
||||||
|
|
||||||
def prepare_for_validation(self):
|
def prepare_for_validation(self):
|
||||||
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
|
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
|
||||||
|
|
||||||
@ -452,10 +452,7 @@ class Trainer:
|
|||||||
progress_bar.set_postfix(logs)
|
progress_bar.set_postfix(logs)
|
||||||
|
|
||||||
# Maybe run validation
|
# Maybe run validation
|
||||||
should_run_validation = (
|
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
|
||||||
self.args.do_validation
|
|
||||||
and global_step % self.args.validation_steps == 0
|
|
||||||
)
|
|
||||||
if should_run_validation:
|
if should_run_validation:
|
||||||
del loss
|
del loss
|
||||||
free_memory()
|
free_memory()
|
||||||
@ -500,7 +497,7 @@ class Trainer:
|
|||||||
|
|
||||||
##### Initialize pipeline #####
|
##### Initialize pipeline #####
|
||||||
pipe = self.initialize_pipeline()
|
pipe = self.initialize_pipeline()
|
||||||
|
|
||||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||||
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
||||||
|
|
||||||
@ -520,9 +517,7 @@ class Trainer:
|
|||||||
video = self.state.validation_videos[i]
|
video = self.state.validation_videos[i]
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image = preprocess_image_with_resize(
|
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
|
||||||
image, self.state.train_height, self.state.train_width
|
|
||||||
)
|
|
||||||
# Convert image tensor (C, H, W) to PIL images
|
# Convert image tensor (C, H, W) to PIL images
|
||||||
image = image.to(torch.uint8)
|
image = image.to(torch.uint8)
|
||||||
image = image.permute(1, 2, 0).cpu().numpy()
|
image = image.permute(1, 2, 0).cpu().numpy()
|
||||||
@ -534,17 +529,13 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
# Convert video tensor (F, C, H, W) to list of PIL images
|
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||||
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
||||||
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
|
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
||||||
main_process_only=False,
|
main_process_only=False,
|
||||||
)
|
)
|
||||||
validation_artifacts = self.validation_step({
|
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||||
"prompt": prompt,
|
|
||||||
"image": image,
|
|
||||||
"video": video
|
|
||||||
}, pipe)
|
|
||||||
prompt_filename = string_to_filename(prompt)[:25]
|
prompt_filename = string_to_filename(prompt)[:25]
|
||||||
artifacts = {
|
artifacts = {
|
||||||
"image": {"type": "image", "value": image},
|
"image": {"type": "image", "value": image},
|
||||||
@ -611,7 +602,7 @@ class Trainer:
|
|||||||
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||||
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||||
|
|
||||||
torch.set_grad_enabled(True)
|
torch.set_grad_enabled(True)
|
||||||
self.components.transformer.train()
|
self.components.transformer.train()
|
||||||
|
|
||||||
def fit(self):
|
def fit(self):
|
||||||
@ -628,10 +619,10 @@ class Trainer:
|
|||||||
|
|
||||||
def collate_fn(self, examples: List[Dict[str, Any]]):
|
def collate_fn(self, examples: List[Dict[str, Any]]):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def load_components(self) -> Components:
|
def load_components(self) -> Components:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def initialize_pipeline(self) -> DiffusionPipeline:
|
def initialize_pipeline(self) -> DiffusionPipeline:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -643,7 +634,7 @@ class Trainer:
|
|||||||
def encode_text(self, text: str) -> torch.Tensor:
|
def encode_text(self, text: str) -> torch.Tensor:
|
||||||
# shape of output text: [batch size, sequence length, embedding dimension]
|
# shape of output text: [batch size, sequence length, embedding dimension]
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def compute_loss(self, batch) -> torch.Tensor:
|
def compute_loss(self, batch) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -663,18 +654,18 @@ class Trainer:
|
|||||||
def __load_components(self):
|
def __load_components(self):
|
||||||
components = self.components.model_dump()
|
components = self.components.model_dump()
|
||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, 'to'):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
if name in self.UNLOAD_LIST:
|
if name in self.UNLOAD_LIST:
|
||||||
continue
|
continue
|
||||||
# setattr(self.components, name, component.to(self.accelerator.device))
|
# setattr(self.components, name, component.to(self.accelerator.device))
|
||||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
|
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
|
||||||
|
|
||||||
def __unload_components(self):
|
def __unload_components(self):
|
||||||
components = self.components.model_dump()
|
components = self.components.model_dump()
|
||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, 'to'):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
if name in self.UNLOAD_LIST:
|
if name in self.UNLOAD_LIST:
|
||||||
setattr(self.components, name, component.to('cpu'))
|
setattr(self.components, name, component.to("cpu"))
|
||||||
|
|
||||||
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
||||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||||
@ -711,9 +702,7 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
transformer_ = unwrap_model(self.accelerator, model)
|
transformer_ = unwrap_model(self.accelerator, model)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
|
||||||
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
|
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
|
||||||
self.args.model_path, subfolder="transformer"
|
self.args.model_path, subfolder="transformer"
|
||||||
|
@ -49,4 +49,4 @@ def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], d
|
|||||||
for param in m.parameters():
|
for param in m.parameters():
|
||||||
# only upcast trainable parameters into fp32
|
# only upcast trainable parameters into fp32
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
param.data = param.to(dtype)
|
param.data = param.to(dtype)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user