finetune I2V sat

This commit is contained in:
zR 2024-09-13 16:46:47 +08:00
parent 6f37641ee3
commit b8001a769f
5 changed files with 25 additions and 24 deletions

View File

@ -362,7 +362,7 @@ class SFTDataset(Dataset):
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
"""
super(SFTDataset, self).__init__()
self.video_size = video_size
self.fps = fps
self.max_num_frames = max_num_frames
@ -385,7 +385,6 @@ class SFTDataset(Dataset):
self.captions.append(caption)
def __getitem__(self, index):
decord.bridge.set_bridge("torch")
video_path = self.video_paths[index]
@ -411,9 +410,7 @@ class SFTDataset(Dataset):
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = (
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
)
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
@ -426,15 +423,11 @@ class SFTDataset(Dataset):
start = int(self.skip_frms_num)
end = int(ori_vlen - self.skip_frms_num)
num_frames = nearest_smaller_4k_plus_1(
end - start
) # 3D VAE requires the number of frames to be 4k+1
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
end = int(start + num_frames)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = (
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
)
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = pad_last_frame(
tensor_frms, self.max_num_frames

View File

@ -1,3 +1,5 @@
import random
import math
from typing import Any, Dict, List, Tuple, Union
from omegaconf import ListConfig
@ -137,7 +139,6 @@ class SATVideoDiffusionEngine(nn.Module):
image = image + image_noise
return image
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.lr_scale is not None:
@ -147,8 +148,22 @@ class SATVideoDiffusionEngine(nn.Module):
batch["lr_input"] = lr_z
x = x.permute(0, 2, 1, 3, 4).contiguous()
if self.noised_image_input:
image = x[:, :, 0:1]
image = self.add_noise_to_first_frame(image)
image = self.encode_first_stage(image, batch)
x = self.encode_first_stage(x, batch)
x = x.permute(0, 2, 1, 3, 4).contiguous()
if self.noised_image_input:
image = image.permute(0, 2, 1, 3, 4).contiguous()
if self.noised_image_all_concat:
image = image.repeat(1, x.shape[1], 1, 1, 1)
else:
image = torch.concat([image, torch.zeros_like(x[:, 1:])], dim=1)
if random.random() < self.noised_image_dropout:
image = torch.zeros_like(image)
batch["concat_images"] = image
gc.collect()
torch.cuda.empty_cache()
@ -311,6 +326,7 @@ class SATVideoDiffusionEngine(nn.Module):
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if self.noised_image_input:
print("Adding noise to first frame")
image = x[:, :, 0:1]
image = self.add_noise_to_first_frame(image)
image = self.encode_first_stage(image, batch)

View File

@ -546,4 +546,4 @@ class VideoAutoencodingEngine(AutoencodingEngine):
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print("Missing keys: ", missing_keys)
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
print(f"Restored from {path}")

View File

@ -6,7 +6,6 @@ from ..util import (
get_context_parallel_group,
get_context_parallel_rank,
get_context_parallel_world_size,
)
_USE_CP = True
@ -179,4 +178,4 @@ def _conv_gather(input_, dim, kernel_size):
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
return output

View File

@ -100,8 +100,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
)
if "concat_images" in batch.keys():
additional_model_inputs["concat_images"] = batch["concat_images"]
cond["concat"] = batch["concat_images"]
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
@ -117,11 +118,3 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
def get_3d_position_ids(frame_len, h, w):
i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w)
j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w)
k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w)
position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3)
return position_ids