mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-26 03:59:17 +08:00
finetune I2V sat
This commit is contained in:
parent
6f37641ee3
commit
b8001a769f
@ -362,7 +362,7 @@ class SFTDataset(Dataset):
|
|||||||
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
||||||
"""
|
"""
|
||||||
super(SFTDataset, self).__init__()
|
super(SFTDataset, self).__init__()
|
||||||
|
|
||||||
self.video_size = video_size
|
self.video_size = video_size
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
self.max_num_frames = max_num_frames
|
self.max_num_frames = max_num_frames
|
||||||
@ -385,7 +385,6 @@ class SFTDataset(Dataset):
|
|||||||
self.captions.append(caption)
|
self.captions.append(caption)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
decord.bridge.set_bridge("torch")
|
decord.bridge.set_bridge("torch")
|
||||||
|
|
||||||
video_path = self.video_paths[index]
|
video_path = self.video_paths[index]
|
||||||
@ -411,9 +410,7 @@ class SFTDataset(Dataset):
|
|||||||
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
||||||
temp_frms = vr.get_batch(np.arange(start, end))
|
temp_frms = vr.get_batch(np.arange(start, end))
|
||||||
assert temp_frms is not None
|
assert temp_frms is not None
|
||||||
tensor_frms = (
|
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
|
||||||
)
|
|
||||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@ -426,15 +423,11 @@ class SFTDataset(Dataset):
|
|||||||
|
|
||||||
start = int(self.skip_frms_num)
|
start = int(self.skip_frms_num)
|
||||||
end = int(ori_vlen - self.skip_frms_num)
|
end = int(ori_vlen - self.skip_frms_num)
|
||||||
num_frames = nearest_smaller_4k_plus_1(
|
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
|
||||||
end - start
|
|
||||||
) # 3D VAE requires the number of frames to be 4k+1
|
|
||||||
end = int(start + num_frames)
|
end = int(start + num_frames)
|
||||||
temp_frms = vr.get_batch(np.arange(start, end))
|
temp_frms = vr.get_batch(np.arange(start, end))
|
||||||
assert temp_frms is not None
|
assert temp_frms is not None
|
||||||
tensor_frms = (
|
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
|
||||||
)
|
|
||||||
|
|
||||||
tensor_frms = pad_last_frame(
|
tensor_frms = pad_last_frame(
|
||||||
tensor_frms, self.max_num_frames
|
tensor_frms, self.max_num_frames
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
@ -137,7 +139,6 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
image = image + image_noise
|
image = image + image_noise
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def shared_step(self, batch: Dict) -> Any:
|
def shared_step(self, batch: Dict) -> Any:
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
if self.lr_scale is not None:
|
if self.lr_scale is not None:
|
||||||
@ -147,8 +148,22 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
batch["lr_input"] = lr_z
|
batch["lr_input"] = lr_z
|
||||||
|
|
||||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
if self.noised_image_input:
|
||||||
|
image = x[:, :, 0:1]
|
||||||
|
image = self.add_noise_to_first_frame(image)
|
||||||
|
image = self.encode_first_stage(image, batch)
|
||||||
|
|
||||||
x = self.encode_first_stage(x, batch)
|
x = self.encode_first_stage(x, batch)
|
||||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
if self.noised_image_input:
|
||||||
|
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
if self.noised_image_all_concat:
|
||||||
|
image = image.repeat(1, x.shape[1], 1, 1, 1)
|
||||||
|
else:
|
||||||
|
image = torch.concat([image, torch.zeros_like(x[:, 1:])], dim=1)
|
||||||
|
if random.random() < self.noised_image_dropout:
|
||||||
|
image = torch.zeros_like(image)
|
||||||
|
batch["concat_images"] = image
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -311,6 +326,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
if self.noised_image_input:
|
if self.noised_image_input:
|
||||||
|
print("Adding noise to first frame")
|
||||||
image = x[:, :, 0:1]
|
image = x[:, :, 0:1]
|
||||||
image = self.add_noise_to_first_frame(image)
|
image = self.add_noise_to_first_frame(image)
|
||||||
image = self.encode_first_stage(image, batch)
|
image = self.encode_first_stage(image, batch)
|
||||||
|
@ -546,4 +546,4 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
||||||
print("Missing keys: ", missing_keys)
|
print("Missing keys: ", missing_keys)
|
||||||
print("Unexpected keys: ", unexpected_keys)
|
print("Unexpected keys: ", unexpected_keys)
|
||||||
print(f"Restored from {path}")
|
print(f"Restored from {path}")
|
||||||
|
@ -6,7 +6,6 @@ from ..util import (
|
|||||||
get_context_parallel_group,
|
get_context_parallel_group,
|
||||||
get_context_parallel_rank,
|
get_context_parallel_rank,
|
||||||
get_context_parallel_world_size,
|
get_context_parallel_world_size,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_USE_CP = True
|
_USE_CP = True
|
||||||
@ -179,4 +178,4 @@ def _conv_gather(input_, dim, kernel_size):
|
|||||||
|
|
||||||
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -100,8 +100,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "concat_images" in batch.keys():
|
if "concat_images" in batch.keys():
|
||||||
additional_model_inputs["concat_images"] = batch["concat_images"]
|
cond["concat"] = batch["concat_images"]
|
||||||
|
|
||||||
|
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
|
||||||
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
|
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
|
||||||
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
||||||
|
|
||||||
@ -117,11 +118,3 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
|||||||
elif self.type == "lpips":
|
elif self.type == "lpips":
|
||||||
loss = self.lpips(model_output, target).reshape(-1)
|
loss = self.lpips(model_output, target).reshape(-1)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def get_3d_position_ids(frame_len, h, w):
|
|
||||||
i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w)
|
|
||||||
j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w)
|
|
||||||
k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w)
|
|
||||||
position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3)
|
|
||||||
return position_ids
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user