mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 18:52:08 +08:00
Experiment: train lora lowvram
This commit is contained in:
parent
8aae1e1e42
commit
f69c3c87ec
@ -10,7 +10,7 @@ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
|
||||
accelerate launch --config_file accelerate_config_machine_single_weak.yaml --multi_gpu \
|
||||
accelerate launch --config_file accelerate_config_machine_single_single_process.yaml --multi_gpu \
|
||||
train_cogvideox_lora.py \
|
||||
--gradient_checkpointing \
|
||||
--pretrained_model_name_or_path $MODEL_PATH \
|
||||
|
||||
@ -416,6 +416,33 @@ def get_args():
|
||||
|
||||
return parsed_args
|
||||
|
||||
class Offloader:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def enable_sequential_cpu_offload(self, model):
|
||||
from accelerate import cpu_offload
|
||||
gpu_id = None
|
||||
device = "cuda"
|
||||
|
||||
torch_device = torch.device(device)
|
||||
device_index = torch_device.index
|
||||
|
||||
if gpu_id is not None and device_index is not None:
|
||||
raise ValueError(
|
||||
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
|
||||
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
|
||||
)
|
||||
|
||||
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
|
||||
self._offload_gpu_id = gpu_id or torch_device.index or 0
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
self._offload_device = device
|
||||
|
||||
offload_buffers = len(model._parameters) > 0
|
||||
cpu_offload(model, device, offload_buffers=offload_buffers)
|
||||
|
||||
class CachedVideoList:
|
||||
ACCELERATOR_DEVICE = 'cpu'
|
||||
@ -458,7 +485,7 @@ class CachedVideoList:
|
||||
def __getitem__(self, index):
|
||||
if index >= len(self.video_names):
|
||||
raise IndexError("Index out of bounds")
|
||||
pt_path = os.path.join(self.cache_dir, f'{video_names[index]}.pt')
|
||||
pt_path = os.path.join(CachedVideoList.cache_dir(), f'{self.video_names[index]}.pt')
|
||||
if not os.path.exists(pt_path):
|
||||
raise FileNotFoundError(f"Video {pt_path} not found in cache.")
|
||||
return torch.load(pt_path)
|
||||
@ -502,6 +529,9 @@ class VideoDataset(Dataset):
|
||||
else:
|
||||
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
|
||||
|
||||
if self.id_token is not None:
|
||||
self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]
|
||||
|
||||
self.num_instance_videos = len(self.instance_video_paths)
|
||||
if self.num_instance_videos != len(self.instance_prompts):
|
||||
raise ValueError(
|
||||
@ -515,7 +545,7 @@ class VideoDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {
|
||||
"instance_prompt": self.id_token + self.instance_prompts[index],
|
||||
"instance_prompt": self.instance_prompts[index],
|
||||
"instance_video": self.instance_videos[index],
|
||||
}
|
||||
|
||||
@ -593,6 +623,21 @@ class VideoDataset(Dataset):
|
||||
|
||||
return instance_prompts, instance_videos
|
||||
|
||||
def encode_prompts(self, tokenizer, text_encoder, device, dtype):
|
||||
encoded_prompts = []
|
||||
for index, prompt in enumerate(self.instance_prompts):
|
||||
print(f"Encoding prompt {index + 1} of {len(self.instance_prompts)}")
|
||||
prompt_embeds = compute_prompt_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt,
|
||||
max_sequence_length=226,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
encoded_prompts.append(prompt_embeds)
|
||||
self.instance_prompts = encoded_prompts
|
||||
|
||||
def encode_videos(self, vae, device):
|
||||
if not CachedVideoList.CACHE_ENABLED:
|
||||
self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos]
|
||||
@ -842,7 +887,7 @@ def _get_t5_prompt_embeds(
|
||||
if text_input_ids is None:
|
||||
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0].to(device)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
@ -1143,8 +1188,6 @@ def main(args):
|
||||
text_encoder.to('cpu', dtype=weight_dtype)
|
||||
transformer.to('cpu', dtype=weight_dtype)
|
||||
vae.to('cpu', dtype=weight_dtype)
|
||||
from accelerate import cpu_offload
|
||||
cpu_offload(transformer, accelerator.device, offload_buffers=False)
|
||||
else:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
transformer.to(accelerator.device, dtype=weight_dtype)
|
||||
@ -1286,6 +1329,13 @@ def main(args):
|
||||
if args.offload_to_cpu:
|
||||
vae.to('cpu')
|
||||
|
||||
if args.offload_to_cpu:
|
||||
text_encoder.to(accelerator.device)
|
||||
train_dataset.encode_prompts(tokenizer, text_encoder, accelerator.device, weight_dtype)
|
||||
|
||||
if args.offload_to_cpu:
|
||||
text_encoder.to('cpu')
|
||||
|
||||
def collate_fn(examples):
|
||||
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
@ -1333,9 +1383,15 @@ def main(args):
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
if not args.offload_to_cpu:
|
||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
offloader = Offloader()
|
||||
print(f"Offloading to CPU on device {CachedVideoList.ACCELERATOR_DEVICE}")
|
||||
offloader.enable_sequential_cpu_offload(transformer)
|
||||
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@ -1414,18 +1470,20 @@ def main(args):
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
prompts = batch["prompts"]
|
||||
|
||||
if args.offload_to_cpu:
|
||||
prompt_embeds = torch.cat(prompts)
|
||||
else:
|
||||
# encode prompts
|
||||
prompt_embeds = compute_prompt_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompts,
|
||||
model_config.max_text_seq_length,
|
||||
accelerator.device,
|
||||
accelerator.device if not args.offload_to_cpu else "cpu",
|
||||
weight_dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
prompt_embeds.to(device=accelerator.device)
|
||||
# Sample noise that will be added to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
batch_size, num_frames, num_channels, height, width = model_input.shape
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user