mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
update sat
This commit is contained in:
parent
aeb7d9d056
commit
6f37641ee3
@ -130,6 +130,14 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
loss_dict = {"loss": loss_mean}
|
loss_dict = {"loss": loss_mean}
|
||||||
return loss_mean, loss_dict
|
return loss_mean, loss_dict
|
||||||
|
|
||||||
|
def add_noise_to_first_frame(self, image):
|
||||||
|
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(self.device)
|
||||||
|
sigma = torch.exp(sigma).to(image.dtype)
|
||||||
|
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
||||||
|
image = image + image_noise
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
def shared_step(self, batch: Dict) -> Any:
|
def shared_step(self, batch: Dict) -> Any:
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
if self.lr_scale is not None:
|
if self.lr_scale is not None:
|
||||||
|
@ -546,85 +546,4 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
||||||
print("Missing keys: ", missing_keys)
|
print("Missing keys: ", missing_keys)
|
||||||
print("Unexpected keys: ", unexpected_keys)
|
print("Unexpected keys: ", unexpected_keys)
|
||||||
print(f"Restored from {path}")
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
|
|
||||||
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cp_size=0,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.cp_size = cp_size
|
|
||||||
return super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
return_reg_log: bool = False,
|
|
||||||
unregularized: bool = False,
|
|
||||||
input_cp: bool = False,
|
|
||||||
output_cp: bool = False,
|
|
||||||
use_cp: bool = True,
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
|
||||||
if self.cp_size <= 1:
|
|
||||||
use_cp = False
|
|
||||||
if self.cp_size > 0 and use_cp and not input_cp:
|
|
||||||
if not is_context_parallel_initialized:
|
|
||||||
initialize_context_parallel(self.cp_size)
|
|
||||||
|
|
||||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
|
||||||
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
|
|
||||||
|
|
||||||
x = _conv_split(x, dim=2, kernel_size=1)
|
|
||||||
|
|
||||||
if return_reg_log:
|
|
||||||
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
|
||||||
else:
|
|
||||||
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
|
||||||
|
|
||||||
if self.cp_size > 0 and use_cp and not output_cp:
|
|
||||||
z = _conv_gather(z, dim=2, kernel_size=1)
|
|
||||||
|
|
||||||
if return_reg_log:
|
|
||||||
return z, reg_log
|
|
||||||
return z
|
|
||||||
|
|
||||||
def decode(
|
|
||||||
self,
|
|
||||||
z: torch.Tensor,
|
|
||||||
input_cp: bool = False,
|
|
||||||
output_cp: bool = False,
|
|
||||||
use_cp: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if self.cp_size <= 1:
|
|
||||||
use_cp = False
|
|
||||||
if self.cp_size > 0 and use_cp and not input_cp:
|
|
||||||
if not is_context_parallel_initialized:
|
|
||||||
initialize_context_parallel(self.cp_size)
|
|
||||||
|
|
||||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
|
||||||
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
|
|
||||||
|
|
||||||
z = _conv_split(z, dim=2, kernel_size=1)
|
|
||||||
|
|
||||||
x = super().decode(z, use_cp=use_cp, **kwargs)
|
|
||||||
|
|
||||||
if self.cp_size > 0 and use_cp and not output_cp:
|
|
||||||
x = _conv_gather(x, dim=2, kernel_size=1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
input_cp: bool = False,
|
|
||||||
latent_cp: bool = False,
|
|
||||||
output_cp: bool = False,
|
|
||||||
**additional_decode_kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
|
||||||
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
|
|
||||||
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
|
|
||||||
return z, dec, reg_log
|
|
Loading…
x
Reference in New Issue
Block a user