Merge pull request #133 from THUDM/CogVideoX_dev

new req and dev
This commit is contained in:
zR 2024-08-16 11:00:49 +08:00 committed by GitHub
commit 96a997a777
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 88 additions and 917 deletions

View File

@ -22,7 +22,11 @@
## Update and News
- 🔥🔥 **News**: ```2024/8/12```: The CogVideoX paper has been uploaded to arxiv. Feel free to check out
- 🔥🔥 **News**: ```2024/8/15```: The `SwissArmyTransformer` dependency in CogVideoX has been upgraded to `0.4.12`. Fine-tuning
no longer requires installing `SwissArmyTransformer` from source. Additionally, the `Tied VAE` technique has been
applied in the implementation within the `diffusers` library. Please install `diffusers` and `accelerate` libraries
from source. Inference for CogVideoX now requires only 12GB of VRAM.
- 🔥 **News**: ```2024/8/12```: The CogVideoX paper has been uploaded to arxiv. Feel free to check out
the [paper](https://arxiv.org/abs/2408.06072).
- 🔥 **News**: ```2024/8/7```: CogVideoX has been integrated into `diffusers` version 0.30.0. Inference can now be
performed

View File

@ -21,8 +21,8 @@
</p>
## 更新とニュース
- 🔥🔥 **ニュース**: ```2024/8/12```: CogVideoX 論文がarxivにアップロードされました。ぜひ[論文](https://arxiv.org/abs/2408.06072)をご覧ください。
- 🔥🔥 **ニュース**: 2024/8/15: CogVideoX の依存関係である`SwissArmyTransformer`の依存が`0.4.12`にアップグレードされました。これにより、微調整の際に`SwissArmyTransformer`をソースコードからインストールする必要がなくなりました。同時に、`Tied VAE` 技術が `diffusers` ライブラリの実装に適用されました。`diffusers``accelerate` ライブラリをソースコードからインストールしてください。CogVdideoX の推論には 12GB の VRAM だけが必要です。
- 🔥 **ニュース**: ```2024/8/12```: CogVideoX 論文がarxivにアップロードされました。ぜひ[論文](https://arxiv.org/abs/2408.06072)をご覧ください。
- 🔥 **ニュース**: ```2024/8/7```: CogVideoX は `diffusers` バージョン 0.30.0 に統合されました。単一の 3090 GPU
で推論を実行できます。詳細については [コード](inference/cli_demo.py) を参照してください。
- 🔥 **ニュース**: ```2024/8/6```: **CogVideoX-2B** で使用される **3D Causal VAE** もオープンソース化しました。これにより、ビデオをほぼ無損失で再構築できます。

View File

@ -23,7 +23,10 @@
## 项目更新
- 🔥🔥 **News**: ```2024/8/12```: CogVideoX 论文已上传到arxiv欢迎查看[论文](https://arxiv.org/abs/2408.06072)。
- 🔥🔥 **News**: ```2024/8/15```: CogVideoX 依赖中`SwissArmyTransformer`依赖升级到`0.4.12`,
微调不再需要从源代码安装`SwissArmyTransformer`。同时,`Tied VAE` 技术已经被应用到 `diffusers`
库中的实现,请从源代码安装 `diffusers``accelerate` 库,推理 CogVdideoX 仅需 12GB显存。
- 🔥 **News**: ```2024/8/12```: CogVideoX 论文已上传到arxiv欢迎查看[论文](https://arxiv.org/abs/2408.06072)。
- 🔥 **News**: ```2024/8/7```: CogVideoX 已经合并入 `diffusers`
0.30.0版本单张3090可以推理详情请见[代码](inference/cli_demo.py)。
- 🔥 **News**: ```2024/8/6```: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。

View File

@ -2,14 +2,15 @@
This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
Note:
This script requires the `diffusers>=0.30.0` library to be installed.
If the video exported using OpenCV appears completely green and cannot be viewed, lease switch to a different player to watch it. This is a normal phenomenon.
This script requires the `diffusers>=0.30.0` library to be installed after `diffusers 0.31.0` release,
need to update.
Run the script:
$ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
"""
import gc
import argparse
import tempfile
from typing import Union, List
@ -18,11 +19,10 @@ import PIL
import imageio
import numpy as np
import torch
from diffusers import CogVideoXPipeline
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler
def export_to_video_imageio(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
) -> str:
"""
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
@ -38,14 +38,13 @@ def export_to_video_imageio(
def generate_video(
prompt: str,
model_path: str,
output_path: str = "./output.mp4",
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
device: str = "cuda",
dtype: torch.dtype = torch.float16,
prompt: str,
model_path: str,
output_path: str = "./output.mp4",
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
dtype: torch.dtype = torch.float16,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
@ -57,36 +56,46 @@ def generate_video(
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
- dtype (torch.dtype): The data type for computation (default is torch.float16).
"""
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
# add device_map="balanced" in the from_pretrained function and remove
# `pipe.enable_model_cpu_offload()` to enable Multi GPUs (2 or more and each one must have more than 20GB memory) inference.
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (float16).
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
# function to use Multi GPUs.
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
# 2. Set Scheduler.
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
# We recommend using `CogVideoXDDIMScheduler` for better results.
pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# 3. Enable CPU offload for the model and reset the memory, enable tiling.
pipe.enable_model_cpu_offload()
# Encode the prompt to get the prompt embeddings
prompt_embeds, _ = pipe.encode_prompt(
prompt=prompt, # The textual description for video generation
negative_prompt=None, # The negative prompt to guide the video generation
do_classifier_free_guidance=True, # Whether to use classifier-free guidance
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
max_sequence_length=226, # Maximum length of the sequence, must be 226
device=device, # Device to use for computation
dtype=dtype, # Data type for computation
)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
# Generate the video frames using the pipeline
# Using with diffusers branch `main` to enable tiling. This will cost ONLY 12GB GPU memory.
# pipe.vae.enable_tiling()
# 4. Generate the video frames based on the prompt.
# `num_frames` is the Number of frames to generate.
# This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame.
# for diffusers version `0.30.0`, this should be 48. and for `0.31.0` and after, this should be 49.
video = pipe(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
num_inference_steps=num_inference_steps, # Number of inference steps
num_frames=48, # Number of frames to generatechanged to 49 for diffusers version `0.31.0` and after.
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
prompt_embeds=prompt_embeds, # Encoded prompt embeddings
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
generator=torch.Generator().manual_seed(42), # Set the seed for reproducibility
).frames[0]
# Export the generated frames to a video file. fps must be 8
# 5. Export the generated frames to a video file. fps must be 8
export_to_video_imageio(video, output_path, fps=8)
@ -104,10 +113,6 @@ if __name__ == "__main__":
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
)
parser.add_argument(
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
)
@ -125,6 +130,5 @@ if __name__ == "__main__":
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos_per_prompt,
device=args.device,
dtype=dtype,
)

View File

@ -1,14 +1,14 @@
diffusers>=0.30.0
transformers>=4.43.4
accelerate>=0.33.0
diffusers==0.30.0
transformers==4.44.0
accelerate==0.33.0
sentencepiece==0.2.0 # T5
SwissArmyTransformer==0.4.11 # Inference
SwissArmyTransformer==0.4.12 # Inference
torch==2.4.0 # Tested in 2.2 2.3 2.4 and 2.5
torchvision==0.19.0
gradio==4.40.0 # For HF gradio demo
pillow==9.5.0 # For HF gradio demo
streamlit==1.37.0 # For streamlit web demo
opencv-python==4.10 # For diffusers inference origin export video
imageio==2.34.2 # For diffusers inference export video
imageio-ffmpeg==0.5.1 # For diffusers inference export video
openai==1.38.0 # For prompt refiner
openai==1.40.6 # For prompt refiner
moviepy==1.0.3 # For export video

View File

@ -120,22 +120,6 @@ bash inference.sh
## Fine-Tuning the Model
### Preparing the Environment
Please note that currently, SAT needs to be installed from the source code for proper fine-tuning.
You need to get the code from the source to support the fine-tuning functionality, as these features have not yet been
released in the Pip package.
We will address this issue in future stable releases.
```
git clone https://github.com/THUDM/SwissArmyTransformer.git
cd SwissArmyTransformer
pip install -e .
```
### Preparing the Dataset
The dataset format should be as follows:

View File

@ -118,17 +118,6 @@ bash inference.sh
## モデルのファインチューニング
### 環境の準備
ご注意ください、現在、SATを正常にファインチューニングするためには、ソースコードからインストールする必要があります。
これは、まだpipパッケージバージョンにリリースされていない最新の機能を使用する必要があるためです。この問題は、今後の安定版で解決する予定です。
```
git clone https://github.com/THUDM/SwissArmyTransformer.git
cd SwissArmyTransformer
pip install -e .
```
### データセットの準備
データセットの形式は次のようになります:

View File

@ -114,18 +114,6 @@ bash inference.sh
## 微调模型
### 准备环境
请注意目前SAT需要从源码安装才能正常微调。
这是因为你需要使用还没发型到pip包版本的最新代码所支持的功能。
我们将会在未来的稳定版本解决这个问题。
```
git clone https://github.com/THUDM/SwissArmyTransformer.git
cd SwissArmyTransformer
pip install -e .
```
### 准备数据集
数据集格式应该如下:

View File

@ -1,3 +1,4 @@
In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
In the haunting backdrop of a warIn the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.

View File

@ -240,17 +240,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
text_length,
theta=10000,
rot_v=False,
pnp=False,
learnable_pos_embed=False,
):
super().__init__()
self.rot_v = rot_v
self.text_length = text_length
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
# 'lang':
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
@ -268,12 +266,7 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
# (T H W D)
self.pnp = pnp
if not self.pnp:
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.contiguous()
freqs_sin = freqs.sin()
@ -281,40 +274,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
self.register_buffer("freqs_sin", freqs_sin)
self.register_buffer("freqs_cos", freqs_cos)
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
else:
self.pos_embedding = None
def rotary(self, t, **kwargs):
if self.pnp:
t_coords = kwargs["rope_position_ids"][:, :, 0]
x_coords = kwargs["rope_position_ids"][:, :, 1]
y_coords = kwargs["rope_position_ids"][:, :, 2]
mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
else:
def reshape_freq(freqs):
frame = t.shape[2]
freqs = freqs[:frame].contiguous()
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos)
freqs_sin = reshape_freq(self.freqs_sin)
seq_len = t.shape[2]
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None:
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
else:
return None
return None
def attention_fn(
self,
@ -329,64 +297,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
):
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
if self.pnp:
query_layer = self.rotary(query_layer, **kwargs)
key_layer = self.rotary(key_layer, **kwargs)
if self.rot_v:
value_layer = self.rotary(value_layer)
else:
query_layer = torch.cat(
(
query_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
query_layer[
:,
:,
kwargs["text_length"] :,
]
),
),
dim=2,
)
key_layer = torch.cat(
(
key_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
key_layer[
:,
:,
kwargs["text_length"] :,
]
),
),
dim=2,
)
if self.rot_v:
value_layer = torch.cat(
(
value_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
value_layer[
:,
:,
kwargs["text_length"] :,
]
),
),
dim=2,
)
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
if self.rot_v:
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
return attention_fn_default(
query_layer,

View File

@ -1,17 +1,16 @@
SwissArmyTransformer==0.4.11
diffusers>=0.29.2
omegaconf>=2.3.0
torch>=2.3.1
torchvision>=0.19.0
pytorch_lightning>=2.3.3
kornia>=0.7.3
beartype>=0.18.5
numpy>=2.0.1
fsspec>=2024.5.0
safetensors>=0.4.3
imageio-ffmpeg>=0.5.1
imageio>=2.34.2
scipy>=1.14.0
decord>=0.6.0
wandb>=0.17.5
deepspeed>=0.14.4
SwissArmyTransformer==0.4.12
omegaconf==2.3.0
torch==2.4.0
torchvision==0.19.0
pytorch_lightning==2.3.3
kornia==0.7.3
beartype==0.18.5
numpy==2.0.1
fsspec==2024.5.0
safetensors==0.4.3
imageio-ffmpeg==0.5.1
imageio==2.34.2
scipy==1.14.0
decord==0.6.0
wandb==0.17.5
deepspeed==0.14.4

View File

@ -2,25 +2,12 @@ import math
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange
from ..util import (
get_context_parallel_group,
get_context_parallel_rank,
get_context_parallel_world_size,
get_context_parallel_group_rank,
)
# try:
from ..util import SafeConv3d as Conv3d
# except:
# # Degrade to normal Conv3d if SafeConv3d is not available
# from torch.nn import Conv3d
)
_USE_CP = True
@ -192,706 +179,4 @@ def _conv_gather(input_, dim, kernel_size):
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _pass_from_previous_rank(input_, dim, kernel_size):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
cp_group_rank = get_context_parallel_group_rank()
cp_world_size = get_context_parallel_world_size()
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
input_ = input_.transpose(0, dim)
# pass from last rank
send_rank = global_rank + 1
recv_rank = global_rank - 1
if send_rank % cp_world_size == 0:
send_rank -= cp_world_size
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
else:
req_recv.wait()
input_ = torch.cat([recv_buffer, input_], dim=0)
input_ = input_.transpose(0, dim).contiguous()
# print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
return input_
def _drop_from_previous_rank(input_, dim, kernel_size):
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
return input_
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_split(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_gather(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _pass_from_previous_rank(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
def conv_pass_from_last_rank(input_, dim, kernel_size):
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
class ContextParallelCausalConv3d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
time_pad = time_kernel_size - 1
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_kernel_size = time_kernel_size
self.temporal_dim = 2
stride = (stride, stride, stride)
dilation = (1, 1, 1)
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input_):
# temporal padding inside
if _USE_CP:
input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
else:
input_ = input_.transpose(0, self.temporal_dim)
input_parallel = torch.cat([input_[:1]] * (self.time_kernel_size - 1) + [input_], dim=0)
input_parallel = input_parallel.transpose(0, self.temporal_dim)
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
output_parallel = self.conv(input_parallel)
output = output_parallel
return output
class ContextParallelGroupNorm(torch.nn.GroupNorm):
def forward(self, input_):
if _USE_CP:
input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1)
output = super().forward(input_)
if _USE_CP:
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
return output
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else:
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
freeze_norm_layer=False,
add_conv=False,
pad_mode="constant",
gather=False,
**norm_layer_params,
):
super().__init__()
if gather:
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
else:
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if add_conv:
self.conv = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=zq_channels,
kernel_size=3,
)
self.conv_y = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
kernel_size=1,
)
self.conv_b = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
kernel_size=1,
)
def forward(self, f, zq):
if f.shape[2] == 1 and not _USE_CP:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
elif get_context_parallel_rank() == 0:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
if self.add_conv:
zq = self.conv(zq)
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f)
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(
in_channels,
zq_ch,
add_conv,
gather=False,
):
return SpatialNorm3D(
in_channels,
zq_ch,
gather=gather,
# norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class Upsample3D(nn.Module):
def __init__(
self,
in_channels,
with_conv,
compress_time=False,
):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
if x.shape[2] == 1 and not _USE_CP:
x = torch.nn.functional.interpolate(x[:, :, 0], scale_factor=2.0, mode="nearest")[:, :, None, :, :]
elif get_context_parallel_rank() == 0:
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.with_conv:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class DownSample3D(nn.Module):
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
super().__init__()
self.with_conv = with_conv
if out_channels is None:
out_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t")
if x.shape[-1] % 2 == 1:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else:
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
else:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class ContextParallelResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
gather_norm=False,
normalization=Normalize,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = normalization(
in_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.conv1 = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = normalization(
out_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = ContextParallelCausalConv3d(
chan_in=out_channels,
chan_out=out_channels,
kernel_size=3,
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
else:
self.nin_shortcut = Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x, temb, zq=None):
h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm1(h, zq)
else:
h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm2(h, zq)
else:
h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class ContextParallelEncoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
pad_mode="first",
temporal_compress_times=4,
gather_norm=False,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
self.conv_in = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=self.ch,
kernel_size=3,
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
temb_channels=self.temb_ch,
gather_norm=gather_norm,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
else:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
gather_norm=gather_norm,
)
# end
self.norm_out = Normalize(block_in, gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=2 * z_channels if double_z else z_channels,
kernel_size=3,
)
def forward(self, x, use_cp=True):
global _USE_CP
_USE_CP = use_cp
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class ContextParallelDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
gather_norm=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels,
chan_out=block_in,
kernel_size=3,
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
self.up.insert(0, up)
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=out_ch,
kernel_size=3,
)
def forward(self, z, use_cp=True):
global _USE_CP
_USE_CP = use_cp
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
_USE_CP = True
return h
def get_last_layer(self):
return self.conv_out.conv.weight
return output