update sat

This commit is contained in:
zR 2024-09-12 22:27:41 +08:00
parent aeb7d9d056
commit 6f37641ee3
2 changed files with 9 additions and 82 deletions

View File

@ -130,6 +130,14 @@ class SATVideoDiffusionEngine(nn.Module):
loss_dict = {"loss": loss_mean}
return loss_mean, loss_dict
def add_noise_to_first_frame(self, image):
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(self.device)
sigma = torch.exp(sigma).to(image.dtype)
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
image = image + image_noise
return image
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.lr_scale is not None:

View File

@ -547,84 +547,3 @@ class VideoAutoencodingEngine(AutoencodingEngine):
print("Missing keys: ", missing_keys)
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
def __init__(
self,
cp_size=0,
*args,
**kwargs,
):
self.cp_size = cp_size
return super().__init__(*args, **kwargs)
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
input_cp: bool = False,
output_cp: bool = False,
use_cp: bool = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.cp_size <= 1:
use_cp = False
if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
x = _conv_split(x, dim=2, kernel_size=1)
if return_reg_log:
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
else:
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
if self.cp_size > 0 and use_cp and not output_cp:
z = _conv_gather(z, dim=2, kernel_size=1)
if return_reg_log:
return z, reg_log
return z
def decode(
self,
z: torch.Tensor,
input_cp: bool = False,
output_cp: bool = False,
use_cp: bool = True,
**kwargs,
):
if self.cp_size <= 1:
use_cp = False
if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
z = _conv_split(z, dim=2, kernel_size=1)
x = super().decode(z, use_cp=use_cp, **kwargs)
if self.cp_size > 0 and use_cp and not output_cp:
x = _conv_gather(x, dim=2, kernel_size=1)
return x
def forward(
self,
x: torch.Tensor,
input_cp: bool = False,
latent_cp: bool = False,
output_cp: bool = False,
**additional_decode_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
return z, dec, reg_log