mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-27 20:59:24 +08:00
commit
96a997a777
@ -22,7 +22,11 @@
|
||||
|
||||
## Update and News
|
||||
|
||||
- 🔥🔥 **News**: ```2024/8/12```: The CogVideoX paper has been uploaded to arxiv. Feel free to check out
|
||||
- 🔥🔥 **News**: ```2024/8/15```: The `SwissArmyTransformer` dependency in CogVideoX has been upgraded to `0.4.12`. Fine-tuning
|
||||
no longer requires installing `SwissArmyTransformer` from source. Additionally, the `Tied VAE` technique has been
|
||||
applied in the implementation within the `diffusers` library. Please install `diffusers` and `accelerate` libraries
|
||||
from source. Inference for CogVideoX now requires only 12GB of VRAM.
|
||||
- 🔥 **News**: ```2024/8/12```: The CogVideoX paper has been uploaded to arxiv. Feel free to check out
|
||||
the [paper](https://arxiv.org/abs/2408.06072).
|
||||
- 🔥 **News**: ```2024/8/7```: CogVideoX has been integrated into `diffusers` version 0.30.0. Inference can now be
|
||||
performed
|
||||
|
@ -21,8 +21,8 @@
|
||||
</p>
|
||||
|
||||
## 更新とニュース
|
||||
|
||||
- 🔥🔥 **ニュース**: ```2024/8/12```: CogVideoX 論文がarxivにアップロードされました。ぜひ[論文](https://arxiv.org/abs/2408.06072)をご覧ください。
|
||||
- 🔥🔥 **ニュース**: 2024/8/15: CogVideoX の依存関係である`SwissArmyTransformer`の依存が`0.4.12`にアップグレードされました。これにより、微調整の際に`SwissArmyTransformer`をソースコードからインストールする必要がなくなりました。同時に、`Tied VAE` 技術が `diffusers` ライブラリの実装に適用されました。`diffusers` と `accelerate` ライブラリをソースコードからインストールしてください。CogVdideoX の推論には 12GB の VRAM だけが必要です。
|
||||
- 🔥 **ニュース**: ```2024/8/12```: CogVideoX 論文がarxivにアップロードされました。ぜひ[論文](https://arxiv.org/abs/2408.06072)をご覧ください。
|
||||
- 🔥 **ニュース**: ```2024/8/7```: CogVideoX は `diffusers` バージョン 0.30.0 に統合されました。単一の 3090 GPU
|
||||
で推論を実行できます。詳細については [コード](inference/cli_demo.py) を参照してください。
|
||||
- 🔥 **ニュース**: ```2024/8/6```: **CogVideoX-2B** で使用される **3D Causal VAE** もオープンソース化しました。これにより、ビデオをほぼ無損失で再構築できます。
|
||||
|
@ -23,7 +23,10 @@
|
||||
|
||||
## 项目更新
|
||||
|
||||
- 🔥🔥 **News**: ```2024/8/12```: CogVideoX 论文已上传到arxiv,欢迎查看[论文](https://arxiv.org/abs/2408.06072)。
|
||||
- 🔥🔥 **News**: ```2024/8/15```: CogVideoX 依赖中`SwissArmyTransformer`依赖升级到`0.4.12`,
|
||||
微调不再需要从源代码安装`SwissArmyTransformer`。同时,`Tied VAE` 技术已经被应用到 `diffusers`
|
||||
库中的实现,请从源代码安装 `diffusers` 和 `accelerate` 库,推理 CogVdideoX 仅需 12GB显存。
|
||||
- 🔥 **News**: ```2024/8/12```: CogVideoX 论文已上传到arxiv,欢迎查看[论文](https://arxiv.org/abs/2408.06072)。
|
||||
- 🔥 **News**: ```2024/8/7```: CogVideoX 已经合并入 `diffusers`
|
||||
0.30.0版本,单张3090可以推理,详情请见[代码](inference/cli_demo.py)。
|
||||
- 🔥 **News**: ```2024/8/6```: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。
|
||||
|
@ -2,14 +2,15 @@
|
||||
This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
||||
|
||||
Note:
|
||||
This script requires the `diffusers>=0.30.0` library to be installed.
|
||||
If the video exported using OpenCV appears “completely green” and cannot be viewed, lease switch to a different player to watch it. This is a normal phenomenon.
|
||||
This script requires the `diffusers>=0.30.0` library to be installed, after `diffusers 0.31.0` release,
|
||||
need to update.
|
||||
|
||||
Run the script:
|
||||
$ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import argparse
|
||||
import tempfile
|
||||
from typing import Union, List
|
||||
@ -18,11 +19,10 @@ import PIL
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
|
||||
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler
|
||||
|
||||
def export_to_video_imageio(
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||
) -> str:
|
||||
"""
|
||||
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
|
||||
@ -38,14 +38,13 @@ def export_to_video_imageio(
|
||||
|
||||
|
||||
def generate_video(
|
||||
prompt: str,
|
||||
model_path: str,
|
||||
output_path: str = "./output.mp4",
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 6.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
prompt: str,
|
||||
model_path: str,
|
||||
output_path: str = "./output.mp4",
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 6.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
"""
|
||||
Generates a video based on the given prompt and saves it to the specified path.
|
||||
@ -57,36 +56,46 @@ def generate_video(
|
||||
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
||||
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
||||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||||
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
|
||||
- dtype (torch.dtype): The data type for computation (default is torch.float16).
|
||||
|
||||
"""
|
||||
|
||||
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
|
||||
# add device_map="balanced" in the from_pretrained function and remove
|
||||
# `pipe.enable_model_cpu_offload()` to enable Multi GPUs (2 or more and each one must have more than 20GB memory) inference.
|
||||
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (float16).
|
||||
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
|
||||
# function to use Multi GPUs.
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||
|
||||
# 2. Set Scheduler.
|
||||
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
|
||||
# We recommend using `CogVideoXDDIMScheduler` for better results.
|
||||
pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
|
||||
# 3. Enable CPU offload for the model and reset the memory, enable tiling.
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Encode the prompt to get the prompt embeddings
|
||||
prompt_embeds, _ = pipe.encode_prompt(
|
||||
prompt=prompt, # The textual description for video generation
|
||||
negative_prompt=None, # The negative prompt to guide the video generation
|
||||
do_classifier_free_guidance=True, # Whether to use classifier-free guidance
|
||||
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
||||
max_sequence_length=226, # Maximum length of the sequence, must be 226
|
||||
device=device, # Device to use for computation
|
||||
dtype=dtype, # Data type for computation
|
||||
)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Generate the video frames using the pipeline
|
||||
# Using with diffusers branch `main` to enable tiling. This will cost ONLY 12GB GPU memory.
|
||||
# pipe.vae.enable_tiling()
|
||||
|
||||
# 4. Generate the video frames based on the prompt.
|
||||
# `num_frames` is the Number of frames to generate.
|
||||
# This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame.
|
||||
# for diffusers version `0.30.0`, this should be 48. and for `0.31.0` and after, this should be 49.
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
||||
num_inference_steps=num_inference_steps, # Number of inference steps
|
||||
num_frames=48, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after.
|
||||
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
|
||||
prompt_embeds=prompt_embeds, # Encoded prompt embeddings
|
||||
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
|
||||
generator=torch.Generator().manual_seed(42), # Set the seed for reproducibility
|
||||
).frames[0]
|
||||
|
||||
# Export the generated frames to a video file. fps must be 8
|
||||
# 5. Export the generated frames to a video file. fps must be 8
|
||||
export_to_video_imageio(video, output_path, fps=8)
|
||||
|
||||
|
||||
@ -104,10 +113,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
||||
)
|
||||
@ -125,6 +130,5 @@ if __name__ == "__main__":
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
num_videos_per_prompt=args.num_videos_per_prompt,
|
||||
device=args.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -1,14 +1,14 @@
|
||||
diffusers>=0.30.0
|
||||
transformers>=4.43.4
|
||||
accelerate>=0.33.0
|
||||
diffusers==0.30.0
|
||||
transformers==4.44.0
|
||||
accelerate==0.33.0
|
||||
sentencepiece==0.2.0 # T5
|
||||
SwissArmyTransformer==0.4.11 # Inference
|
||||
SwissArmyTransformer==0.4.12 # Inference
|
||||
torch==2.4.0 # Tested in 2.2 2.3 2.4 and 2.5
|
||||
torchvision==0.19.0
|
||||
gradio==4.40.0 # For HF gradio demo
|
||||
pillow==9.5.0 # For HF gradio demo
|
||||
streamlit==1.37.0 # For streamlit web demo
|
||||
opencv-python==4.10 # For diffusers inference origin export video
|
||||
imageio==2.34.2 # For diffusers inference export video
|
||||
imageio-ffmpeg==0.5.1 # For diffusers inference export video
|
||||
openai==1.38.0 # For prompt refiner
|
||||
openai==1.40.6 # For prompt refiner
|
||||
moviepy==1.0.3 # For export video
|
@ -120,22 +120,6 @@ bash inference.sh
|
||||
|
||||
## Fine-Tuning the Model
|
||||
|
||||
### Preparing the Environment
|
||||
|
||||
Please note that currently, SAT needs to be installed from the source code for proper fine-tuning.
|
||||
|
||||
You need to get the code from the source to support the fine-tuning functionality, as these features have not yet been
|
||||
released in the Pip package.
|
||||
|
||||
We will address this issue in future stable releases.
|
||||
|
||||
|
||||
```
|
||||
git clone https://github.com/THUDM/SwissArmyTransformer.git
|
||||
cd SwissArmyTransformer
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Preparing the Dataset
|
||||
|
||||
The dataset format should be as follows:
|
||||
|
@ -118,17 +118,6 @@ bash inference.sh
|
||||
|
||||
## モデルのファインチューニング
|
||||
|
||||
### 環境の準備
|
||||
|
||||
ご注意ください、現在、SATを正常にファインチューニングするためには、ソースコードからインストールする必要があります。
|
||||
これは、まだpipパッケージバージョンにリリースされていない最新の機能を使用する必要があるためです。この問題は、今後の安定版で解決する予定です。
|
||||
|
||||
```
|
||||
git clone https://github.com/THUDM/SwissArmyTransformer.git
|
||||
cd SwissArmyTransformer
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### データセットの準備
|
||||
|
||||
データセットの形式は次のようになります:
|
||||
|
@ -114,18 +114,6 @@ bash inference.sh
|
||||
|
||||
## 微调模型
|
||||
|
||||
### 准备环境
|
||||
|
||||
请注意,目前,SAT需要从源码安装,才能正常微调。
|
||||
这是因为你需要使用还没发型到pip包版本的最新代码所支持的功能。
|
||||
我们将会在未来的稳定版本解决这个问题。
|
||||
|
||||
```
|
||||
git clone https://github.com/THUDM/SwissArmyTransformer.git
|
||||
cd SwissArmyTransformer
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 准备数据集
|
||||
|
||||
数据集格式应该如下:
|
||||
|
@ -1,3 +1,4 @@
|
||||
In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
|
||||
In the haunting backdrop of a warIn the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
|
||||
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
|
||||
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
|
||||
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
|
||||
A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
|
@ -240,17 +240,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
text_length,
|
||||
theta=10000,
|
||||
rot_v=False,
|
||||
pnp=False,
|
||||
learnable_pos_embed=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.rot_v = rot_v
|
||||
self.text_length = text_length
|
||||
|
||||
dim_t = hidden_size_head // 4
|
||||
dim_h = hidden_size_head // 8 * 3
|
||||
dim_w = hidden_size_head // 8 * 3
|
||||
|
||||
# 'lang':
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||
@ -268,12 +266,7 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
# (T H W D)
|
||||
|
||||
self.pnp = pnp
|
||||
|
||||
if not self.pnp:
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
|
||||
freqs = freqs.contiguous()
|
||||
freqs_sin = freqs.sin()
|
||||
@ -281,40 +274,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
self.register_buffer("freqs_sin", freqs_sin)
|
||||
self.register_buffer("freqs_cos", freqs_cos)
|
||||
|
||||
self.text_length = text_length
|
||||
if learnable_pos_embed:
|
||||
num_patches = height * width * compressed_num_frames + text_length
|
||||
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
def rotary(self, t, **kwargs):
|
||||
if self.pnp:
|
||||
t_coords = kwargs["rope_position_ids"][:, :, 0]
|
||||
x_coords = kwargs["rope_position_ids"][:, :, 1]
|
||||
y_coords = kwargs["rope_position_ids"][:, :, 2]
|
||||
mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
|
||||
freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
|
||||
freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
|
||||
|
||||
else:
|
||||
|
||||
def reshape_freq(freqs):
|
||||
frame = t.shape[2]
|
||||
freqs = freqs[:frame].contiguous()
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
||||
return freqs
|
||||
|
||||
freqs_cos = reshape_freq(self.freqs_cos)
|
||||
freqs_sin = reshape_freq(self.freqs_sin)
|
||||
seq_len = t.shape[2]
|
||||
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
return t * freqs_cos + rotate_half(t) * freqs_sin
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
if self.pos_embedding is not None:
|
||||
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def attention_fn(
|
||||
self,
|
||||
@ -329,64 +297,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
):
|
||||
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
||||
|
||||
if self.pnp:
|
||||
query_layer = self.rotary(query_layer, **kwargs)
|
||||
key_layer = self.rotary(key_layer, **kwargs)
|
||||
if self.rot_v:
|
||||
value_layer = self.rotary(value_layer)
|
||||
else:
|
||||
query_layer = torch.cat(
|
||||
(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
key_layer = torch.cat(
|
||||
(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
if self.rot_v:
|
||||
value_layer = torch.cat(
|
||||
(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
|
||||
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
|
||||
if self.rot_v:
|
||||
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
|
||||
|
||||
return attention_fn_default(
|
||||
query_layer,
|
||||
|
@ -1,17 +1,16 @@
|
||||
SwissArmyTransformer==0.4.11
|
||||
diffusers>=0.29.2
|
||||
omegaconf>=2.3.0
|
||||
torch>=2.3.1
|
||||
torchvision>=0.19.0
|
||||
pytorch_lightning>=2.3.3
|
||||
kornia>=0.7.3
|
||||
beartype>=0.18.5
|
||||
numpy>=2.0.1
|
||||
fsspec>=2024.5.0
|
||||
safetensors>=0.4.3
|
||||
imageio-ffmpeg>=0.5.1
|
||||
imageio>=2.34.2
|
||||
scipy>=1.14.0
|
||||
decord>=0.6.0
|
||||
wandb>=0.17.5
|
||||
deepspeed>=0.14.4
|
||||
SwissArmyTransformer==0.4.12
|
||||
omegaconf==2.3.0
|
||||
torch==2.4.0
|
||||
torchvision==0.19.0
|
||||
pytorch_lightning==2.3.3
|
||||
kornia==0.7.3
|
||||
beartype==0.18.5
|
||||
numpy==2.0.1
|
||||
fsspec==2024.5.0
|
||||
safetensors==0.4.3
|
||||
imageio-ffmpeg==0.5.1
|
||||
imageio==2.34.2
|
||||
scipy==1.14.0
|
||||
decord==0.6.0
|
||||
wandb==0.17.5
|
||||
deepspeed==0.14.4
|
@ -2,25 +2,12 @@ import math
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from beartype import beartype
|
||||
from beartype.typing import Union, Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
|
||||
from ..util import (
|
||||
get_context_parallel_group,
|
||||
get_context_parallel_rank,
|
||||
get_context_parallel_world_size,
|
||||
get_context_parallel_group_rank,
|
||||
)
|
||||
|
||||
# try:
|
||||
from ..util import SafeConv3d as Conv3d
|
||||
# except:
|
||||
# # Degrade to normal Conv3d if SafeConv3d is not available
|
||||
# from torch.nn import Conv3d
|
||||
)
|
||||
|
||||
_USE_CP = True
|
||||
|
||||
@ -192,706 +179,4 @@ def _conv_gather(input_, dim, kernel_size):
|
||||
|
||||
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _pass_from_previous_rank(input_, dim, kernel_size):
|
||||
# Bypass the function if kernel size is 1
|
||||
if kernel_size == 1:
|
||||
return input_
|
||||
|
||||
group = get_context_parallel_group()
|
||||
cp_rank = get_context_parallel_rank()
|
||||
cp_group_rank = get_context_parallel_group_rank()
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
global_rank = torch.distributed.get_rank()
|
||||
global_world_size = torch.distributed.get_world_size()
|
||||
|
||||
input_ = input_.transpose(0, dim)
|
||||
|
||||
# pass from last rank
|
||||
send_rank = global_rank + 1
|
||||
recv_rank = global_rank - 1
|
||||
if send_rank % cp_world_size == 0:
|
||||
send_rank -= cp_world_size
|
||||
if recv_rank % cp_world_size == cp_world_size - 1:
|
||||
recv_rank += cp_world_size
|
||||
|
||||
if cp_rank < cp_world_size - 1:
|
||||
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
||||
if cp_rank > 0:
|
||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||
|
||||
if cp_rank == 0:
|
||||
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
|
||||
else:
|
||||
req_recv.wait()
|
||||
input_ = torch.cat([recv_buffer, input_], dim=0)
|
||||
|
||||
input_ = input_.transpose(0, dim).contiguous()
|
||||
|
||||
# print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _drop_from_previous_rank(input_, dim, kernel_size):
|
||||
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
|
||||
return input_
|
||||
|
||||
|
||||
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
ctx.kernel_size = kernel_size
|
||||
return _conv_split(input_, dim, kernel_size)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
|
||||
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
ctx.kernel_size = kernel_size
|
||||
return _conv_gather(input_, dim, kernel_size)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
|
||||
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
ctx.kernel_size = kernel_size
|
||||
return _pass_from_previous_rank(input_, dim, kernel_size)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
|
||||
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
|
||||
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
|
||||
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
def conv_pass_from_last_rank(input_, dim, kernel_size):
|
||||
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
class ContextParallelCausalConv3d(nn.Module):
|
||||
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
||||
|
||||
time_pad = time_kernel_size - 1
|
||||
height_pad = height_kernel_size // 2
|
||||
width_pad = width_kernel_size // 2
|
||||
|
||||
self.height_pad = height_pad
|
||||
self.width_pad = width_pad
|
||||
self.time_pad = time_pad
|
||||
self.time_kernel_size = time_kernel_size
|
||||
self.temporal_dim = 2
|
||||
|
||||
stride = (stride, stride, stride)
|
||||
dilation = (1, 1, 1)
|
||||
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, input_):
|
||||
# temporal padding inside
|
||||
if _USE_CP:
|
||||
input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
|
||||
else:
|
||||
input_ = input_.transpose(0, self.temporal_dim)
|
||||
input_parallel = torch.cat([input_[:1]] * (self.time_kernel_size - 1) + [input_], dim=0)
|
||||
input_parallel = input_parallel.transpose(0, self.temporal_dim)
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
|
||||
output_parallel = self.conv(input_parallel)
|
||||
output = output_parallel
|
||||
return output
|
||||
|
||||
|
||||
class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
||||
def forward(self, input_):
|
||||
if _USE_CP:
|
||||
input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1)
|
||||
output = super().forward(input_)
|
||||
if _USE_CP:
|
||||
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
|
||||
return output
|
||||
|
||||
|
||||
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
|
||||
if gather:
|
||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
else:
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
f_channels,
|
||||
zq_channels,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=False,
|
||||
pad_mode="constant",
|
||||
gather=False,
|
||||
**norm_layer_params,
|
||||
):
|
||||
super().__init__()
|
||||
if gather:
|
||||
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
|
||||
else:
|
||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
|
||||
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
||||
if freeze_norm_layer:
|
||||
for p in self.norm_layer.parameters:
|
||||
p.requires_grad = False
|
||||
|
||||
self.add_conv = add_conv
|
||||
if add_conv:
|
||||
self.conv = ContextParallelCausalConv3d(
|
||||
chan_in=zq_channels,
|
||||
chan_out=zq_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
self.conv_y = ContextParallelCausalConv3d(
|
||||
chan_in=zq_channels,
|
||||
chan_out=f_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
self.conv_b = ContextParallelCausalConv3d(
|
||||
chan_in=zq_channels,
|
||||
chan_out=f_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f, zq):
|
||||
if f.shape[2] == 1 and not _USE_CP:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
|
||||
elif get_context_parallel_rank() == 0:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq)
|
||||
|
||||
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
|
||||
norm_f = self.norm_layer(f)
|
||||
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
|
||||
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
def Normalize3D(
|
||||
in_channels,
|
||||
zq_ch,
|
||||
add_conv,
|
||||
gather=False,
|
||||
):
|
||||
return SpatialNorm3D(
|
||||
in_channels,
|
||||
zq_ch,
|
||||
gather=gather,
|
||||
# norm_layer=nn.GroupNorm,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=add_conv,
|
||||
num_groups=32,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
)
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
with_conv,
|
||||
compress_time=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
if x.shape[2] == 1 and not _USE_CP:
|
||||
x = torch.nn.functional.interpolate(x[:, :, 0], scale_factor=2.0, mode="nearest")[:, :, None, :, :]
|
||||
elif get_context_parallel_rank() == 0:
|
||||
# split first frame
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
if self.with_conv:
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
return x
|
||||
|
||||
|
||||
class DownSample3D(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
h, w = x.shape[-2:]
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
else:
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
return x
|
||||
|
||||
|
||||
class ContextParallelResnetBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
gather_norm=False,
|
||||
normalization=Normalize,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = normalization(
|
||||
in_channels,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
gather=gather_norm,
|
||||
)
|
||||
|
||||
self.conv1 = ContextParallelCausalConv3d(
|
||||
chan_in=in_channels,
|
||||
chan_out=out_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = normalization(
|
||||
out_channels,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
gather=gather_norm,
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = ContextParallelCausalConv3d(
|
||||
chan_in=out_channels,
|
||||
chan_out=out_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = ContextParallelCausalConv3d(
|
||||
chan_in=in_channels,
|
||||
chan_out=out_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x, temb, zq=None):
|
||||
h = x
|
||||
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm1(h, zq)
|
||||
else:
|
||||
h = self.norm1(h)
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm2(h, zq)
|
||||
else:
|
||||
h = self.norm2(h)
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class ContextParallelEncoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
gather_norm=False,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
self.conv_in = ContextParallelCausalConv3d(
|
||||
chan_in=in_channels,
|
||||
chan_out=self.ch,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dropout=dropout,
|
||||
temb_channels=self.temb_ch,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if i_level < self.temporal_compress_level:
|
||||
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
|
||||
else:
|
||||
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
|
||||
self.mid.block_2 = ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, gather=gather_norm)
|
||||
|
||||
self.conv_out = ContextParallelCausalConv3d(
|
||||
chan_in=block_in,
|
||||
chan_out=2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, x, use_cp=True):
|
||||
global _USE_CP
|
||||
_USE_CP = use_cp
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
h = self.norm_out(h)
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class ContextParallelDecoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
gather_norm=False,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
self.conv_in = ContextParallelCausalConv3d(
|
||||
chan_in=z_channels,
|
||||
chan_out=block_in,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
normalization=Normalize3D,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
|
||||
self.mid.block_2 = ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
normalization=Normalize3D,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
normalization=Normalize3D,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
|
||||
|
||||
self.conv_out = ContextParallelCausalConv3d(
|
||||
chan_in=block_in,
|
||||
chan_out=out_ch,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, z, use_cp=True):
|
||||
global _USE_CP
|
||||
_USE_CP = use_cp
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
t = z.shape[2]
|
||||
# z to block_in
|
||||
|
||||
zq = z
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
_USE_CP = True
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.conv.weight
|
||||
return output
|
Loading…
x
Reference in New Issue
Block a user