mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 18:52:08 +08:00
Bring back the previous train_cogvideox_lora.py
This commit is contained in:
parent
f69c3c87ec
commit
b848932c13
@ -401,94 +401,8 @@ def get_args():
|
|||||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--offload_to_cpu",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether or not to offload the model to the CPU.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cache_preprocessed_data",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether or not to cache preprocessed data.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parsed_args = parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
return parsed_args
|
|
||||||
|
|
||||||
class Offloader:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def enable_sequential_cpu_offload(self, model):
|
|
||||||
from accelerate import cpu_offload
|
|
||||||
gpu_id = None
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
torch_device = torch.device(device)
|
|
||||||
device_index = torch_device.index
|
|
||||||
|
|
||||||
if gpu_id is not None and device_index is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
|
|
||||||
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
|
|
||||||
self._offload_gpu_id = gpu_id or torch_device.index or 0
|
|
||||||
|
|
||||||
device_type = torch_device.type
|
|
||||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
|
||||||
self._offload_device = device
|
|
||||||
|
|
||||||
offload_buffers = len(model._parameters) > 0
|
|
||||||
cpu_offload(model, device, offload_buffers=offload_buffers)
|
|
||||||
|
|
||||||
class CachedVideoList:
|
|
||||||
ACCELERATOR_DEVICE = 'cpu'
|
|
||||||
CACHE_ENABLED = False
|
|
||||||
VAE = None
|
|
||||||
OUTPUT_DIR = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_cached(cls, video_name):
|
|
||||||
if cls.CACHE_ENABLED:
|
|
||||||
return os.path.exists(os.path.join(cls.cache_dir(), f'{video_name}.pt'))
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def cache_dir(cls):
|
|
||||||
result = os.path.join(cls.OUTPUT_DIR, "cached_videos")
|
|
||||||
if not os.path.exists(result):
|
|
||||||
os.makedirs(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not self.CACHE_ENABLED:
|
|
||||||
raise ValueError("CachedVideoList is not enabled. Please enable the cache before using it.")
|
|
||||||
if self.OUTPUT_DIR is None:
|
|
||||||
raise ValueError("Output directory not set. Please set the output directory before using the CachedVideoList.")
|
|
||||||
if self.VAE is None:
|
|
||||||
raise ValueError("VAE model not set. Please set the VAE model before using the CachedVideoList.")
|
|
||||||
|
|
||||||
self.video_names = []
|
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.video_names)
|
|
||||||
|
|
||||||
def append(self, video: Tuple[str, torch.Tensor]):
|
|
||||||
self.video_names.append(video[0])
|
|
||||||
if video[1] is not None:
|
|
||||||
torch.save(video[1], os.path.join(CachedVideoList.cache_dir(), f'{video[0]}.pt'))
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if index >= len(self.video_names):
|
|
||||||
raise IndexError("Index out of bounds")
|
|
||||||
pt_path = os.path.join(CachedVideoList.cache_dir(), f'{self.video_names[index]}.pt')
|
|
||||||
if not os.path.exists(pt_path):
|
|
||||||
raise FileNotFoundError(f"Video {pt_path} not found in cache.")
|
|
||||||
return torch.load(pt_path)
|
|
||||||
|
|
||||||
|
|
||||||
class VideoDataset(Dataset):
|
class VideoDataset(Dataset):
|
||||||
@ -529,9 +443,6 @@ class VideoDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
|
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
|
||||||
|
|
||||||
if self.id_token is not None:
|
|
||||||
self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]
|
|
||||||
|
|
||||||
self.num_instance_videos = len(self.instance_video_paths)
|
self.num_instance_videos = len(self.instance_video_paths)
|
||||||
if self.num_instance_videos != len(self.instance_prompts):
|
if self.num_instance_videos != len(self.instance_prompts):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -545,7 +456,7 @@ class VideoDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
return {
|
return {
|
||||||
"instance_prompt": self.instance_prompts[index],
|
"instance_prompt": self.id_token + self.instance_prompts[index],
|
||||||
"instance_video": self.instance_videos[index],
|
"instance_video": self.instance_videos[index],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -623,32 +534,6 @@ class VideoDataset(Dataset):
|
|||||||
|
|
||||||
return instance_prompts, instance_videos
|
return instance_prompts, instance_videos
|
||||||
|
|
||||||
def encode_prompts(self, tokenizer, text_encoder, device, dtype):
|
|
||||||
encoded_prompts = []
|
|
||||||
for index, prompt in enumerate(self.instance_prompts):
|
|
||||||
print(f"Encoding prompt {index + 1} of {len(self.instance_prompts)}")
|
|
||||||
prompt_embeds = compute_prompt_embeddings(
|
|
||||||
tokenizer,
|
|
||||||
text_encoder,
|
|
||||||
prompt,
|
|
||||||
max_sequence_length=226,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
encoded_prompts.append(prompt_embeds)
|
|
||||||
self.instance_prompts = encoded_prompts
|
|
||||||
|
|
||||||
def encode_videos(self, vae, device):
|
|
||||||
if not CachedVideoList.CACHE_ENABLED:
|
|
||||||
self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos]
|
|
||||||
|
|
||||||
def encode_video(self, video, vae, device):
|
|
||||||
video = video.to(device, dtype=vae.dtype).unsqueeze(0)
|
|
||||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
|
||||||
vae.to(device)
|
|
||||||
latent_dist = vae.encode(video.to(device)).latent_dist
|
|
||||||
return latent_dist
|
|
||||||
|
|
||||||
def _preprocess_data(self):
|
def _preprocess_data(self):
|
||||||
try:
|
try:
|
||||||
import decord
|
import decord
|
||||||
@ -659,7 +544,7 @@ class VideoDataset(Dataset):
|
|||||||
|
|
||||||
decord.bridge.set_bridge("torch")
|
decord.bridge.set_bridge("torch")
|
||||||
|
|
||||||
videos = CachedVideoList() if CachedVideoList.CACHE_ENABLED else []
|
videos = []
|
||||||
train_transforms = transforms.Compose(
|
train_transforms = transforms.Compose(
|
||||||
[
|
[
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
||||||
@ -667,11 +552,6 @@ class VideoDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for filename in self.instance_video_paths:
|
for filename in self.instance_video_paths:
|
||||||
if CachedVideoList.CACHE_ENABLED and videos.is_cached(filename.stem):
|
|
||||||
videos.append((filename.stem, None))
|
|
||||||
continue
|
|
||||||
if CachedVideoList.CACHE_ENABLED:
|
|
||||||
print(f"Pre-processing video {filename.as_posix()}")
|
|
||||||
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
||||||
video_num_frames = len(video_reader)
|
video_num_frames = len(video_reader)
|
||||||
|
|
||||||
@ -700,17 +580,7 @@ class VideoDataset(Dataset):
|
|||||||
# Training transforms
|
# Training transforms
|
||||||
frames = frames.float()
|
frames = frames.float()
|
||||||
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
|
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
|
||||||
frames = frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
|
||||||
|
|
||||||
if CachedVideoList.CACHE_ENABLED:
|
|
||||||
print(f"Encoding video {filename.stem} device {CachedVideoList.ACCELERATOR_DEVICE}")
|
|
||||||
|
|
||||||
if CachedVideoList.CACHE_ENABLED:
|
|
||||||
videos.append((filename.stem, self.encode_video(frames, CachedVideoList.VAE, CachedVideoList.ACCELERATOR_DEVICE)))
|
|
||||||
else:
|
|
||||||
videos.append(frames) # [F, C, H, W]
|
|
||||||
if CachedVideoList.CACHE_ENABLED:
|
|
||||||
print(f"Video Encoded {filename.stem}")
|
|
||||||
|
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
@ -887,7 +757,7 @@ def _get_t5_prompt_embeds(
|
|||||||
if text_input_ids is None:
|
if text_input_ids is None:
|
||||||
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
||||||
|
|
||||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0].to(device)
|
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
@ -1184,16 +1054,11 @@ def main(args):
|
|||||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.offload_to_cpu:
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||||
text_encoder.to('cpu', dtype=weight_dtype)
|
transformer.to(accelerator.device, dtype=weight_dtype)
|
||||||
transformer.to('cpu', dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
vae.to('cpu', dtype=weight_dtype)
|
|
||||||
else:
|
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
transformer.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
if args.gradient_checkpointing or args.offload_to_cpu:
|
if args.gradient_checkpointing:
|
||||||
transformer.enable_gradient_checkpointing()
|
transformer.enable_gradient_checkpointing()
|
||||||
|
|
||||||
# now we will add new LoRA weights to the attention layers
|
# now we will add new LoRA weights to the attention layers
|
||||||
@ -1298,12 +1163,6 @@ def main(args):
|
|||||||
|
|
||||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||||
|
|
||||||
if args.cache_preprocessed_data:
|
|
||||||
CachedVideoList.OUTPUT_DIR = args.output_dir
|
|
||||||
CachedVideoList.CACHE_ENABLED = True
|
|
||||||
CachedVideoList.VAE = vae
|
|
||||||
CachedVideoList.ACCELERATOR_DEVICE = accelerator.device
|
|
||||||
|
|
||||||
# Dataset and DataLoader
|
# Dataset and DataLoader
|
||||||
train_dataset = VideoDataset(
|
train_dataset = VideoDataset(
|
||||||
instance_data_root=args.instance_data_root,
|
instance_data_root=args.instance_data_root,
|
||||||
@ -1321,20 +1180,13 @@ def main(args):
|
|||||||
id_token=args.id_token,
|
id_token=args.id_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.offload_to_cpu:
|
def encode_video(video):
|
||||||
vae.to('cuda')
|
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
||||||
|
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||||
|
latent_dist = vae.encode(video).latent_dist
|
||||||
|
return latent_dist
|
||||||
|
|
||||||
train_dataset.encode_videos(vae, accelerator.device)
|
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
|
||||||
|
|
||||||
if args.offload_to_cpu:
|
|
||||||
vae.to('cpu')
|
|
||||||
|
|
||||||
if args.offload_to_cpu:
|
|
||||||
text_encoder.to(accelerator.device)
|
|
||||||
train_dataset.encode_prompts(tokenizer, text_encoder, accelerator.device, weight_dtype)
|
|
||||||
|
|
||||||
if args.offload_to_cpu:
|
|
||||||
text_encoder.to('cpu')
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
||||||
@ -1383,15 +1235,9 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
# Prepare everything with our `accelerator`.
|
||||||
if not args.offload_to_cpu:
|
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
transformer, optimizer, train_dataloader, lr_scheduler
|
||||||
transformer, optimizer, train_dataloader, lr_scheduler
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
offloader = Offloader()
|
|
||||||
print(f"Offloading to CPU on device {CachedVideoList.ACCELERATOR_DEVICE}")
|
|
||||||
offloader.enable_sequential_cpu_offload(transformer)
|
|
||||||
|
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
@ -1470,20 +1316,18 @@ def main(args):
|
|||||||
with accelerator.accumulate(models_to_accumulate):
|
with accelerator.accumulate(models_to_accumulate):
|
||||||
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||||
prompts = batch["prompts"]
|
prompts = batch["prompts"]
|
||||||
if args.offload_to_cpu:
|
|
||||||
prompt_embeds = torch.cat(prompts)
|
# encode prompts
|
||||||
else:
|
prompt_embeds = compute_prompt_embeddings(
|
||||||
# encode prompts
|
tokenizer,
|
||||||
prompt_embeds = compute_prompt_embeddings(
|
text_encoder,
|
||||||
tokenizer,
|
prompts,
|
||||||
text_encoder,
|
model_config.max_text_seq_length,
|
||||||
prompts,
|
accelerator.device,
|
||||||
model_config.max_text_seq_length,
|
weight_dtype,
|
||||||
accelerator.device if not args.offload_to_cpu else "cpu",
|
requires_grad=False,
|
||||||
weight_dtype,
|
)
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
prompt_embeds.to(device=accelerator.device)
|
|
||||||
# Sample noise that will be added to the latents
|
# Sample noise that will be added to the latents
|
||||||
noise = torch.randn_like(model_input)
|
noise = torch.randn_like(model_input)
|
||||||
batch_size, num_frames, num_channels, height, width = model_input.shape
|
batch_size, num_frames, num_channels, height, width = model_input.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user