diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 0000000..19271ef --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,51 @@ +name: "\U0001F41B Bug Report" +description: Submit a bug report to help us improve CogVideoX / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX 开源模型 +body: + - type: textarea + id: system-info + attributes: + label: System Info / 系統信息 + description: Your operating environment / 您的运行环境信息 + placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... + validations: + required: true + + - type: checkboxes + id: information-scripts-examples + attributes: + label: Information / 问题信息 + description: 'The problem arises when using: / 问题出现在' + options: + - label: "The official example scripts / 官方的示例脚本" + - label: "My own modified scripts / 我自己修改的脚本和任务" + + - type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction / 复现过程 + description: | + Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. + If you have code snippets, error messages, stack traces, please provide them here as well. + Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting + Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. + + 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 + 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 + 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting + 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 + placeholder: | + Steps to reproduce the behavior/复现Bug的步骤: + + 1. + 2. + 3. + + - type: textarea + id: expected-behavior + validations: + required: true + attributes: + label: Expected behavior / 期待表现 + description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature-request.yaml b/.github/ISSUE_TEMPLATE/feature-request.yaml new file mode 100644 index 0000000..7e09bee --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yaml @@ -0,0 +1,34 @@ +name: "\U0001F680 Feature request" +description: Submit a request for a new CogVideoX feature / 提交一个新的 CogVideoX开源模型的功能建议 +labels: [ "feature" ] +body: + - type: textarea + id: feature-request + validations: + required: true + attributes: + label: Feature request / 功能建议 + description: | + A brief description of the functional proposal. Links to corresponding papers and code are desirable. + 对功能建议的简述。最好提供对应的论文和代码链接。 + + - type: textarea + id: motivation + validations: + required: true + attributes: + label: Motivation / 动机 + description: | + Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. + 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 + + - type: textarea + id: contribution + validations: + required: true + attributes: + label: Your contribution / 您的贡献 + description: | + + Your PR link or any other link you can help with. + 您的PR链接或者其他您能提供帮助的链接。 \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE/pr_template.md b/.github/PULL_REQUEST_TEMPLATE/pr_template.md new file mode 100644 index 0000000..0c3140a --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/pr_template.md @@ -0,0 +1,34 @@ +# Raise valuable PR / 提出有价值的PR + +## Caution / 注意事项: +Users should keep the following points in mind when submitting PRs: + +1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md). +2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs. + +用户在提交PR时候应该注意以下几点: + +1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。 +2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。 + +## 不应该提出的PR / PRs that should not be proposed + +If a developer proposes a PR about any of the following, it may be closed or Rejected. + +1. those that don't describe improvement options. +2. multiple issues of different types combined in one PR. +3. The proposed PR is highly duplicative of already existing PRs. + +如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。 + +1. 没有说明改进方案的。 +2. 多个不同类型的问题合并在一个PR中的。 +3. 提出的PR与已经存在的PR高度重复的。 + + +# 检查您的PR +- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分? +- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。 +- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。 +- [ ] Did you write new required tests? / 您是否编写了新的必要测试? +- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题 \ No newline at end of file diff --git a/sat/data_video.py b/sat/data_video.py index 540f7f7..ccfea46 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -425,7 +425,7 @@ class SFTDataset(Dataset): self.videos_list.append(tensor_frms) # caption - caption_path = os.path.join(root, filename.replace('videos', 'labels').replace('.mp4', '.txt')) + caption_path = os.path.join(root, filename.replace("videos", "labels").replace(".mp4", ".txt")) if os.path.exists(caption_path): caption = open(caption_path, "r").read().splitlines()[0] else: diff --git a/sat/sample_video.py b/sat/sample_video.py index 8ca4b5a..ad1940c 100644 --- a/sat/sample_video.py +++ b/sat/sample_video.py @@ -178,18 +178,18 @@ def sampling_main(args, model_cls): shape=(T, C, H // F, W // F), ) samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() - + # Unload the model from GPU to save GPU memory - model.to('cpu') + model.to("cpu") torch.cuda.empty_cache() first_stage_model = model.first_stage_model first_stage_model = first_stage_model.to(device) latent = 1.0 / model.scale_factor * samples_z - + # Decode latent serial to save GPU memory recons = [] - loop_num = (T-1)//2 + loop_num = (T - 1) // 2 for i in range(loop_num): if i == 0: start_frame, end_frame = 0, 3 @@ -200,7 +200,9 @@ def sampling_main(args, model_cls): else: clear_fake_cp_cache = False with torch.no_grad(): - recon = first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache) + recon = first_stage_model.decode( + latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache + ) recons.append(recon) @@ -208,12 +210,13 @@ def sampling_main(args, model_cls): samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() - save_path = os.path.join(args.output_dir, str(cnt) + '_' + text.replace(' ', '_').replace('/', '')[:120], str(index)) + save_path = os.path.join( + args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + ) if mpu.get_model_parallel_rank() == 0: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) - if __name__ == "__main__": if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] diff --git a/sat/vae_modules/attention.py b/sat/vae_modules/attention.py index 19f5485..52bbba5 100644 --- a/sat/vae_modules/attention.py +++ b/sat/vae_modules/attention.py @@ -52,6 +52,7 @@ except: from modules.utils import checkpoint + def exists(val): return val is not None @@ -93,15 +94,9 @@ class FeedForward(nn.Module): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -117,9 +112,7 @@ def zero_module(module): def Normalize(in_channels): - return torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class LinearAttention(nn.Module): @@ -133,15 +126,11 @@ class LinearAttention(nn.Module): def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange( - qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 - ) + q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) k = k.softmax(dim=-1) context = torch.einsum("bhdn,bhen->bhde", k, v) out = torch.einsum("bhde,bhdn->bhen", context, q) - out = rearrange( - out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w - ) + out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) return self.to_out(out) @@ -151,18 +140,10 @@ class SpatialSelfAttention(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -211,9 +192,7 @@ class CrossAttention(nn.Module): self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.backend = backend def forward( @@ -241,12 +220,8 @@ class CrossAttention(nn.Module): # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 assert x.shape[0] % n_times_crossframe_attn_in_self == 0 n_cp = x.shape[0] // n_times_crossframe_attn_in_self - k = repeat( - k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp - ) - v = repeat( - v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp - ) + k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) @@ -269,9 +244,7 @@ class CrossAttention(nn.Module): ## new with sdp_kernel(**BACKEND_MAP[self.backend]): # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) - out = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask - ) # scale is dim_head ** -0.5 per default + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default del q, k, v out = rearrange(out, "b h n d -> b n (h d)", h=h) @@ -284,9 +257,7 @@ class CrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__( - self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs - ): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs): super().__init__() print( f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " @@ -302,9 +273,7 @@ class MemoryEfficientCrossAttention(nn.Module): self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward( @@ -351,9 +320,7 @@ class MemoryEfficientCrossAttention(nn.Module): ) # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None, op=self.attention_op - ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) # TODO: Use this directly in the attention operation, as a bias if exists(mask): @@ -398,13 +365,9 @@ class BasicTransformerBlock(nn.Module): ) attn_mode = "softmax" elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: - print( - "We do not support vanilla attention anymore, as it is too expensive. Sorry." - ) + print("We do not support vanilla attention anymore, as it is too expensive. Sorry.") if not XFORMERS_IS_AVAILABLE: - assert ( - False - ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" + assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'" else: print("Falling back to xformers efficient attention.") attn_mode = "softmax-xformers" @@ -438,9 +401,7 @@ class BasicTransformerBlock(nn.Module): if self.checkpoint: print(f"{self.__class__.__name__} is using checkpointing") - def forward( - self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 - ): + def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): kwargs = {"x": x} if context is not None: @@ -450,35 +411,22 @@ class BasicTransformerBlock(nn.Module): kwargs.update({"additional_tokens": additional_tokens}) if n_times_crossframe_attn_in_self: - kwargs.update( - {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self} - ) + kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}) # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) - return checkpoint( - self._forward, (x, context), self.parameters(), self.checkpoint - ) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) - def _forward( - self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 - ): + def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): x = ( self.attn1( self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, - n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self - if not self.disable_self_attn - else 0, - ) - + x - ) - x = ( - self.attn2( - self.norm2(x), context=context, additional_tokens=additional_tokens + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, ) + x ) + x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x x = self.ff(self.norm3(x)) + x return x @@ -486,7 +434,7 @@ class BasicTransformerBlock(nn.Module): class BasicTransformerSingleLayerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version + "softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) } @@ -517,9 +465,7 @@ class BasicTransformerSingleLayerBlock(nn.Module): self.checkpoint = checkpoint def forward(self, x, context=None): - return checkpoint( - self._forward, (x, context), self.parameters(), self.checkpoint - ) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): x = self.attn1(self.norm1(x), context=context) + x @@ -553,9 +499,7 @@ class SpatialTransformer(nn.Module): sdp_backend=None, ): super().__init__() - print( - f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads" - ) + print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") from omegaconf import ListConfig if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): @@ -577,9 +521,7 @@ class SpatialTransformer(nn.Module): inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) @@ -600,9 +542,7 @@ class SpatialTransformer(nn.Module): ] ) if not use_linear: - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) diff --git a/sat/vae_modules/autoencoder.py b/sat/vae_modules/autoencoder.py index 2b70840..7c0cc80 100644 --- a/sat/vae_modules/autoencoder.py +++ b/sat/vae_modules/autoencoder.py @@ -16,7 +16,7 @@ from packaging import version from vae_modules.ema import LitEma from sgm.util import ( - instantiate_from_config, + instantiate_from_config, get_obj_from_str, default, is_context_parallel_initialized, @@ -48,14 +48,14 @@ class AbstractAutoencoder(pl.LightningModule): self.use_ema = ema_decay is not None if monitor is not None: self.monitor = monitor - + if self.use_ema: self.model_ema = LitEma(self, decay=ema_decay) logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - + if version.parse(torch.__version__) >= version.parse("2.0.0"): self.automatic_optimization = False - + # def apply_ckpt(self, ckpt: Union[None, str, dict]): # if ckpt is None: # return @@ -66,14 +66,14 @@ class AbstractAutoencoder(pl.LightningModule): # } # engine = instantiate_from_config(ckpt) # engine(self) - + def apply_ckpt(self, ckpt: Union[None, str, dict]): if ckpt is None: return self.init_from_ckpt(ckpt) def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")['state_dict'] + sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: @@ -119,9 +119,7 @@ class AbstractAutoencoder(pl.LightningModule): def instantiate_optimizer_from_config(self, params, lr, cfg): logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") - return get_obj_from_str(cfg["target"])( - params, lr=lr, **cfg.get("params", dict()) - ) + return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict())) def configure_optimizers(self) -> Any: raise NotImplementedError() @@ -160,12 +158,8 @@ class AutoencodingEngine(AbstractAutoencoder): self.encoder = instantiate_from_config(encoder_config) self.decoder = instantiate_from_config(decoder_config) self.loss = instantiate_from_config(loss_config) - self.regularization = instantiate_from_config( - regularizer_config - ) - self.optimizer_config = default( - optimizer_config, {"target": "torch.optim.Adam"} - ) + self.regularization = instantiate_from_config(regularizer_config) + self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"}) self.diff_boost_factor = diff_boost_factor self.disc_start_iter = disc_start_iter self.lr_g_factor = lr_g_factor @@ -178,7 +172,7 @@ class AutoencodingEngine(AbstractAutoencoder): assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) else: self.ae_optimizer_args = [{}] # makes type consitent - + self.trainable_disc_params = trainable_disc_params if self.trainable_disc_params is not None: self.disc_optimizer_args = default( @@ -210,7 +204,7 @@ class AutoencodingEngine(AbstractAutoencoder): params = params + list(self.encoder.parameters()) params = params + list(self.decoder.parameters()) return params - + def get_discriminator_params(self) -> list: if hasattr(self.loss, "get_trainable_parameters"): params = list(self.loss.get_trainable_parameters()) # e.g., discriminator @@ -234,25 +228,19 @@ class AutoencodingEngine(AbstractAutoencoder): if return_reg_log: return z, reg_log return z - + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: x = self.decoder(z, **kwargs) return x - def forward( - self, x: torch.Tensor, **additional_decode_kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: z, reg_log = self.encode(x, return_reg_log=True) dec = self.decode(z, **additional_decode_kwargs) return z, dec, reg_log - def inner_training_step( - self, batch: dict, batch_idx: int, optimizer_idx: int = 0 - ) -> torch.Tensor: + def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: x = self.get_input(batch) - additional_decode_kwargs = { - key: batch[key] for key in self.additional_decode_keys.intersection(batch) - } + additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} z, xrec, regularization_log = self(x, **additional_decode_kwargs) if hasattr(self.loss, "forward_keys"): extra_info = { @@ -267,7 +255,7 @@ class AutoencodingEngine(AbstractAutoencoder): extra_info = {k: extra_info[k] for k in self.loss.forward_keys} else: extra_info = dict() - + if optimizer_idx == 0: # autoencode out_loss = self.loss(x, xrec, **extra_info) @@ -299,13 +287,11 @@ class AutoencodingEngine(AbstractAutoencoder): # discriminator discloss, log_dict_disc = self.loss(x, xrec, **extra_info) # -> discriminator always needs to return a tuple - self.log_dict( - log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True - ) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return discloss else: raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") - + def training_step(self, batch: dict, batch_idx: int): opts = self.optimizers() if not isinstance(opts, list): @@ -317,9 +303,7 @@ class AutoencodingEngine(AbstractAutoencoder): opt = opts[optimizer_idx] opt.zero_grad() with opt.toggle_model(): - loss = self.inner_training_step( - batch, batch_idx, optimizer_idx=optimizer_idx - ) + loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx) self.manual_backward(loss) opt.step() @@ -329,7 +313,7 @@ class AutoencodingEngine(AbstractAutoencoder): log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") log_dict.update(log_dict_ema) return log_dict - + def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict: x = self.get_input(batch) @@ -387,24 +371,18 @@ class AutoencodingEngine(AbstractAutoencoder): params.extend(pattern_params) groups.append({"params": params, **args}) return groups, num_params - + def configure_optimizers(self) -> List[torch.optim.Optimizer]: if self.trainable_ae_params is None: ae_params = self.get_autoencoder_params() else: - ae_params, num_ae_params = self.get_param_groups( - self.trainable_ae_params, self.ae_optimizer_args - ) + ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args) logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") if self.trainable_disc_params is None: disc_params = self.get_discriminator_params() else: - disc_params, num_disc_params = self.get_param_groups( - self.trainable_disc_params, self.disc_optimizer_args - ) - logpy.info( - f"Number of trainable discriminator parameters: {num_disc_params:,}" - ) + disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args) + logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}") opt_ae = self.instantiate_optimizer_from_config( ae_params, default(self.lr_g_factor, 1.0) * self.learning_rate, @@ -412,23 +390,17 @@ class AutoencodingEngine(AbstractAutoencoder): ) opts = [opt_ae] if len(disc_params) > 0: - opt_disc = self.instantiate_optimizer_from_config( - disc_params, self.learning_rate, self.optimizer_config - ) + opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config) opts.append(opt_disc) return opts - + @torch.no_grad() - def log_images( - self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs - ) -> dict: + def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: log = dict() additional_decode_kwargs = {} x = self.get_input(batch) - additional_decode_kwargs.update( - {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} - ) + additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)}) _, xrec, _ = self(x, **additional_decode_kwargs) log["inputs"] = x @@ -438,9 +410,7 @@ class AutoencodingEngine(AbstractAutoencoder): log["diff"] = 2.0 * diff - 1.0 # diff_boost shows location of small errors, by boosting their # brightness. - log["diff_boost"] = ( - 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 - ) + log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 if hasattr(self.loss, "log_images"): log.update(self.loss.log_images(x, xrec)) with self.ema_scope(): @@ -449,9 +419,7 @@ class AutoencodingEngine(AbstractAutoencoder): diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) diff_ema.clamp_(0, 1.0) log["diff_ema"] = 2.0 * diff_ema - 1.0 - log["diff_boost_ema"] = ( - 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 - ) + log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 if additional_log_kwargs: additional_decode_kwargs.update(additional_log_kwargs) _, xrec_add, _ = self(x, **additional_decode_kwargs) @@ -493,9 +461,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine): params = super().get_autoencoder_params() return params - def encode( - self, x: torch.Tensor, return_reg_log: bool = False - ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.max_batch_size is None: z = self.encoder(x) z = self.quant_conv(z) @@ -538,16 +504,11 @@ class AutoencoderKL(AutoencodingEngineLegacy): if "lossconfig" in kwargs: kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( - regularizer_config={ - "target": ( - "sgm.modules.autoencoding.regularizers" - ".DiagonalGaussianRegularizer" - ) - }, + regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")}, **kwargs, ) - + class IdentityFirstStage(AbstractAutoencoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -560,33 +521,31 @@ class IdentityFirstStage(AbstractAutoencoder): def decode(self, x: Any, *args, **kwargs) -> Any: return x - + class VideoAutoencodingEngine(AutoencodingEngine): def __init__( self, ckpt_path: Union[None, str] = None, ignore_keys: Union[Tuple, list] = (), - image_video_weights=[1,1], + image_video_weights=[1, 1], only_train_decoder=False, context_parallel_size=0, - **kwargs, + **kwargs, ): super().__init__(**kwargs) self.context_parallel_size = context_parallel_size if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def log_videos( - self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs - ) -> dict: + + def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: return self.log_images(batch, additional_log_kwargs, **kwargs) def get_input(self, batch: dict) -> torch.Tensor: if self.context_parallel_size > 0: if not is_context_parallel_initialized(): initialize_context_parallel(self.context_parallel_size) - + batch = batch[self.input_key] global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size @@ -594,16 +553,16 @@ class VideoAutoencodingEngine(AutoencodingEngine): batch = _conv_split(batch, dim=2, kernel_size=1) return batch - + return batch[self.input_key] - + def apply_ckpt(self, ckpt: Union[None, str, dict]): if ckpt is None: return self.init_from_ckpt(ckpt) - + def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")['state_dict'] + sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: @@ -615,6 +574,7 @@ class VideoAutoencodingEngine(AutoencodingEngine): print("Unexpected keys: ", unexpected_keys) print(f"Restored from {path}") + class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): def __init__( self, @@ -633,16 +593,15 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): input_cp: bool = False, output_cp: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - if self.cp_size > 0 and not input_cp: if not is_context_parallel_initialized: initialize_context_parallel(self.cp_size) - + global_src_rank = get_context_parallel_group_rank() * self.cp_size torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group()) x = _conv_split(x, dim=2, kernel_size=1) - + if return_reg_log: z, reg_log = super().encode(x, return_reg_log, unregularized) else: @@ -650,7 +609,7 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): if self.cp_size > 0 and not output_cp: z = _conv_gather(z, dim=2, kernel_size=1) - + if return_reg_log: return z, reg_log return z @@ -671,7 +630,7 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group()) z = _conv_split(z, dim=2, kernel_size=split_kernel_size) - + x = super().decode(z, **kwargs) if self.cp_size > 0 and not output_cp: @@ -680,15 +639,13 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): return x def forward( - self, - x: torch.Tensor, + self, + x: torch.Tensor, input_cp: bool = False, latent_cp: bool = False, output_cp: bool = False, - **additional_decode_kwargs + **additional_decode_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, dict]: - z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp) dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs) return z, dec, reg_log - \ No newline at end of file diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py index 7a3509d..28d9738 100644 --- a/sat/vae_modules/cp_enc_dec.py +++ b/sat/vae_modules/cp_enc_dec.py @@ -9,7 +9,12 @@ from beartype import beartype from beartype.typing import Union, Tuple, Optional, List from einops import rearrange -from sgm.util import get_context_parallel_group, get_context_parallel_rank, get_context_parallel_world_size, get_context_parallel_group_rank +from sgm.util import ( + get_context_parallel_group, + get_context_parallel_rank, + get_context_parallel_world_size, + get_context_parallel_group_rank, +) # try: from vae_modules.utils import SafeConv3d as Conv3d @@ -17,12 +22,15 @@ from vae_modules.utils import SafeConv3d as Conv3d # # Degrade to normal Conv3d if SafeConv3d is not available # from torch.nn import Conv3d -def cast_tuple(t, length = 1): + +def cast_tuple(t, length=1): return t if isinstance(t, tuple) else ((t,) * length) + def divisible_by(num, den): return (num % den) == 0 + def is_odd(n): return not divisible_by(n, 2) @@ -30,6 +38,7 @@ def is_odd(n): def exists(v): return v is not None + def pair(t): return t if isinstance(t, tuple) else (t, t) @@ -51,15 +60,16 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish - return x*torch.sigmoid(x) + return x * torch.sigmoid(x) -def leaky_relu(p = 0.1): + +def leaky_relu(p=0.1): return nn.LeakyReLU(p) @@ -68,7 +78,7 @@ def _split(input_, dim): if cp_world_size == 1: return input_ - + cp_rank = get_context_parallel_rank() # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) @@ -88,27 +98,30 @@ def _split(input_, dim): return output + def _gather(input_, dim): cp_world_size = get_context_parallel_world_size() # Bypass the function if context parallel is 1 if cp_world_size == 1: return input_ - + group = get_context_parallel_group() cp_rank = get_context_parallel_rank() - + # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() if cp_rank == 0: input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() - tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)] + tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] if cp_rank == 0: input_ = torch.cat([input_first_frame_, input_], dim=dim) - + tensor_list[cp_rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=group) @@ -118,6 +131,7 @@ def _gather(input_, dim): return output + def _conv_split(input_, dim, kernel_size): cp_world_size = get_context_parallel_world_size() @@ -132,23 +146,26 @@ def _conv_split(input_, dim, kernel_size): dim_size = (input_.size()[dim] - kernel_size) // cp_world_size if cp_rank == 0: - output = input_.transpose(dim, 0)[:dim_size + kernel_size].transpose(dim, 0) + output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) else: # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) - output = input_.transpose(dim, 0)[cp_rank * dim_size + kernel_size:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) + output = input_.transpose(dim, 0)[ + cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size + ].transpose(dim, 0) output = output.contiguous() # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) return output + def _conv_gather(input_, dim, kernel_size): cp_world_size = get_context_parallel_world_size() # Bypass the function if context parallel is 1 if cp_world_size == 1: return input_ - + group = get_context_parallel_group() cp_rank = get_context_parallel_rank() @@ -158,9 +175,11 @@ def _conv_gather(input_, dim, kernel_size): if cp_rank == 0: input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() else: - input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0):].transpose(0, dim).contiguous() + input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous() - tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)] + tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] if cp_rank == 0: input_ = torch.cat([input_first_kernel_, input_], dim=dim) @@ -174,12 +193,12 @@ def _conv_gather(input_, dim, kernel_size): return output -def _pass_from_previous_rank(input_, dim, kernel_size): +def _pass_from_previous_rank(input_, dim, kernel_size): # Bypass the function if kernel size is 1 if kernel_size == 1: return input_ - + group = get_context_parallel_group() cp_rank = get_context_parallel_rank() cp_group_rank = get_context_parallel_group_rank() @@ -191,7 +210,7 @@ def _pass_from_previous_rank(input_, dim, kernel_size): global_world_size = torch.distributed.get_world_size() input_ = input_.transpose(0, dim) - + # pass from last rank send_rank = global_rank + 1 recv_rank = global_rank - 1 @@ -201,11 +220,11 @@ def _pass_from_previous_rank(input_, dim, kernel_size): recv_rank += cp_world_size if cp_rank < cp_world_size - 1: - req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) if cp_rank > 0: - recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() + recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) - + if cp_rank == 0: input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) else: @@ -218,8 +237,8 @@ def _pass_from_previous_rank(input_, dim, kernel_size): return input_ -def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None): +def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None): # Bypass the function if kernel size is 1 if kernel_size == 1: return input_ @@ -248,9 +267,9 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group) # req_recv.wait() - recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() + recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() if cp_rank < cp_world_size - 1: - req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) if cp_rank > 0: req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) @@ -264,18 +283,17 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non else: req_recv.wait() input_ = torch.cat([recv_buffer, input_], dim=0) - + input_ = input_.transpose(0, dim).contiguous() return input_ def _drop_from_previous_rank(input_, dim, kernel_size): - input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(0, dim) + input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim) return input_ class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, kernel_size): ctx.dim = dim @@ -286,8 +304,8 @@ class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): def backward(ctx, grad_output): return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None -class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): +class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_, dim, kernel_size): ctx.dim = dim @@ -298,26 +316,26 @@ class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): def backward(ctx, grad_output): return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None -class _ConvolutionPassFromPreviousRank(torch.autograd.Function): +class _ConvolutionPassFromPreviousRank(torch.autograd.Function): @staticmethod def forward(ctx, input_, dim, kernel_size): ctx.dim = dim ctx.kernel_size = kernel_size return _pass_from_previous_rank(input_, dim, kernel_size) - + @staticmethod def backward(ctx, grad_output): return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None -class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function): +class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function): @staticmethod def forward(ctx, input_, dim, kernel_size, cache_padding): ctx.dim = dim ctx.kernel_size = kernel_size return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding) - + @staticmethod def backward(ctx, grad_output): return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None, None @@ -326,25 +344,21 @@ class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function): def conv_scatter_to_context_parallel_region(input_, dim, kernel_size): return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size) + def conv_gather_from_context_parallel_region(input_, dim, kernel_size): return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size) + def conv_pass_from_last_rank(input_, dim, kernel_size): return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size) + def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding): return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding) class ContextParallelCausalConv3d(nn.Module): - def __init__( - self, - chan_in, - chan_out, - kernel_size: Union[int, Tuple[int, int, int]], - stride = 1, - **kwargs - ): + def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs): super().__init__() kernel_size = cast_tuple(kernel_size, 3) @@ -364,7 +378,7 @@ class ContextParallelCausalConv3d(nn.Module): stride = (stride, stride, stride) dilation = (1, 1, 1) - self.conv = Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs) + self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) self.cache_padding = None def forward(self, input_, clear_cache=True): @@ -381,30 +395,40 @@ class ContextParallelCausalConv3d(nn.Module): # output = output_parallel # return output - input_parallel = fake_cp_pass_from_previous_rank(input_, self.temporal_dim, self.time_kernel_size, self.cache_padding) - + input_parallel = fake_cp_pass_from_previous_rank( + input_, self.temporal_dim, self.time_kernel_size, self.cache_padding + ) + del self.cache_padding self.cache_padding = None if not clear_cache: cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size() global_rank = torch.distributed.get_rank() if cp_world_size == 1: - self.cache_padding = input_parallel[:, :, -self.time_kernel_size + 1:].contiguous().detach().clone().cpu() + self.cache_padding = ( + input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + ) else: if cp_rank == cp_world_size - 1: - torch.distributed.isend(input_parallel[:, :, -self.time_kernel_size + 1:].contiguous(), global_rank + 1 - cp_world_size, group=get_context_parallel_group()) + torch.distributed.isend( + input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(), + global_rank + 1 - cp_world_size, + group=get_context_parallel_group(), + ) if cp_rank == 0: - recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1:]).contiguous() - torch.distributed.recv(recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group()) + recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous() + torch.distributed.recv( + recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group() + ) self.cache_padding = recv_buffer.contiguous().detach().clone().cpu() padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0) + input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) output_parallel = self.conv(input_parallel) output = output_parallel return output - + class ContextParallelGroupNorm(torch.nn.GroupNorm): def forward(self, input_): @@ -416,26 +440,24 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm): output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1) return output -def Normalize( - in_channels, - gather=False, - **kwargs -): # same for 3D and 2D + +def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D if gather: return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) else: return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + class SpatialNorm3D(nn.Module): def __init__( - self, - f_channels, - zq_channels, - freeze_norm_layer=False, - add_conv=False, - pad_mode='constant', + self, + f_channels, + zq_channels, + freeze_norm_layer=False, + add_conv=False, + pad_mode="constant", gather=False, - **norm_layer_params + **norm_layer_params, ): super().__init__() if gather: @@ -446,7 +468,7 @@ class SpatialNorm3D(nn.Module): if freeze_norm_layer: for p in self.norm_layer.parameters: p.requires_grad = False - + self.add_conv = add_conv if add_conv: self.conv = ContextParallelCausalConv3d( @@ -454,7 +476,7 @@ class SpatialNorm3D(nn.Module): chan_out=zq_channels, kernel_size=3, ) - + self.conv_y = ContextParallelCausalConv3d( chan_in=zq_channels, chan_out=f_channels, @@ -476,10 +498,10 @@ class SpatialNorm3D(nn.Module): zq = torch.cat([zq_first, zq_rest], dim=2) else: zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") - + if self.add_conv: zq = self.conv(zq, clear_cache=clear_fake_cp_cache) - + # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1) norm_f = self.norm_layer(f) # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1) @@ -487,20 +509,21 @@ class SpatialNorm3D(nn.Module): new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f + def Normalize3D( - in_channels, - zq_ch, + in_channels, + zq_ch, add_conv, gather=False, ): return SpatialNorm3D( - in_channels, - zq_ch, + in_channels, + zq_ch, gather=gather, - freeze_norm_layer=False, - add_conv=add_conv, - num_groups=32, - eps=1e-6, + freeze_norm_layer=False, + add_conv=add_conv, + num_groups=32, + eps=1e-6, affine=True, ) @@ -515,13 +538,9 @@ class Upsample3D(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time - + def forward(self, x): if self.compress_time and x.shape[2] > 1: if x.shape[2] % 2 == 1: @@ -533,46 +552,37 @@ class Upsample3D(nn.Module): x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) else: x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - + else: - # only interpolate 2D + # only interpolate 2D t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) - + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + if self.with_conv: t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = self.conv(x) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x + class DownSample3D(nn.Module): - def __init__( - self, - in_channels, - with_conv, - compress_time=False, - out_channels=None - ): + def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None): super().__init__() self.with_conv = with_conv if out_channels is None: out_channels = in_channels if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=2, - padding=0) + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.compress_time = compress_time - + def forward(self, x): if self.compress_time and x.shape[2] > 1: h, w = x.shape[-2:] - x = rearrange(x, 'b c t h w -> (b h w) c t') + x = rearrange(x, "b c t h w -> (b h w) c t") if x.shape[-1] % 2 == 1: # split first frame @@ -581,29 +591,40 @@ class DownSample3D(nn.Module): if x_rest.shape[-1] > 0: x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) - x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) else: x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) - x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w) - + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + if self.with_conv: - pad = (0,1,0,1) + pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = self.conv(x) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) else: t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x - + class ContextParallelResnetBlock3D(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512, zq_ch=None, add_conv=False, gather_norm=False, normalization=Normalize): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + zq_ch=None, + add_conv=False, + gather_norm=False, + normalization=Normalize, + ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -623,8 +644,7 @@ class ContextParallelResnetBlock3D(nn.Module): kernel_size=3, ) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = normalization( out_channels, zq_ch=zq_ch, @@ -652,7 +672,7 @@ class ContextParallelResnetBlock3D(nn.Module): stride=1, padding=0, ) - + def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): h = x @@ -669,7 +689,7 @@ class ContextParallelResnetBlock3D(nn.Module): h = self.conv1(h, clear_cache=clear_fake_cp_cache) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None,None] + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] # if isinstance(self.norm2, torch.nn.GroupNorm): # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) @@ -690,13 +710,29 @@ class ContextParallelResnetBlock3D(nn.Module): else: x = self.nin_shortcut(x) - return x+h + return x + h + - class ContextParallelEncoder3D(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, pad_mode='first', temporal_compress_times=4, gather_norm=False, **ignore_kwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignore_kwargs, + ): super().__init__() self.ch = ch self.temb_ch = 0 @@ -715,13 +751,13 @@ class ContextParallelEncoder3D(nn.Module): ) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ContextParallelResnetBlock3D( @@ -736,14 +772,14 @@ class ContextParallelEncoder3D(nn.Module): down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: if i_level < self.temporal_compress_level: down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) else: down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) curr_res = curr_res // 2 self.down.append(down) - + # middle self.mid = nn.Module() self.mid.block_1 = ContextParallelResnetBlock3D( @@ -767,12 +803,11 @@ class ContextParallelEncoder3D(nn.Module): self.conv_out = ContextParallelCausalConv3d( chan_in=block_in, - chan_out=2*z_channels if double_z else z_channels, + chan_out=2 * z_channels if double_z else z_channels, kernel_size=3, ) def forward(self, x): - # timestep embedding temb = None @@ -783,9 +818,9 @@ class ContextParallelEncoder3D(nn.Module): h = self.down[i_level].block[i_block](h, temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: h = self.down[i_level].downsample(h) - + # middle h = self.mid.block_1(h, temb) h = self.mid.block_2(h, temb) @@ -794,7 +829,7 @@ class ContextParallelEncoder3D(nn.Module): # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) h = self.norm_out(h) # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - + h = nonlinearity(h) h = self.conv_out(h) @@ -802,8 +837,27 @@ class ContextParallelEncoder3D(nn.Module): class ContextParallelDecoder3D(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, pad_mode='first', temporal_compress_times=4, gather_norm=False, **ignorekwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignorekwargs, + ): super().__init__() self.ch = ch self.temb_ch = 0 @@ -820,12 +874,11 @@ class ContextParallelDecoder3D(nn.Module): zq_ch = z_channels # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) self.conv_in = ContextParallelCausalConv3d( chan_in=z_channels, @@ -862,8 +915,8 @@ class ContextParallelDecoder3D(nn.Module): for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): block.append( ContextParallelResnetBlock3D( in_channels=block_in, @@ -886,7 +939,7 @@ class ContextParallelDecoder3D(nn.Module): else: up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) self.up.insert(0, up) - + self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) self.conv_out = ContextParallelCausalConv3d( @@ -896,7 +949,6 @@ class ContextParallelDecoder3D(nn.Module): ) def forward(self, z, clear_fake_cp_cache=True): - self.last_z_shape = z.shape # timestep embedding @@ -914,7 +966,7 @@ class ContextParallelDecoder3D(nn.Module): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): + for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, zq) diff --git a/sat/vae_modules/ema.py b/sat/vae_modules/ema.py index 97b5ae2..9f1f760 100644 --- a/sat/vae_modules/ema.py +++ b/sat/vae_modules/ema.py @@ -12,9 +12,7 @@ class LitEma(nn.Module): self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) self.register_buffer( "num_updates", - torch.tensor(0, dtype=torch.int) - if use_num_upates - else torch.tensor(-1, dtype=torch.int), + torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), ) for name, p in model.named_parameters(): @@ -47,9 +45,7 @@ class LitEma(nn.Module): if m_param[key].requires_grad: sname = self.m_name2s_name[key] shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_( - one_minus_decay * (shadow_params[sname] - m_param[key]) - ) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: assert not key in self.m_name2s_name diff --git a/sat/vae_modules/regularizers.py b/sat/vae_modules/regularizers.py index d95cd53..205bd4a 100644 --- a/sat/vae_modules/regularizers.py +++ b/sat/vae_modules/regularizers.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F from torch import nn + class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters @@ -15,9 +16,7 @@ class DiagonalGaussianDistribution(object): self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device - ) + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self): # x = self.mean + self.std * torch.randn(self.mean.shape).to( @@ -57,6 +56,7 @@ class DiagonalGaussianDistribution(object): def mode(self): return self.mean + class AbstractRegularizer(nn.Module): def __init__(self): super().__init__() @@ -77,14 +77,10 @@ class IdentityRegularizer(AbstractRegularizer): yield from () -def measure_perplexity( - predicted_indices: torch.Tensor, num_centroids: int -) -> Tuple[torch.Tensor, torch.Tensor]: +def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = ( - F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) - ) + encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) diff --git a/sat/vae_modules/utils.py b/sat/vae_modules/utils.py index 51c4fd1..8c8dba6 100644 --- a/sat/vae_modules/utils.py +++ b/sat/vae_modules/utils.py @@ -14,17 +14,19 @@ import torch.distributed _CONTEXT_PARALLEL_GROUP = None _CONTEXT_PARALLEL_SIZE = None + def is_context_parallel_initialized(): if _CONTEXT_PARALLEL_GROUP is None: return False else: return True - + + def initialize_context_parallel(context_parallel_size): global _CONTEXT_PARALLEL_GROUP global _CONTEXT_PARALLEL_SIZE - assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized' + assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" _CONTEXT_PARALLEL_SIZE = context_parallel_size rank = torch.distributed.get_rank() @@ -36,41 +38,49 @@ def initialize_context_parallel(context_parallel_size): if rank in ranks: _CONTEXT_PARALLEL_GROUP = group break - + + def get_context_parallel_group(): - assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized' + assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" return _CONTEXT_PARALLEL_GROUP + def get_context_parallel_world_size(): - assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized' + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" return _CONTEXT_PARALLEL_SIZE + def get_context_parallel_rank(): - assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized' + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" rank = torch.distributed.get_rank() cp_rank = rank % _CONTEXT_PARALLEL_SIZE return cp_rank + def get_context_parallel_group_rank(): - assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized' + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" rank = torch.distributed.get_rank() cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE return cp_group_rank + class SafeConv3d(torch.nn.Conv3d): def forward(self, input): memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 if memory_count > 2: kernel_size = self.kernel_size[0] part_num = int(memory_count / 2) + 1 - input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW if kernel_size > 1: - input_chunks = [input_chunks[0]] + [torch.cat((input_chunks[i-1][:, :, -kernel_size+1:], input_chunks[i]), dim=2) for i in range(1, len(input_chunks))] + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] output_chunks = [] for input_chunk in input_chunks: @@ -149,9 +159,7 @@ def log_txt_as_img(wh, xc, size=10): text_seq = xc[bi][0] else: text_seq = xc[bi] - lines = "\n".join( - text_seq[start : start + nc] for start in range(0, len(text_seq), nc) - ) + lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc)) try: draw.text((0, 0), lines, fill="black", font=font) @@ -263,9 +271,7 @@ def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] @@ -360,7 +366,8 @@ def checkpoint(func, inputs, params, flag): return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) - + + class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): @@ -395,4 +402,3 @@ class CheckpointFunction(torch.autograd.Function): del ctx.input_params del output_tensors return (None, None) + input_grads - \ No newline at end of file