refactor: simplify dataset implementation and add latent precomputation

- Replace bucket-based dataset with simpler resize-based implementation
- Add video latent precomputation during dataset initialization
- Improve code readability and user experience
- Remove complexity of bucket sampling for better maintainability

This change makes the codebase more straightforward and easier to use while
maintaining functionality through resize-based video processing.
This commit is contained in:
OleehyO 2024-12-30 16:14:46 +00:00
parent 6eae5c201e
commit 45d40450a1

View File

@ -44,7 +44,7 @@ from finetune.utils import (
string_to_filename
)
from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize, BucketSampler
from finetune.datasets.utils import (
load_prompts, load_images, load_videos,
preprocess_image_with_resize, preprocess_video_with_resize
@ -80,7 +80,6 @@ class Trainer:
self._init_logging()
self._init_directories()
def _init_distributed(self):
logging_dir = Path(self.args.output_dir, "logs")
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
@ -108,7 +107,6 @@ class Trainer:
if self.args.seed is not None:
set_seed(self.args.seed)
def _init_logging(self) -> None:
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -124,33 +122,12 @@ class Trainer:
logger.info("Initialized Trainer")
logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False)
def _init_directories(self) -> None:
if self.accelerator.is_main_process:
self.args.output_dir = Path(self.args.output_dir)
self.args.output_dir.mkdir(parents=True, exist_ok=True)
def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader")
if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithBuckets(**(self.args.model_dump()))
elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithBuckets(**(self.args.model_dump()))
else:
raise ValueError(f"Invalid model type: {self.args.model_type}")
self.data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=1,
sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
collate_fn=self.collate_fn,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
)
def prepare_models(self) -> None:
logger.info("Initializing models")
@ -163,6 +140,57 @@ class Trainer:
if self.args.enable_tiling:
self.components.vae.enable_tiling()
def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader")
if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
encode_fn=self.encode_video,
max_num_frames=self.args.train_resolution[0],
height=self.args.train_resolution[1],
width=self.args.train_resolution[2]
)
elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
encode_fn=self.encode_video,
max_num_frames=self.args.train_resolution[0],
height=self.args.train_resolution[1],
width=self.args.train_resolution[2]
)
else:
raise ValueError(f"Invalid model type: {self.args.model_type}")
# Prepare VAE for encoding
self.components.vae = self.components.vae.to(self.accelerator.device)
self.components.vae.requires_grad_(False)
# Precompute latent for video
logger.info("Precomputing latent for video ...")
tmp_data_loader = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_size=1,
num_workers=0,
pin_memory=self.args.pin_memory,
)
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader: ...
logger.info("Precomputing latent for video ... Done")
self.data_loader = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
shuffle=True
)
def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters")
@ -275,7 +303,6 @@ class Trainer:
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
def prepare_for_validation(self):
self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')]
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
if self.args.validation_images is not None:
@ -447,7 +474,7 @@ class Trainer:
if image is not None:
image = preprocess_image_with_resize(
image, self.state.gen_height, self.state.gen_width
image, self.args.train_resolution[1], self.args.train_resolution[2]
)
# Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8)
@ -456,7 +483,7 @@ class Trainer:
if video is not None:
video = preprocess_video_with_resize(
video, self.state.gen_frames, self.state.gen_height, self.state.gen_width
video, self.args.train_resolution[0], self.args.train_resolution[1], self.args.train_resolution[2]
)
# Convert video tensor (F, C, H, W) to list of PIL images
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
@ -531,8 +558,8 @@ class Trainer:
self.components.transformer.train()
def fit(self):
self.prepare_dataset()
self.prepare_models()
self.prepare_dataset()
self.prepare_trainable_parameters()
self.prepare_optimizer()
self.prepare_for_training()
@ -541,15 +568,15 @@ class Trainer:
self.prepare_trackers()
self.train()
def collate_fn(self, examples: List[List[Dict[str, Any]]]):
"""
Since we use BucketSampler, the examples parameter is a nested list where the outer list contains only one element,
which is the batch data we need. Therefore, when processing the data, we need to access the batch through examples[0].
"""
def collate_fn(self, examples: List[Dict[str, Any]]):
raise NotImplementedError
def load_components(self) -> Components:
raise NotImplementedError
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W], where B = 1
raise NotImplementedError
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
@ -617,7 +644,7 @@ class Trainer:
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = self.components.piepeline_cls.lora_state_dict(input_dir)
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items()