mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
182 lines
5.1 KiB
Python
182 lines
5.1 KiB
Python
import math
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn as nn
|
|
from ..util import (
|
|
get_context_parallel_group,
|
|
get_context_parallel_rank,
|
|
get_context_parallel_world_size,
|
|
)
|
|
|
|
_USE_CP = True
|
|
|
|
|
|
def cast_tuple(t, length=1):
|
|
return t if isinstance(t, tuple) else ((t,) * length)
|
|
|
|
|
|
def divisible_by(num, den):
|
|
return (num % den) == 0
|
|
|
|
|
|
def is_odd(n):
|
|
return not divisible_by(n, 2)
|
|
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
|
|
def pair(t):
|
|
return t if isinstance(t, tuple) else (t, t)
|
|
|
|
|
|
def get_timestep_embedding(timesteps, embedding_dim):
|
|
"""
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
|
From Fairseq.
|
|
Build sinusoidal embeddings.
|
|
This matches the implementation in tensor2tensor, but differs slightly
|
|
from the description in Section 3.5 of "Attention Is All You Need".
|
|
"""
|
|
assert len(timesteps.shape) == 1
|
|
|
|
half_dim = embedding_dim // 2
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
|
emb = emb.to(device=timesteps.device)
|
|
emb = timesteps.float()[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if embedding_dim % 2 == 1: # zero pad
|
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
|
return emb
|
|
|
|
|
|
def nonlinearity(x):
|
|
# swish
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
def leaky_relu(p=0.1):
|
|
return nn.LeakyReLU(p)
|
|
|
|
|
|
def _split(input_, dim):
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
if cp_world_size == 1:
|
|
return input_
|
|
|
|
cp_rank = get_context_parallel_rank()
|
|
|
|
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
|
|
|
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
|
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
|
dim_size = input_.size()[dim] // cp_world_size
|
|
|
|
input_list = torch.split(input_, dim_size, dim=dim)
|
|
output = input_list[cp_rank]
|
|
|
|
if cp_rank == 0:
|
|
output = torch.cat([inpu_first_frame_, output], dim=dim)
|
|
output = output.contiguous()
|
|
|
|
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
|
|
|
|
return output
|
|
|
|
|
|
def _gather(input_, dim):
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
# Bypass the function if context parallel is 1
|
|
if cp_world_size == 1:
|
|
return input_
|
|
|
|
group = get_context_parallel_group()
|
|
cp_rank = get_context_parallel_rank()
|
|
|
|
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
|
|
|
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
|
if cp_rank == 0:
|
|
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
|
|
|
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
|
|
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
|
]
|
|
|
|
if cp_rank == 0:
|
|
input_ = torch.cat([input_first_frame_, input_], dim=dim)
|
|
|
|
tensor_list[cp_rank] = input_
|
|
torch.distributed.all_gather(tensor_list, input_, group=group)
|
|
|
|
output = torch.cat(tensor_list, dim=dim).contiguous()
|
|
|
|
# print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
|
|
|
|
return output
|
|
|
|
|
|
def _conv_split(input_, dim, kernel_size):
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
# Bypass the function if context parallel is 1
|
|
if cp_world_size == 1:
|
|
return input_
|
|
|
|
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
|
|
|
cp_rank = get_context_parallel_rank()
|
|
|
|
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
|
|
|
|
if cp_rank == 0:
|
|
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
|
else:
|
|
output = input_.transpose(dim, 0)[
|
|
cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size
|
|
].transpose(dim, 0)
|
|
output = output.contiguous()
|
|
|
|
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
|
|
|
return output
|
|
|
|
|
|
def _conv_gather(input_, dim, kernel_size):
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
# Bypass the function if context parallel is 1
|
|
if cp_world_size == 1:
|
|
return input_
|
|
|
|
group = get_context_parallel_group()
|
|
cp_rank = get_context_parallel_rank()
|
|
|
|
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
|
|
|
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
|
|
if cp_rank == 0:
|
|
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
|
|
else:
|
|
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous()
|
|
|
|
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
|
|
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
|
]
|
|
if cp_rank == 0:
|
|
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
|
|
|
|
tensor_list[cp_rank] = input_
|
|
torch.distributed.all_gather(tensor_list, input_, group=group)
|
|
|
|
# Note: torch.cat already creates a contiguous tensor.
|
|
output = torch.cat(tensor_list, dim=dim).contiguous()
|
|
|
|
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
|
|
|
return output
|