From 45d40450a1ba50ceaf1d57b6fec7917cd13f8d80 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 30 Dec 2024 16:14:46 +0000 Subject: [PATCH] refactor: simplify dataset implementation and add latent precomputation - Replace bucket-based dataset with simpler resize-based implementation - Add video latent precomputation during dataset initialization - Improve code readability and user experience - Remove complexity of bucket sampling for better maintainability This change makes the codebase more straightforward and easier to use while maintaining functionality through resize-based video processing. --- finetune/trainer.py | 95 +++++++++++++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 34 deletions(-) diff --git a/finetune/trainer.py b/finetune/trainer.py index 501a7e5..5b02f0e 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -44,7 +44,7 @@ from finetune.utils import ( string_to_filename ) -from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler +from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize, BucketSampler from finetune.datasets.utils import ( load_prompts, load_images, load_videos, preprocess_image_with_resize, preprocess_video_with_resize @@ -80,7 +80,6 @@ class Trainer: self._init_logging() self._init_directories() - def _init_distributed(self): logging_dir = Path(self.args.output_dir, "logs") project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) @@ -108,7 +107,6 @@ class Trainer: if self.args.seed is not None: set_seed(self.args.seed) - def _init_logging(self) -> None: logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -124,33 +122,12 @@ class Trainer: logger.info("Initialized Trainer") logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False) - def _init_directories(self) -> None: if self.accelerator.is_main_process: self.args.output_dir = Path(self.args.output_dir) self.args.output_dir.mkdir(parents=True, exist_ok=True) - - def prepare_dataset(self) -> None: - logger.info("Initializing dataset and dataloader") - - if self.args.model_type == "i2v": - self.dataset = I2VDatasetWithBuckets(**(self.args.model_dump())) - elif self.args.model_type == "t2v": - self.dataset = T2VDatasetWithBuckets(**(self.args.model_dump())) - else: - raise ValueError(f"Invalid model type: {self.args.model_type}") - - self.data_loader = torch.utils.data.DataLoader( - self.dataset, - batch_size=1, - sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), - collate_fn=self.collate_fn, - num_workers=self.args.num_workers, - pin_memory=self.args.pin_memory, - ) - def prepare_models(self) -> None: logger.info("Initializing models") @@ -163,6 +140,57 @@ class Trainer: if self.args.enable_tiling: self.components.vae.enable_tiling() + def prepare_dataset(self) -> None: + logger.info("Initializing dataset and dataloader") + + if self.args.model_type == "i2v": + self.dataset = I2VDatasetWithResize( + **(self.args.model_dump()), + device=self.accelerator.device, + encode_fn=self.encode_video, + max_num_frames=self.args.train_resolution[0], + height=self.args.train_resolution[1], + width=self.args.train_resolution[2] + ) + elif self.args.model_type == "t2v": + self.dataset = T2VDatasetWithResize( + **(self.args.model_dump()), + device=self.accelerator.device, + encode_fn=self.encode_video, + max_num_frames=self.args.train_resolution[0], + height=self.args.train_resolution[1], + width=self.args.train_resolution[2] + ) + else: + raise ValueError(f"Invalid model type: {self.args.model_type}") + + # Prepare VAE for encoding + self.components.vae = self.components.vae.to(self.accelerator.device) + self.components.vae.requires_grad_(False) + + # Precompute latent for video + logger.info("Precomputing latent for video ...") + tmp_data_loader = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_size=1, + num_workers=0, + pin_memory=self.args.pin_memory, + ) + tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) + for _ in tmp_data_loader: ... + logger.info("Precomputing latent for video ... Done") + + self.data_loader = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=self.args.pin_memory, + shuffle=True + ) + + def prepare_trainable_parameters(self): logger.info("Initializing trainable parameters") @@ -275,7 +303,6 @@ class Trainer: self.state.num_update_steps_per_epoch = num_update_steps_per_epoch def prepare_for_validation(self): - self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')] validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts) if self.args.validation_images is not None: @@ -447,7 +474,7 @@ class Trainer: if image is not None: image = preprocess_image_with_resize( - image, self.state.gen_height, self.state.gen_width + image, self.args.train_resolution[1], self.args.train_resolution[2] ) # Convert image tensor (C, H, W) to PIL images image = image.to(torch.uint8) @@ -456,7 +483,7 @@ class Trainer: if video is not None: video = preprocess_video_with_resize( - video, self.state.gen_frames, self.state.gen_height, self.state.gen_width + video, self.args.train_resolution[0], self.args.train_resolution[1], self.args.train_resolution[2] ) # Convert video tensor (F, C, H, W) to list of PIL images video = (video * 255).round().clamp(0, 255).to(torch.uint8) @@ -531,8 +558,8 @@ class Trainer: self.components.transformer.train() def fit(self): - self.prepare_dataset() self.prepare_models() + self.prepare_dataset() self.prepare_trainable_parameters() self.prepare_optimizer() self.prepare_for_training() @@ -541,15 +568,15 @@ class Trainer: self.prepare_trackers() self.train() - def collate_fn(self, examples: List[List[Dict[str, Any]]]): - """ - Since we use BucketSampler, the examples parameter is a nested list where the outer list contains only one element, - which is the batch data we need. Therefore, when processing the data, we need to access the batch through examples[0]. - """ + def collate_fn(self, examples: List[Dict[str, Any]]): raise NotImplementedError def load_components(self) -> Components: raise NotImplementedError + + def encode_video(self, video: torch.Tensor) -> torch.Tensor: + # shape of input video: [B, C, F, H, W], where B = 1 + raise NotImplementedError def compute_loss(self, batch) -> torch.Tensor: raise NotImplementedError @@ -617,7 +644,7 @@ class Trainer: ) transformer_.add_adapter(transformer_lora_config) - lora_state_dict = self.components.piepeline_cls.lora_state_dict(input_dir) + lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir) transformer_state_dict = { f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items()