Experiment: train lora lowvram

This commit is contained in:
Rodrigo Antonio de Araujo 2024-10-09 10:32:16 -03:00
parent 8aae1e1e42
commit f69c3c87ec
3 changed files with 80 additions and 22 deletions

View File

@ -10,7 +10,7 @@ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES="0" export CUDA_VISIBLE_DEVICES="0"
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number # if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single_weak.yaml --multi_gpu \ accelerate launch --config_file accelerate_config_machine_single_single_process.yaml --multi_gpu \
train_cogvideox_lora.py \ train_cogvideox_lora.py \
--gradient_checkpointing \ --gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \ --pretrained_model_name_or_path $MODEL_PATH \

View File

@ -416,7 +416,34 @@ def get_args():
return parsed_args return parsed_args
class Offloader:
def __init__(self):
pass
def enable_sequential_cpu_offload(self, model):
from accelerate import cpu_offload
gpu_id = None
device = "cuda"
torch_device = torch.device(device)
device_index = torch_device.index
if gpu_id is not None and device_index is not None:
raise ValueError(
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
)
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
self._offload_gpu_id = gpu_id or torch_device.index or 0
device_type = torch_device.type
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
self._offload_device = device
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
class CachedVideoList: class CachedVideoList:
ACCELERATOR_DEVICE = 'cpu' ACCELERATOR_DEVICE = 'cpu'
CACHE_ENABLED = False CACHE_ENABLED = False
@ -458,7 +485,7 @@ class CachedVideoList:
def __getitem__(self, index): def __getitem__(self, index):
if index >= len(self.video_names): if index >= len(self.video_names):
raise IndexError("Index out of bounds") raise IndexError("Index out of bounds")
pt_path = os.path.join(self.cache_dir, f'{video_names[index]}.pt') pt_path = os.path.join(CachedVideoList.cache_dir(), f'{self.video_names[index]}.pt')
if not os.path.exists(pt_path): if not os.path.exists(pt_path):
raise FileNotFoundError(f"Video {pt_path} not found in cache.") raise FileNotFoundError(f"Video {pt_path} not found in cache.")
return torch.load(pt_path) return torch.load(pt_path)
@ -501,6 +528,9 @@ class VideoDataset(Dataset):
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
else: else:
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
if self.id_token is not None:
self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]
self.num_instance_videos = len(self.instance_video_paths) self.num_instance_videos = len(self.instance_video_paths)
if self.num_instance_videos != len(self.instance_prompts): if self.num_instance_videos != len(self.instance_prompts):
@ -515,7 +545,7 @@ class VideoDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
return { return {
"instance_prompt": self.id_token + self.instance_prompts[index], "instance_prompt": self.instance_prompts[index],
"instance_video": self.instance_videos[index], "instance_video": self.instance_videos[index],
} }
@ -593,6 +623,21 @@ class VideoDataset(Dataset):
return instance_prompts, instance_videos return instance_prompts, instance_videos
def encode_prompts(self, tokenizer, text_encoder, device, dtype):
encoded_prompts = []
for index, prompt in enumerate(self.instance_prompts):
print(f"Encoding prompt {index + 1} of {len(self.instance_prompts)}")
prompt_embeds = compute_prompt_embeddings(
tokenizer,
text_encoder,
prompt,
max_sequence_length=226,
device=device,
dtype=dtype,
)
encoded_prompts.append(prompt_embeds)
self.instance_prompts = encoded_prompts
def encode_videos(self, vae, device): def encode_videos(self, vae, device):
if not CachedVideoList.CACHE_ENABLED: if not CachedVideoList.CACHE_ENABLED:
self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos] self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos]
@ -842,7 +887,7 @@ def _get_t5_prompt_embeds(
if text_input_ids is None: if text_input_ids is None:
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0].to(device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
@ -1142,9 +1187,7 @@ def main(args):
if args.offload_to_cpu: if args.offload_to_cpu:
text_encoder.to('cpu', dtype=weight_dtype) text_encoder.to('cpu', dtype=weight_dtype)
transformer.to('cpu', dtype=weight_dtype) transformer.to('cpu', dtype=weight_dtype)
vae.to('cpu', dtype=weight_dtype) vae.to('cpu', dtype=weight_dtype)
from accelerate import cpu_offload
cpu_offload(transformer, accelerator.device, offload_buffers=False)
else: else:
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype)
@ -1285,6 +1328,13 @@ def main(args):
if args.offload_to_cpu: if args.offload_to_cpu:
vae.to('cpu') vae.to('cpu')
if args.offload_to_cpu:
text_encoder.to(accelerator.device)
train_dataset.encode_prompts(tokenizer, text_encoder, accelerator.device, weight_dtype)
if args.offload_to_cpu:
text_encoder.to('cpu')
def collate_fn(examples): def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
@ -1333,9 +1383,15 @@ def main(args):
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( if not args.offload_to_cpu:
transformer, optimizer, train_dataloader, lr_scheduler transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
) transformer, optimizer, train_dataloader, lr_scheduler
)
else:
offloader = Offloader()
print(f"Offloading to CPU on device {CachedVideoList.ACCELERATOR_DEVICE}")
offloader.enable_sequential_cpu_offload(transformer)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@ -1414,18 +1470,20 @@ def main(args):
with accelerator.accumulate(models_to_accumulate): with accelerator.accumulate(models_to_accumulate):
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
prompts = batch["prompts"] prompts = batch["prompts"]
if args.offload_to_cpu:
# encode prompts prompt_embeds = torch.cat(prompts)
prompt_embeds = compute_prompt_embeddings( else:
tokenizer, # encode prompts
text_encoder, prompt_embeds = compute_prompt_embeddings(
prompts, tokenizer,
model_config.max_text_seq_length, text_encoder,
accelerator.device, prompts,
weight_dtype, model_config.max_text_seq_length,
requires_grad=False, accelerator.device if not args.offload_to_cpu else "cpu",
) weight_dtype,
requires_grad=False,
)
prompt_embeds.to(device=accelerator.device)
# Sample noise that will be added to the latents # Sample noise that will be added to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
batch_size, num_frames, num_channels, height, width = model_input.shape batch_size, num_frames, num_channels, height, width = model_input.shape