mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
finetune I2V sat
This commit is contained in:
parent
6f37641ee3
commit
b8001a769f
@ -362,7 +362,7 @@ class SFTDataset(Dataset):
|
||||
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
||||
"""
|
||||
super(SFTDataset, self).__init__()
|
||||
|
||||
|
||||
self.video_size = video_size
|
||||
self.fps = fps
|
||||
self.max_num_frames = max_num_frames
|
||||
@ -385,7 +385,6 @@ class SFTDataset(Dataset):
|
||||
self.captions.append(caption)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
video_path = self.video_paths[index]
|
||||
@ -411,9 +410,7 @@ class SFTDataset(Dataset):
|
||||
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
|
||||
@ -426,15 +423,11 @@ class SFTDataset(Dataset):
|
||||
|
||||
start = int(self.skip_frms_num)
|
||||
end = int(ori_vlen - self.skip_frms_num)
|
||||
num_frames = nearest_smaller_4k_plus_1(
|
||||
end - start
|
||||
) # 3D VAE requires the number of frames to be 4k+1
|
||||
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
|
||||
end = int(start + num_frames)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
|
||||
tensor_frms = pad_last_frame(
|
||||
tensor_frms, self.max_num_frames
|
||||
|
@ -1,3 +1,5 @@
|
||||
import random
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from omegaconf import ListConfig
|
||||
@ -137,7 +139,6 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
image = image + image_noise
|
||||
return image
|
||||
|
||||
|
||||
def shared_step(self, batch: Dict) -> Any:
|
||||
x = self.get_input(batch)
|
||||
if self.lr_scale is not None:
|
||||
@ -147,8 +148,22 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
batch["lr_input"] = lr_z
|
||||
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if self.noised_image_input:
|
||||
image = x[:, :, 0:1]
|
||||
image = self.add_noise_to_first_frame(image)
|
||||
image = self.encode_first_stage(image, batch)
|
||||
|
||||
x = self.encode_first_stage(x, batch)
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if self.noised_image_input:
|
||||
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if self.noised_image_all_concat:
|
||||
image = image.repeat(1, x.shape[1], 1, 1, 1)
|
||||
else:
|
||||
image = torch.concat([image, torch.zeros_like(x[:, 1:])], dim=1)
|
||||
if random.random() < self.noised_image_dropout:
|
||||
image = torch.zeros_like(image)
|
||||
batch["concat_images"] = image
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@ -311,6 +326,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if self.noised_image_input:
|
||||
print("Adding noise to first frame")
|
||||
image = x[:, :, 0:1]
|
||||
image = self.add_noise_to_first_frame(image)
|
||||
image = self.encode_first_stage(image, batch)
|
||||
|
@ -546,4 +546,4 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
||||
print("Missing keys: ", missing_keys)
|
||||
print("Unexpected keys: ", unexpected_keys)
|
||||
print(f"Restored from {path}")
|
||||
print(f"Restored from {path}")
|
||||
|
@ -6,7 +6,6 @@ from ..util import (
|
||||
get_context_parallel_group,
|
||||
get_context_parallel_rank,
|
||||
get_context_parallel_world_size,
|
||||
|
||||
)
|
||||
|
||||
_USE_CP = True
|
||||
@ -179,4 +178,4 @@ def _conv_gather(input_, dim, kernel_size):
|
||||
|
||||
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||
|
||||
return output
|
||||
return output
|
||||
|
@ -100,8 +100,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
||||
)
|
||||
|
||||
if "concat_images" in batch.keys():
|
||||
additional_model_inputs["concat_images"] = batch["concat_images"]
|
||||
cond["concat"] = batch["concat_images"]
|
||||
|
||||
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
|
||||
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
|
||||
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
||||
|
||||
@ -117,11 +118,3 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
||||
elif self.type == "lpips":
|
||||
loss = self.lpips(model_output, target).reshape(-1)
|
||||
return loss
|
||||
|
||||
|
||||
def get_3d_position_ids(frame_len, h, w):
|
||||
i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w)
|
||||
j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w)
|
||||
k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w)
|
||||
position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3)
|
||||
return position_ids
|
||||
|
Loading…
x
Reference in New Issue
Block a user