mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
refactor: simplify dataset implementation and add latent precomputation
- Replace bucket-based dataset with simpler resize-based implementation - Add video latent precomputation during dataset initialization - Improve code readability and user experience - Remove complexity of bucket sampling for better maintainability This change makes the codebase more straightforward and easier to use while maintaining functionality through resize-based video processing.
This commit is contained in:
parent
6eae5c201e
commit
45d40450a1
@ -44,7 +44,7 @@ from finetune.utils import (
|
||||
|
||||
string_to_filename
|
||||
)
|
||||
from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler
|
||||
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize, BucketSampler
|
||||
from finetune.datasets.utils import (
|
||||
load_prompts, load_images, load_videos,
|
||||
preprocess_image_with_resize, preprocess_video_with_resize
|
||||
@ -80,7 +80,6 @@ class Trainer:
|
||||
self._init_logging()
|
||||
self._init_directories()
|
||||
|
||||
|
||||
def _init_distributed(self):
|
||||
logging_dir = Path(self.args.output_dir, "logs")
|
||||
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
|
||||
@ -108,7 +107,6 @@ class Trainer:
|
||||
if self.args.seed is not None:
|
||||
set_seed(self.args.seed)
|
||||
|
||||
|
||||
def _init_logging(self) -> None:
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@ -124,33 +122,12 @@ class Trainer:
|
||||
|
||||
logger.info("Initialized Trainer")
|
||||
logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False)
|
||||
|
||||
|
||||
def _init_directories(self) -> None:
|
||||
if self.accelerator.is_main_process:
|
||||
self.args.output_dir = Path(self.args.output_dir)
|
||||
self.args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def prepare_dataset(self) -> None:
|
||||
logger.info("Initializing dataset and dataloader")
|
||||
|
||||
if self.args.model_type == "i2v":
|
||||
self.dataset = I2VDatasetWithBuckets(**(self.args.model_dump()))
|
||||
elif self.args.model_type == "t2v":
|
||||
self.dataset = T2VDatasetWithBuckets(**(self.args.model_dump()))
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
||||
|
||||
self.data_loader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=1,
|
||||
sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=self.args.pin_memory,
|
||||
)
|
||||
|
||||
def prepare_models(self) -> None:
|
||||
logger.info("Initializing models")
|
||||
|
||||
@ -163,6 +140,57 @@ class Trainer:
|
||||
if self.args.enable_tiling:
|
||||
self.components.vae.enable_tiling()
|
||||
|
||||
def prepare_dataset(self) -> None:
|
||||
logger.info("Initializing dataset and dataloader")
|
||||
|
||||
if self.args.model_type == "i2v":
|
||||
self.dataset = I2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
encode_fn=self.encode_video,
|
||||
max_num_frames=self.args.train_resolution[0],
|
||||
height=self.args.train_resolution[1],
|
||||
width=self.args.train_resolution[2]
|
||||
)
|
||||
elif self.args.model_type == "t2v":
|
||||
self.dataset = T2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
encode_fn=self.encode_video,
|
||||
max_num_frames=self.args.train_resolution[0],
|
||||
height=self.args.train_resolution[1],
|
||||
width=self.args.train_resolution[2]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
||||
|
||||
# Prepare VAE for encoding
|
||||
self.components.vae = self.components.vae.to(self.accelerator.device)
|
||||
self.components.vae.requires_grad_(False)
|
||||
|
||||
# Precompute latent for video
|
||||
logger.info("Precomputing latent for video ...")
|
||||
tmp_data_loader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=1,
|
||||
num_workers=0,
|
||||
pin_memory=self.args.pin_memory,
|
||||
)
|
||||
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
|
||||
for _ in tmp_data_loader: ...
|
||||
logger.info("Precomputing latent for video ... Done")
|
||||
|
||||
self.data_loader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.args.batch_size,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=self.args.pin_memory,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
|
||||
def prepare_trainable_parameters(self):
|
||||
logger.info("Initializing trainable parameters")
|
||||
|
||||
@ -275,7 +303,6 @@ class Trainer:
|
||||
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
||||
|
||||
def prepare_for_validation(self):
|
||||
self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')]
|
||||
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
|
||||
|
||||
if self.args.validation_images is not None:
|
||||
@ -447,7 +474,7 @@ class Trainer:
|
||||
|
||||
if image is not None:
|
||||
image = preprocess_image_with_resize(
|
||||
image, self.state.gen_height, self.state.gen_width
|
||||
image, self.args.train_resolution[1], self.args.train_resolution[2]
|
||||
)
|
||||
# Convert image tensor (C, H, W) to PIL images
|
||||
image = image.to(torch.uint8)
|
||||
@ -456,7 +483,7 @@ class Trainer:
|
||||
|
||||
if video is not None:
|
||||
video = preprocess_video_with_resize(
|
||||
video, self.state.gen_frames, self.state.gen_height, self.state.gen_width
|
||||
video, self.args.train_resolution[0], self.args.train_resolution[1], self.args.train_resolution[2]
|
||||
)
|
||||
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
||||
@ -531,8 +558,8 @@ class Trainer:
|
||||
self.components.transformer.train()
|
||||
|
||||
def fit(self):
|
||||
self.prepare_dataset()
|
||||
self.prepare_models()
|
||||
self.prepare_dataset()
|
||||
self.prepare_trainable_parameters()
|
||||
self.prepare_optimizer()
|
||||
self.prepare_for_training()
|
||||
@ -541,15 +568,15 @@ class Trainer:
|
||||
self.prepare_trackers()
|
||||
self.train()
|
||||
|
||||
def collate_fn(self, examples: List[List[Dict[str, Any]]]):
|
||||
"""
|
||||
Since we use BucketSampler, the examples parameter is a nested list where the outer list contains only one element,
|
||||
which is the batch data we need. Therefore, when processing the data, we need to access the batch through examples[0].
|
||||
"""
|
||||
def collate_fn(self, examples: List[Dict[str, Any]]):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_components(self) -> Components:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
||||
# shape of input video: [B, C, F, H, W], where B = 1
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@ -617,7 +644,7 @@ class Trainer:
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = self.components.piepeline_cls.lora_state_dict(input_dir)
|
||||
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v
|
||||
for k, v in lora_state_dict.items()
|
||||
|
Loading…
x
Reference in New Issue
Block a user