mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
update sat
This commit is contained in:
parent
aeb7d9d056
commit
6f37641ee3
@ -130,6 +130,14 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
loss_dict = {"loss": loss_mean}
|
||||
return loss_mean, loss_dict
|
||||
|
||||
def add_noise_to_first_frame(self, image):
|
||||
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(self.device)
|
||||
sigma = torch.exp(sigma).to(image.dtype)
|
||||
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
||||
image = image + image_noise
|
||||
return image
|
||||
|
||||
|
||||
def shared_step(self, batch: Dict) -> Any:
|
||||
x = self.get_input(batch)
|
||||
if self.lr_scale is not None:
|
||||
|
@ -547,84 +547,3 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
print("Missing keys: ", missing_keys)
|
||||
print("Unexpected keys: ", unexpected_keys)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
|
||||
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
def __init__(
|
||||
self,
|
||||
cp_size=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.cp_size = cp_size
|
||||
return super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
use_cp: bool = True,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.cp_size <= 1:
|
||||
use_cp = False
|
||||
if self.cp_size > 0 and use_cp and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
||||
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
|
||||
|
||||
x = _conv_split(x, dim=2, kernel_size=1)
|
||||
|
||||
if return_reg_log:
|
||||
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
||||
else:
|
||||
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
||||
|
||||
if self.cp_size > 0 and use_cp and not output_cp:
|
||||
z = _conv_gather(z, dim=2, kernel_size=1)
|
||||
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(
|
||||
self,
|
||||
z: torch.Tensor,
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
use_cp: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if self.cp_size <= 1:
|
||||
use_cp = False
|
||||
if self.cp_size > 0 and use_cp and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
||||
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
|
||||
|
||||
z = _conv_split(z, dim=2, kernel_size=1)
|
||||
|
||||
x = super().decode(z, use_cp=use_cp, **kwargs)
|
||||
|
||||
if self.cp_size > 0 and use_cp and not output_cp:
|
||||
x = _conv_gather(x, dim=2, kernel_size=1)
|
||||
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
input_cp: bool = False,
|
||||
latent_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
**additional_decode_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
|
||||
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
Loading…
x
Reference in New Issue
Block a user