Bring back the previous train_cogvideox_lora.py

This commit is contained in:
Rodrigo Antonio de Araujo 2024-10-09 16:56:50 -03:00
parent f69c3c87ec
commit b848932c13

View File

@ -401,94 +401,8 @@ def get_args():
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
), ),
) )
parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Whether or not to offload the model to the CPU.",
)
parser.add_argument(
"--cache_preprocessed_data",
action="store_true",
help="Whether or not to cache preprocessed data.",
)
parsed_args = parser.parse_args() return parser.parse_args()
return parsed_args
class Offloader:
def __init__(self):
pass
def enable_sequential_cpu_offload(self, model):
from accelerate import cpu_offload
gpu_id = None
device = "cuda"
torch_device = torch.device(device)
device_index = torch_device.index
if gpu_id is not None and device_index is not None:
raise ValueError(
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
)
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
self._offload_gpu_id = gpu_id or torch_device.index or 0
device_type = torch_device.type
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
self._offload_device = device
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
class CachedVideoList:
ACCELERATOR_DEVICE = 'cpu'
CACHE_ENABLED = False
VAE = None
OUTPUT_DIR = None
@classmethod
def is_cached(cls, video_name):
if cls.CACHE_ENABLED:
return os.path.exists(os.path.join(cls.cache_dir(), f'{video_name}.pt'))
return False
@classmethod
def cache_dir(cls):
result = os.path.join(cls.OUTPUT_DIR, "cached_videos")
if not os.path.exists(result):
os.makedirs(result)
return result
def __init__(self):
if not self.CACHE_ENABLED:
raise ValueError("CachedVideoList is not enabled. Please enable the cache before using it.")
if self.OUTPUT_DIR is None:
raise ValueError("Output directory not set. Please set the output directory before using the CachedVideoList.")
if self.VAE is None:
raise ValueError("VAE model not set. Please set the VAE model before using the CachedVideoList.")
self.video_names = []
def __len__(self):
return len(self.video_names)
def append(self, video: Tuple[str, torch.Tensor]):
self.video_names.append(video[0])
if video[1] is not None:
torch.save(video[1], os.path.join(CachedVideoList.cache_dir(), f'{video[0]}.pt'))
def __getitem__(self, index):
if index >= len(self.video_names):
raise IndexError("Index out of bounds")
pt_path = os.path.join(CachedVideoList.cache_dir(), f'{self.video_names[index]}.pt')
if not os.path.exists(pt_path):
raise FileNotFoundError(f"Video {pt_path} not found in cache.")
return torch.load(pt_path)
class VideoDataset(Dataset): class VideoDataset(Dataset):
@ -529,9 +443,6 @@ class VideoDataset(Dataset):
else: else:
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
if self.id_token is not None:
self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]
self.num_instance_videos = len(self.instance_video_paths) self.num_instance_videos = len(self.instance_video_paths)
if self.num_instance_videos != len(self.instance_prompts): if self.num_instance_videos != len(self.instance_prompts):
raise ValueError( raise ValueError(
@ -545,7 +456,7 @@ class VideoDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
return { return {
"instance_prompt": self.instance_prompts[index], "instance_prompt": self.id_token + self.instance_prompts[index],
"instance_video": self.instance_videos[index], "instance_video": self.instance_videos[index],
} }
@ -623,32 +534,6 @@ class VideoDataset(Dataset):
return instance_prompts, instance_videos return instance_prompts, instance_videos
def encode_prompts(self, tokenizer, text_encoder, device, dtype):
encoded_prompts = []
for index, prompt in enumerate(self.instance_prompts):
print(f"Encoding prompt {index + 1} of {len(self.instance_prompts)}")
prompt_embeds = compute_prompt_embeddings(
tokenizer,
text_encoder,
prompt,
max_sequence_length=226,
device=device,
dtype=dtype,
)
encoded_prompts.append(prompt_embeds)
self.instance_prompts = encoded_prompts
def encode_videos(self, vae, device):
if not CachedVideoList.CACHE_ENABLED:
self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos]
def encode_video(self, video, vae, device):
video = video.to(device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
vae.to(device)
latent_dist = vae.encode(video.to(device)).latent_dist
return latent_dist
def _preprocess_data(self): def _preprocess_data(self):
try: try:
import decord import decord
@ -659,7 +544,7 @@ class VideoDataset(Dataset):
decord.bridge.set_bridge("torch") decord.bridge.set_bridge("torch")
videos = CachedVideoList() if CachedVideoList.CACHE_ENABLED else [] videos = []
train_transforms = transforms.Compose( train_transforms = transforms.Compose(
[ [
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
@ -667,11 +552,6 @@ class VideoDataset(Dataset):
) )
for filename in self.instance_video_paths: for filename in self.instance_video_paths:
if CachedVideoList.CACHE_ENABLED and videos.is_cached(filename.stem):
videos.append((filename.stem, None))
continue
if CachedVideoList.CACHE_ENABLED:
print(f"Pre-processing video {filename.as_posix()}")
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_num_frames = len(video_reader) video_num_frames = len(video_reader)
@ -700,17 +580,7 @@ class VideoDataset(Dataset):
# Training transforms # Training transforms
frames = frames.float() frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
frames = frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W] videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
if CachedVideoList.CACHE_ENABLED:
print(f"Encoding video {filename.stem} device {CachedVideoList.ACCELERATOR_DEVICE}")
if CachedVideoList.CACHE_ENABLED:
videos.append((filename.stem, self.encode_video(frames, CachedVideoList.VAE, CachedVideoList.ACCELERATOR_DEVICE)))
else:
videos.append(frames) # [F, C, H, W]
if CachedVideoList.CACHE_ENABLED:
print(f"Video Encoded {filename.stem}")
return videos return videos
@ -887,7 +757,7 @@ def _get_t5_prompt_embeds(
if text_input_ids is None: if text_input_ids is None:
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
prompt_embeds = text_encoder(text_input_ids.to(device))[0].to(device) prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
@ -1184,16 +1054,11 @@ def main(args):
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
) )
if args.offload_to_cpu: text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.to('cpu', dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype)
transformer.to('cpu', dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
vae.to('cpu', dtype=weight_dtype)
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing or args.offload_to_cpu: if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing() transformer.enable_gradient_checkpointing()
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
@ -1298,12 +1163,6 @@ def main(args):
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
if args.cache_preprocessed_data:
CachedVideoList.OUTPUT_DIR = args.output_dir
CachedVideoList.CACHE_ENABLED = True
CachedVideoList.VAE = vae
CachedVideoList.ACCELERATOR_DEVICE = accelerator.device
# Dataset and DataLoader # Dataset and DataLoader
train_dataset = VideoDataset( train_dataset = VideoDataset(
instance_data_root=args.instance_data_root, instance_data_root=args.instance_data_root,
@ -1321,20 +1180,13 @@ def main(args):
id_token=args.id_token, id_token=args.id_token,
) )
if args.offload_to_cpu: def encode_video(video):
vae.to('cuda') video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist
train_dataset.encode_videos(vae, accelerator.device) train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
if args.offload_to_cpu:
vae.to('cpu')
if args.offload_to_cpu:
text_encoder.to(accelerator.device)
train_dataset.encode_prompts(tokenizer, text_encoder, accelerator.device, weight_dtype)
if args.offload_to_cpu:
text_encoder.to('cpu')
def collate_fn(examples): def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
@ -1383,15 +1235,9 @@ def main(args):
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
if not args.offload_to_cpu: transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler
transformer, optimizer, train_dataloader, lr_scheduler )
)
else:
offloader = Offloader()
print(f"Offloading to CPU on device {CachedVideoList.ACCELERATOR_DEVICE}")
offloader.enable_sequential_cpu_offload(transformer)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@ -1470,20 +1316,18 @@ def main(args):
with accelerator.accumulate(models_to_accumulate): with accelerator.accumulate(models_to_accumulate):
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
prompts = batch["prompts"] prompts = batch["prompts"]
if args.offload_to_cpu:
prompt_embeds = torch.cat(prompts) # encode prompts
else: prompt_embeds = compute_prompt_embeddings(
# encode prompts tokenizer,
prompt_embeds = compute_prompt_embeddings( text_encoder,
tokenizer, prompts,
text_encoder, model_config.max_text_seq_length,
prompts, accelerator.device,
model_config.max_text_seq_length, weight_dtype,
accelerator.device if not args.offload_to_cpu else "cpu", requires_grad=False,
weight_dtype, )
requires_grad=False,
)
prompt_embeds.to(device=accelerator.device)
# Sample noise that will be added to the latents # Sample noise that will be added to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
batch_size, num_frames, num_channels, height, width = model_input.shape batch_size, num_frames, num_channels, height, width = model_input.shape