diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 6bee4ce..b099504 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -130,6 +130,14 @@ class SATVideoDiffusionEngine(nn.Module): loss_dict = {"loss": loss_mean} return loss_mean, loss_dict + def add_noise_to_first_frame(self, image): + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(self.device) + sigma = torch.exp(sigma).to(image.dtype) + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image = image + image_noise + return image + + def shared_step(self, batch: Dict) -> Any: x = self.get_input(batch) if self.lr_scale is not None: diff --git a/sat/sgm/models/autoencoder.py b/sat/sgm/models/autoencoder.py index 9ae44d0..08b04e5 100644 --- a/sat/sgm/models/autoencoder.py +++ b/sat/sgm/models/autoencoder.py @@ -546,85 +546,4 @@ class VideoAutoencodingEngine(AutoencodingEngine): missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) print("Missing keys: ", missing_keys) print("Unexpected keys: ", unexpected_keys) - print(f"Restored from {path}") - - -class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): - def __init__( - self, - cp_size=0, - *args, - **kwargs, - ): - self.cp_size = cp_size - return super().__init__(*args, **kwargs) - - def encode( - self, - x: torch.Tensor, - return_reg_log: bool = False, - unregularized: bool = False, - input_cp: bool = False, - output_cp: bool = False, - use_cp: bool = True, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - if self.cp_size <= 1: - use_cp = False - if self.cp_size > 0 and use_cp and not input_cp: - if not is_context_parallel_initialized: - initialize_context_parallel(self.cp_size) - - global_src_rank = get_context_parallel_group_rank() * self.cp_size - torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group()) - - x = _conv_split(x, dim=2, kernel_size=1) - - if return_reg_log: - z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp) - else: - z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp) - - if self.cp_size > 0 and use_cp and not output_cp: - z = _conv_gather(z, dim=2, kernel_size=1) - - if return_reg_log: - return z, reg_log - return z - - def decode( - self, - z: torch.Tensor, - input_cp: bool = False, - output_cp: bool = False, - use_cp: bool = True, - **kwargs, - ): - if self.cp_size <= 1: - use_cp = False - if self.cp_size > 0 and use_cp and not input_cp: - if not is_context_parallel_initialized: - initialize_context_parallel(self.cp_size) - - global_src_rank = get_context_parallel_group_rank() * self.cp_size - torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group()) - - z = _conv_split(z, dim=2, kernel_size=1) - - x = super().decode(z, use_cp=use_cp, **kwargs) - - if self.cp_size > 0 and use_cp and not output_cp: - x = _conv_gather(x, dim=2, kernel_size=1) - - return x - - def forward( - self, - x: torch.Tensor, - input_cp: bool = False, - latent_cp: bool = False, - output_cp: bool = False, - **additional_decode_kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor, dict]: - z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp) - dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs) - return z, dec, reg_log + print(f"Restored from {path}") \ No newline at end of file