mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
update with test code
This commit is contained in:
parent
b6abbeab97
commit
3a9af5bdd9
@ -36,6 +36,7 @@ def add_sampling_config_args(parser):
|
||||
group.add_argument("--input-dir", type=str, default=None)
|
||||
group.add_argument("--input-type", type=str, default="cli")
|
||||
group.add_argument("--input-file", type=str, default="input.txt")
|
||||
group.add_argument("--sampling-image-size", type=list, default=[768, 1360])
|
||||
group.add_argument("--final-size", type=int, default=2048)
|
||||
group.add_argument("--sdedit", action="store_true")
|
||||
group.add_argument("--grid-num-rows", type=int, default=1)
|
||||
|
BIN
sat/configs/images.jpg
Normal file
BIN
sat/configs/images.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 35 KiB |
@ -185,7 +185,12 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
||||
else:
|
||||
kwargs = {}
|
||||
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
|
||||
frame = z.shape[2] * 4 - 3
|
||||
if frame <= 9:
|
||||
use_cp = False
|
||||
else:
|
||||
use_cp = True
|
||||
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs)
|
||||
all_out.append(out)
|
||||
out = torch.cat(all_out, dim=0)
|
||||
return out
|
||||
@ -218,6 +223,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
shape: Union[None, Tuple, List] = None,
|
||||
prefix=None,
|
||||
concat_images=None,
|
||||
ofs=None,
|
||||
**kwargs,
|
||||
):
|
||||
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
|
||||
@ -241,7 +247,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
|
||||
)
|
||||
|
||||
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
|
||||
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
|
||||
samples = samples.to(self.dtype)
|
||||
return samples
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
from functools import partial
|
||||
from einops import rearrange, repeat
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
@ -13,38 +15,34 @@ from sat.mpu.layers import ColumnParallelLinear
|
||||
from sgm.util import instantiate_from_config
|
||||
|
||||
from sgm.modules.diffusionmodules.openaimodel import Timestep
|
||||
from sgm.modules.diffusionmodules.util import (
|
||||
linear,
|
||||
timestep_embedding,
|
||||
)
|
||||
from sgm.modules.diffusionmodules.util import linear, timestep_embedding
|
||||
from sat.ops.layernorm import LayerNorm, RMSNorm
|
||||
|
||||
|
||||
class ImagePatchEmbeddingMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
bias=True,
|
||||
text_hidden_size=None,
|
||||
):
|
||||
def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size)
|
||||
if text_hidden_size is not None:
|
||||
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
||||
else:
|
||||
self.text_proj = None
|
||||
|
||||
def word_embedding_forward(self, input_ids, **kwargs):
|
||||
# now is 3d patch
|
||||
images = kwargs["images"] # (b,t,c,h,w)
|
||||
B, T = images.shape[:2]
|
||||
emb = images.view(-1, *images.shape[2:])
|
||||
emb = self.proj(emb) # ((b t),d,h/2,w/2)
|
||||
emb = emb.view(B, T, *emb.shape[1:])
|
||||
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
|
||||
emb = rearrange(emb, "b t n d -> b (t n) d")
|
||||
emb = rearrange(images, "b t c h w -> b (t h w) c")
|
||||
emb = rearrange(
|
||||
emb,
|
||||
"b (t o h p w q) c -> b (t h w) (c o p q)",
|
||||
t=kwargs["rope_T"],
|
||||
h=kwargs["rope_H"],
|
||||
w=kwargs["rope_W"],
|
||||
o=self.patch_size[0],
|
||||
p=self.patch_size[1],
|
||||
q=self.patch_size[2],
|
||||
)
|
||||
emb = self.proj(emb)
|
||||
|
||||
if self.text_proj is not None:
|
||||
text_emb = self.text_proj(kwargs["encoder_outputs"])
|
||||
@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed(
|
||||
grid_size: int of the grid height and width
|
||||
t_size: int of the temporal size
|
||||
return:
|
||||
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim]
|
||||
(w/ or w/o cls_token)
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
embed_dim_spatial = embed_dim // 4 * 3
|
||||
@ -100,7 +99,6 @@ def get_3d_sincos_pos_embed(
|
||||
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
|
||||
|
||||
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
|
||||
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
|
||||
|
||||
return pos_embed # [T, H*W, D]
|
||||
|
||||
@ -259,6 +257,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
text_length,
|
||||
theta=10000,
|
||||
rot_v=False,
|
||||
height_interpolation=1.0,
|
||||
width_interpolation=1.0,
|
||||
time_interpolation=1.0,
|
||||
learnable_pos_embed=False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -285,14 +286,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
|
||||
freqs = freqs.contiguous()
|
||||
freqs_sin = freqs.sin()
|
||||
freqs_cos = freqs.cos()
|
||||
self.register_buffer("freqs_sin", freqs_sin)
|
||||
self.register_buffer("freqs_cos", freqs_cos)
|
||||
|
||||
self.freqs_sin = freqs.sin().cuda()
|
||||
self.freqs_cos = freqs.cos().cuda()
|
||||
self.text_length = text_length
|
||||
if learnable_pos_embed:
|
||||
num_patches = height * width * compressed_num_frames + text_length
|
||||
@ -301,15 +298,20 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
self.pos_embedding = None
|
||||
|
||||
def rotary(self, t, **kwargs):
|
||||
seq_len = t.shape[2]
|
||||
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
def reshape_freq(freqs):
|
||||
freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
||||
return freqs
|
||||
|
||||
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
|
||||
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
|
||||
|
||||
return t * freqs_cos + rotate_half(t) * freqs_sin
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
if self.pos_embedding is not None:
|
||||
return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]]
|
||||
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -326,10 +328,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
):
|
||||
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
||||
|
||||
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
|
||||
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
|
||||
query_layer = torch.cat(
|
||||
(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
key_layer = torch.cat(
|
||||
(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
if self.rot_v:
|
||||
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
|
||||
value_layer = torch.cat(
|
||||
(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
return attention_fn_default(
|
||||
query_layer,
|
||||
@ -347,21 +400,25 @@ def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
|
||||
def unpatchify(x, c, patch_size, w, h, **kwargs):
|
||||
"""
|
||||
x: (N, T/2 * S, patch_size**3 * C)
|
||||
imgs: (N, T, H, W, C)
|
||||
|
||||
patch_size 被拆解为三个不同的维度 (o, p, q),分别对应了深度(o)、高度(p)和宽度(q)。这使得 patch 大小在不同维度上可以不相等,增加了灵活性。
|
||||
"""
|
||||
if rope_position_ids is not None:
|
||||
assert NotImplementedError
|
||||
# do pix2struct unpatchify
|
||||
L = x.shape[1]
|
||||
x = x.reshape(shape=(x.shape[0], L, p, p, c))
|
||||
x = torch.einsum("nlpqc->ncplq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
|
||||
else:
|
||||
b = x.shape[0]
|
||||
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
|
||||
|
||||
imgs = rearrange(
|
||||
x,
|
||||
"b (t h w) (c o p q) -> b (t o) c (h p) (w q)",
|
||||
c=c,
|
||||
o=patch_size[0],
|
||||
p=patch_size[1],
|
||||
q=patch_size[2],
|
||||
t=kwargs["rope_T"],
|
||||
h=kwargs["rope_H"],
|
||||
w=kwargs["rope_W"],
|
||||
)
|
||||
|
||||
return imgs
|
||||
|
||||
@ -382,27 +439,17 @@ class FinalLayerMixin(BaseMixin):
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
|
||||
|
||||
self.spatial_length = latent_width * latent_height // patch_size**2
|
||||
self.latent_width = latent_width
|
||||
self.latent_height = latent_height
|
||||
|
||||
def final_forward(self, logits, **kwargs):
|
||||
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
|
||||
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分
|
||||
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
|
||||
return unpatchify(
|
||||
x,
|
||||
c=self.out_channels,
|
||||
p=self.patch_size,
|
||||
w=self.latent_width // self.patch_size,
|
||||
h=self.latent_height // self.patch_size,
|
||||
rope_position_ids=kwargs.get("rope_position_ids", None),
|
||||
**kwargs,
|
||||
x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs
|
||||
)
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
@ -440,8 +487,6 @@ class SwiGLUMixin(BaseMixin):
|
||||
class AdaLNMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
time_embed_dim,
|
||||
@ -452,8 +497,6 @@ class AdaLNMixin(BaseMixin):
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.compressed_num_frames = compressed_num_frames
|
||||
|
||||
self.adaLN_modulations = nn.ModuleList(
|
||||
@ -611,7 +654,8 @@ class DiffusionTransformer(BaseModel):
|
||||
time_interpolation=1.0,
|
||||
use_SwiGLU=False,
|
||||
use_RMSNorm=False,
|
||||
zero_init_y_embed=False,
|
||||
cfg_embed_dim=None,
|
||||
ofs_embed_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.latent_width = latent_width
|
||||
@ -619,12 +663,14 @@ class DiffusionTransformer(BaseModel):
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.time_compressed_rate = time_compressed_rate
|
||||
self.spatial_length = latent_width * latent_height // patch_size**2
|
||||
self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:])
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.model_channels = hidden_size
|
||||
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
|
||||
self.cfg_embed_dim = cfg_embed_dim
|
||||
self.ofs_embed_dim = ofs_embed_dim
|
||||
self.num_classes = num_classes
|
||||
self.adm_in_channels = adm_in_channels
|
||||
self.input_time = input_time
|
||||
@ -636,7 +682,6 @@ class DiffusionTransformer(BaseModel):
|
||||
self.width_interpolation = width_interpolation
|
||||
self.time_interpolation = time_interpolation
|
||||
self.inner_hidden_size = hidden_size * 4
|
||||
self.zero_init_y_embed = zero_init_y_embed
|
||||
try:
|
||||
self.dtype = str_to_dtype[kwargs.pop("dtype")]
|
||||
except:
|
||||
@ -669,7 +714,6 @@ class DiffusionTransformer(BaseModel):
|
||||
|
||||
def _build_modules(self, module_configs):
|
||||
model_channels = self.hidden_size
|
||||
# time_embed_dim = model_channels * 4
|
||||
time_embed_dim = self.time_embed_dim
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
@ -677,6 +721,20 @@ class DiffusionTransformer(BaseModel):
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.ofs_embed_dim is not None:
|
||||
self.ofs_embed = nn.Sequential(
|
||||
linear(self.ofs_embed_dim, self.ofs_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(self.ofs_embed_dim, self.ofs_embed_dim),
|
||||
)
|
||||
|
||||
if self.cfg_embed_dim is not None:
|
||||
self.cfg_embed = nn.Sequential(
|
||||
linear(self.cfg_embed_dim, self.cfg_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(self.cfg_embed_dim, self.cfg_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
|
||||
@ -701,9 +759,6 @@ class DiffusionTransformer(BaseModel):
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
)
|
||||
if self.zero_init_y_embed:
|
||||
nn.init.constant_(self.label_emb[0][2].weight, 0)
|
||||
nn.init.constant_(self.label_emb[0][2].bias, 0)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@ -712,10 +767,13 @@ class DiffusionTransformer(BaseModel):
|
||||
"pos_embed",
|
||||
instantiate_from_config(
|
||||
pos_embed_config,
|
||||
height=self.latent_height // self.patch_size,
|
||||
width=self.latent_width // self.patch_size,
|
||||
height=self.latent_height // self.patch_size[1],
|
||||
width=self.latent_width // self.patch_size[2],
|
||||
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
height_interpolation=self.height_interpolation,
|
||||
width_interpolation=self.width_interpolation,
|
||||
time_interpolation=self.time_interpolation,
|
||||
),
|
||||
reinit=True,
|
||||
)
|
||||
@ -737,8 +795,6 @@ class DiffusionTransformer(BaseModel):
|
||||
"adaln_layer",
|
||||
instantiate_from_config(
|
||||
adaln_layer_config,
|
||||
height=self.latent_height // self.patch_size,
|
||||
width=self.latent_width // self.patch_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
||||
@ -749,7 +805,6 @@ class DiffusionTransformer(BaseModel):
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
final_layer_config = module_configs["final_layer_config"]
|
||||
self.add_mixin(
|
||||
"final_layer",
|
||||
@ -766,25 +821,18 @@ class DiffusionTransformer(BaseModel):
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
if "lora_config" in module_configs:
|
||||
lora_config = module_configs["lora_config"]
|
||||
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
|
||||
|
||||
return
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
b, t, d, h, w = x.shape
|
||||
if x.dtype != self.dtype:
|
||||
x = x.to(self.dtype)
|
||||
|
||||
# This is not use in inference
|
||||
if "concat_images" in kwargs and kwargs["concat_images"] is not None:
|
||||
if kwargs["concat_images"].shape[0] != x.shape[0]:
|
||||
concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
concat_images = kwargs["concat_images"]
|
||||
x = torch.cat([x, concat_images], dim=2)
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
@ -792,17 +840,33 @@ class DiffusionTransformer(BaseModel):
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
# assert y.shape[0] == x.shape[0]
|
||||
assert x.shape[0] % y.shape[0] == 0
|
||||
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
kwargs["seq_length"] = t * h * w // (self.patch_size**2)
|
||||
if self.ofs_embed_dim is not None:
|
||||
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
|
||||
ofs_emb = self.ofs_embed(ofs_emb)
|
||||
emb = emb + ofs_emb
|
||||
if self.cfg_embed_dim is not None:
|
||||
cfg_emb = kwargs["scale_emb"]
|
||||
cfg_emb = self.cfg_embed(cfg_emb)
|
||||
emb = emb + cfg_emb
|
||||
|
||||
if "ofs" in kwargs.keys():
|
||||
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
|
||||
ofs_emb = self.ofs_embed(ofs_emb)
|
||||
|
||||
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
|
||||
kwargs["images"] = x
|
||||
kwargs["emb"] = emb
|
||||
kwargs["encoder_outputs"] = context
|
||||
kwargs["text_length"] = context.shape[1]
|
||||
|
||||
kwargs["rope_T"] = t // self.patch_size[0]
|
||||
kwargs["rope_H"] = h // self.patch_size[1]
|
||||
kwargs["rope_W"] = w // self.patch_size[2]
|
||||
|
||||
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
|
||||
output = super().forward(**kwargs)[0]
|
||||
return output
|
||||
|
@ -1,16 +1,11 @@
|
||||
SwissArmyTransformer==0.4.12
|
||||
omegaconf==2.3.0
|
||||
torch==2.4.0
|
||||
torchvision==0.19.0
|
||||
pytorch_lightning==2.3.3
|
||||
kornia==0.7.3
|
||||
beartype==0.18.5
|
||||
numpy==2.0.1
|
||||
fsspec==2024.5.0
|
||||
safetensors==0.4.3
|
||||
imageio-ffmpeg==0.5.1
|
||||
imageio==2.34.2
|
||||
scipy==1.14.0
|
||||
decord==0.6.0
|
||||
wandb==0.17.5
|
||||
deepspeed==0.14.4
|
||||
SwissArmyTransformer>=0.4.12
|
||||
omegaconf>=2.3.0
|
||||
pytorch_lightning>=2.4.0
|
||||
kornia>=0.7.3
|
||||
beartype>=0.19.0
|
||||
fsspec>=2024.2.0
|
||||
safetensors>=0.4.5
|
||||
scipy>=1.14.1
|
||||
decord>=0.6.0
|
||||
wandb>=0.18.5
|
||||
deepspeed>=0.15.3
|
@ -4,24 +4,20 @@ import argparse
|
||||
from typing import List, Union
|
||||
from tqdm import tqdm
|
||||
from omegaconf import ListConfig
|
||||
from PIL import Image
|
||||
import imageio
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
import torchvision.transforms as TT
|
||||
|
||||
|
||||
from sat.model.base_model import get_model
|
||||
from sat.training.model_io import load_checkpoint
|
||||
from sat import mpu
|
||||
|
||||
from diffusion_video import SATVideoDiffusionEngine
|
||||
from arguments import get_args
|
||||
from torchvision.transforms.functional import center_crop, resize
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_from_cli():
|
||||
cnt = 0
|
||||
@ -56,6 +52,42 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
|
||||
if key == "txt":
|
||||
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
elif key == "fps":
|
||||
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "fps_id":
|
||||
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "motion_bucket_id":
|
||||
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "pool_image":
|
||||
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
|
||||
elif key == "cond_aug":
|
||||
batch[key] = repeat(
|
||||
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
||||
"1 -> b",
|
||||
b=math.prod(N),
|
||||
)
|
||||
elif key == "cond_frames":
|
||||
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
||||
elif key == "cond_frames_without_noise":
|
||||
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
@ -83,37 +115,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i
|
||||
writer.append_data(frame)
|
||||
|
||||
|
||||
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
|
||||
|
||||
def sampling_main(args, model_cls):
|
||||
if isinstance(model_cls, type):
|
||||
model = get_model(args, model_cls)
|
||||
@ -127,44 +128,65 @@ def sampling_main(args, model_cls):
|
||||
data_iter = read_from_cli()
|
||||
elif args.input_type == "txt":
|
||||
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
|
||||
print("rank and world_size", rank, world_size)
|
||||
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
image_size = [480, 720]
|
||||
|
||||
if args.image2video:
|
||||
chained_trainsforms = []
|
||||
chained_trainsforms.append(TT.ToTensor())
|
||||
transform = TT.Compose(chained_trainsforms)
|
||||
|
||||
sample_func = model.sample
|
||||
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
|
||||
num_samples = [1]
|
||||
force_uc_zero_embeddings = ["txt"]
|
||||
device = model.device
|
||||
|
||||
with torch.no_grad():
|
||||
for text, cnt in tqdm(data_iter):
|
||||
if args.image2video:
|
||||
text, image_path = text.split("@@")
|
||||
# use with input image shape
|
||||
text, image_path = text.split('@@')
|
||||
assert os.path.exists(image_path), image_path
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image = transform(image).unsqueeze(0).to("cuda")
|
||||
image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0)
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
(img_W, img_H) = image.size
|
||||
|
||||
def nearest_multiple_of_16(n):
|
||||
lower_multiple = (n // 16) * 16
|
||||
upper_multiple = (n // 16 + 1) * 16
|
||||
if abs(n - lower_multiple) < abs(n - upper_multiple):
|
||||
return lower_multiple
|
||||
else:
|
||||
return upper_multiple
|
||||
|
||||
if img_H < img_W:
|
||||
H = 96
|
||||
W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
|
||||
else:
|
||||
W = 96
|
||||
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
|
||||
chained_trainsforms = []
|
||||
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
|
||||
chained_trainsforms.append(TT.ToTensor())
|
||||
transform = TT.Compose(chained_trainsforms)
|
||||
image = transform(image).unsqueeze(0).to('cuda')
|
||||
image = image * 2.0 - 1.0
|
||||
image = image.unsqueeze(2).to(torch.bfloat16)
|
||||
image = model.encode_first_stage(image, None)
|
||||
image = image / model.scale_factor
|
||||
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pad_shape = (image.shape[0], T - 1, C, H // F, W // F)
|
||||
pad_shape = (image.shape[0], T - 1, C, H, W)
|
||||
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
|
||||
else:
|
||||
image_size = args.sampling_image_size
|
||||
T, H, W, C = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels
|
||||
F = 8 # 8x downsampled
|
||||
image = None
|
||||
|
||||
text_cast = [text]
|
||||
mp_size = mpu.get_model_parallel_world_size()
|
||||
global_rank = torch.distributed.get_rank() // mp_size
|
||||
src = global_rank * mp_size
|
||||
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
|
||||
text = text_cast[0]
|
||||
value_dict = {
|
||||
"prompt": text,
|
||||
"negative_prompt": "",
|
||||
"num_frames": torch.tensor(T).unsqueeze(0),
|
||||
'prompt': text,
|
||||
'negative_prompt': '',
|
||||
'num_frames': torch.tensor(T).unsqueeze(0)
|
||||
}
|
||||
|
||||
batch, batch_uc = get_batch(
|
||||
@ -187,64 +209,52 @@ def sampling_main(args, model_cls):
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
|
||||
|
||||
if args.image2video and image is not None:
|
||||
if args.image2video:
|
||||
c["concat"] = image
|
||||
uc["concat"] = image
|
||||
|
||||
for index in range(args.batch_size):
|
||||
# reload model on GPU
|
||||
model.to(device)
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H // F, W // F),
|
||||
)
|
||||
if args.image2video:
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H, W),
|
||||
ofs=torch.tensor([2.0]).to('cuda')
|
||||
)
|
||||
else:
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H // F, W // F),
|
||||
).to('cuda')
|
||||
|
||||
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if args.only_save_latents:
|
||||
samples_z = 1.0 / model.scale_factor * samples_z
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
|
||||
with open(os.path.join(save_path, "text.txt"), "w") as f:
|
||||
f.write(text)
|
||||
else:
|
||||
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
|
||||
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
)
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
|
||||
# Unload the model from GPU to save GPU memory
|
||||
model.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
first_stage_model = model.first_stage_model
|
||||
first_stage_model = first_stage_model.to(device)
|
||||
|
||||
latent = 1.0 / model.scale_factor * samples_z
|
||||
|
||||
# Decode latent serial to save GPU memory
|
||||
recons = []
|
||||
loop_num = (T - 1) // 2
|
||||
for i in range(loop_num):
|
||||
if i == 0:
|
||||
start_frame, end_frame = 0, 3
|
||||
else:
|
||||
start_frame, end_frame = i * 2 + 1, i * 2 + 3
|
||||
if i == loop_num - 1:
|
||||
clear_fake_cp_cache = True
|
||||
else:
|
||||
clear_fake_cp_cache = False
|
||||
with torch.no_grad():
|
||||
recon = first_stage_model.decode(
|
||||
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
|
||||
)
|
||||
|
||||
recons.append(recon)
|
||||
|
||||
recon = torch.cat(recons, dim=2).to(torch.float32)
|
||||
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
)
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
||||
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
||||
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
|
||||
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
|
||||
if __name__ == '__main__':
|
||||
if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
|
||||
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
|
||||
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
|
||||
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
|
||||
py_parser = argparse.ArgumentParser(add_help=False)
|
||||
known, args_list = py_parser.parse_known_args()
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""
|
||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||
"""
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
@ -16,7 +17,6 @@ from ...modules.diffusionmodules.sampling_utils import (
|
||||
to_sigma,
|
||||
)
|
||||
from ...util import append_dims, default, instantiate_from_config
|
||||
from ...util import SeededNoise
|
||||
|
||||
from .guiders import DynamicCFG
|
||||
|
||||
@ -44,7 +44,9 @@ class BaseDiffusionSampler:
|
||||
self.device = device
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)
|
||||
sigmas = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device
|
||||
)
|
||||
uc = default(uc, cond)
|
||||
|
||||
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
@ -83,7 +85,9 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler):
|
||||
|
||||
|
||||
class EDMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
|
||||
def __init__(
|
||||
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_churn = s_churn
|
||||
@ -102,15 +106,21 @@ class EDMSampler(SingleStepDiffusionSampler):
|
||||
dt = append_dims(next_sigma - sigma_hat, x.ndim)
|
||||
|
||||
euler_step = self.euler_step(x, d, dt)
|
||||
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||
x = self.possible_correction_step(
|
||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
gamma = (
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
|
||||
if self.s_tmin <= sigmas[i] <= self.s_tmax
|
||||
else 0.0
|
||||
)
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
@ -126,23 +136,30 @@ class EDMSampler(SingleStepDiffusionSampler):
|
||||
|
||||
|
||||
class DDIMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(self, s_noise=0.1, *args, **kwargs):
|
||||
def __init__(
|
||||
self, s_noise=0.1, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_noise = s_noise
|
||||
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
|
||||
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
d = to_d(x, sigma, denoised)
|
||||
dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
|
||||
dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)
|
||||
|
||||
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
|
||||
|
||||
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||
x = self.possible_correction_step(
|
||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
@ -181,7 +198,9 @@ class AncestralSampler(SingleStepDiffusionSampler):
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
@ -208,32 +227,43 @@ class LinearMultistepSampler(BaseDiffusionSampler):
|
||||
self.order = order
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
ds = []
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
sigma = s_in * sigmas[i]
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
|
||||
)
|
||||
denoised = self.guider(denoised, sigma)
|
||||
d = to_d(x, sigma, denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > self.order:
|
||||
ds.pop(0)
|
||||
cur_order = min(i + 1, self.order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
coeffs = [
|
||||
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
|
||||
for j in range(cur_order)
|
||||
]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerEDMSampler(EDMSampler):
|
||||
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||
def possible_correction_step(
|
||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
):
|
||||
return euler_step
|
||||
|
||||
|
||||
class HeunEDMSampler(EDMSampler):
|
||||
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||
def possible_correction_step(
|
||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
):
|
||||
if torch.sum(next_sigma) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0
|
||||
return euler_step
|
||||
@ -243,7 +273,9 @@ class HeunEDMSampler(EDMSampler):
|
||||
d_prime = (d + d_new) / 2.0
|
||||
|
||||
# apply correction if noise level is not 0
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@ -282,7 +314,9 @@ class DPMPP2SAncestralSampler(AncestralSampler):
|
||||
x = x_euler
|
||||
else:
|
||||
h, s, t, t_next = self.get_variables(sigma, sigma_down)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
|
||||
mult = [
|
||||
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
|
||||
]
|
||||
|
||||
x2 = mult[0] * x - mult[1] * denoised
|
||||
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
|
||||
@ -332,7 +366,10 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
]
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised
|
||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||
@ -343,12 +380,16 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
||||
)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
@ -365,7 +406,6 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
||||
@ -380,7 +420,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
def get_mult(self, h, r, t, t_next, previous_sigma):
|
||||
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
|
||||
mult2 = (-2 * h).expm1()
|
||||
mult2 = (-2*h).expm1()
|
||||
|
||||
if previous_sigma is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
@ -403,8 +443,11 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
|
||||
mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
]
|
||||
mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim)
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||
@ -415,12 +458,16 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
||||
)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
@ -437,7 +484,6 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SdeditEDMSampler(EulerEDMSampler):
|
||||
def __init__(self, edit_ratio=0.5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -446,7 +492,9 @@ class SdeditEDMSampler(EulerEDMSampler):
|
||||
|
||||
def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
|
||||
randn_unit = randn.clone()
|
||||
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps)
|
||||
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
randn, cond, uc, num_steps
|
||||
)
|
||||
|
||||
if num_steps is None:
|
||||
num_steps = self.num_steps
|
||||
@ -461,7 +509,9 @@ class SdeditEDMSampler(EulerEDMSampler):
|
||||
x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
|
||||
|
||||
gamma = (
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
|
||||
if self.s_tmin <= sigmas[i] <= self.s_tmax
|
||||
else 0.0
|
||||
)
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
@ -475,8 +525,8 @@ class SdeditEDMSampler(EulerEDMSampler):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.fixed_frames = fixed_frames
|
||||
@ -484,13 +534,10 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps,
|
||||
device=self.device,
|
||||
return_idx=True,
|
||||
do_append_zero=False,
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False
|
||||
)
|
||||
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
|
||||
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))])
|
||||
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))])
|
||||
|
||||
uc = default(uc, cond)
|
||||
|
||||
@ -500,51 +547,36 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
|
||||
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None):
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None):
|
||||
additional_model_inputs = {}
|
||||
|
||||
if ofs is not None:
|
||||
additional_model_inputs['ofs'] = ofs
|
||||
|
||||
if isinstance(scale, torch.Tensor) == False and scale == 1:
|
||||
additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep
|
||||
additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
|
||||
if scale_emb is not None:
|
||||
additional_model_inputs["scale_emb"] = scale_emb
|
||||
additional_model_inputs['scale_emb'] = scale_emb
|
||||
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
|
||||
else:
|
||||
additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
|
||||
).to(torch.float32)
|
||||
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32)
|
||||
if isinstance(self.guider, DynamicCFG):
|
||||
denoised = self.guider(
|
||||
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
|
||||
)
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale)
|
||||
else:
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale)
|
||||
return denoised
|
||||
|
||||
def sampler_step(
|
||||
self,
|
||||
alpha_cumprod_sqrt,
|
||||
next_alpha_cumprod_sqrt,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
):
|
||||
denoised = self.denoise(
|
||||
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
|
||||
).to(torch.float32)
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None):
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
|
||||
|
||||
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
|
||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||
|
||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
@ -558,25 +590,83 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
timestep=timesteps[-(i+1)],
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
ofs=ofs # 1020
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
|
||||
)
|
||||
uc = default(uc, cond)
|
||||
|
||||
num_sigmas = len(alpha_cumprod_sqrt)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
|
||||
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
|
||||
additional_model_inputs = {}
|
||||
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(
|
||||
torch.float32)
|
||||
if isinstance(self.guider, DynamicCFG):
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep)
|
||||
else:
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5)
|
||||
return denoised
|
||||
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None,
|
||||
timestep=None):
|
||||
# 此处的sigma实际上是alpha_cumprod_sqrt
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised
|
||||
|
||||
a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5
|
||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||
|
||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||
return x
|
||||
|
||||
def __call__(self, image, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
s_in * alpha_cumprod_sqrt[i],
|
||||
s_in * alpha_cumprod_sqrt[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)]
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
||||
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||
alpha_cumprod = alpha_cumprod_sqrt ** 2
|
||||
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
|
||||
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
|
||||
h = lamb_next - lamb
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
|
||||
h_last = lamb - lamb_previous
|
||||
r = h_last / h
|
||||
return h, r, lamb, lamb_next
|
||||
@ -584,8 +674,8 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
||||
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
|
||||
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
|
||||
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp()
|
||||
mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
@ -608,21 +698,18 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
timestep=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
ofs=None # 1020
|
||||
):
|
||||
denoised = self.denoise(
|
||||
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
|
||||
).to(torch.float32)
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
|
||||
if idx == 1:
|
||||
return denoised, denoised
|
||||
|
||||
h, r, lamb, lamb_next = self.get_variables(
|
||||
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
]
|
||||
mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
|
||||
mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim)
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
|
||||
@ -636,24 +723,23 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
prefix_frames = x[:, : self.fixed_frames]
|
||||
prefix_frames = x[:, :self.fixed_frames]
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
if self.sdedit:
|
||||
rd = torch.randn_like(prefix_frames)
|
||||
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
|
||||
s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
|
||||
)
|
||||
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape))
|
||||
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
else:
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised,
|
||||
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
|
||||
@ -664,28 +750,29 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
timestep=timesteps[-(i+1)],
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
ofs=ofs # 1020
|
||||
)
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
||||
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||
alpha_cumprod = alpha_cumprod_sqrt ** 2
|
||||
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
|
||||
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
|
||||
h = lamb_next - lamb
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
|
||||
h_last = lamb - lamb_previous
|
||||
r = h_last / h
|
||||
return h, r, lamb, lamb_next
|
||||
@ -693,7 +780,7 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
||||
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5
|
||||
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
@ -714,15 +801,13 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
timestep=None
|
||||
):
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised, denoised
|
||||
|
||||
h, r, lamb, lamb_next = self.get_variables(
|
||||
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
@ -757,7 +842,39 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
timestep=timesteps[-(i+1)]
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class VideoDDPMSampler(VideoDDIMSampler):
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None):
|
||||
# 此处的sigma实际上是alpha_cumprod_sqrt
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised
|
||||
|
||||
alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
|
||||
x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \
|
||||
+ append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \
|
||||
+ append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
s_in * alpha_cumprod_sqrt[i],
|
||||
s_in * alpha_cumprod_sqrt[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i
|
||||
)
|
||||
|
||||
return x
|
@ -17,23 +17,20 @@ class EDMSampling:
|
||||
|
||||
|
||||
class DiscreteSampling:
|
||||
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False):
|
||||
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0):
|
||||
self.num_idx = num_idx
|
||||
self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
|
||||
self.sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
)
|
||||
world_size = mpu.get_data_parallel_world_size()
|
||||
if world_size <= 8:
|
||||
uniform_sampling = False
|
||||
self.uniform_sampling = uniform_sampling
|
||||
self.group_num = group_num
|
||||
if self.uniform_sampling:
|
||||
i = 1
|
||||
while True:
|
||||
if world_size % i != 0 or num_idx % (world_size // i) != 0:
|
||||
i += 1
|
||||
else:
|
||||
self.group_num = world_size // i
|
||||
break
|
||||
|
||||
assert self.group_num > 0
|
||||
assert world_size % self.group_num == 0
|
||||
self.group_width = world_size // self.group_num # the number of rank in one group
|
||||
assert world_size % group_num == 0
|
||||
self.group_width = world_size // group_num # the number of rank in one group
|
||||
self.sigma_interval = self.num_idx // self.group_num
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
@ -45,9 +42,7 @@ class DiscreteSampling:
|
||||
group_index = rank // self.group_width
|
||||
idx = default(
|
||||
rand,
|
||||
torch.randint(
|
||||
group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
|
||||
),
|
||||
torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)),
|
||||
)
|
||||
else:
|
||||
idx = default(
|
||||
@ -59,7 +54,6 @@ class DiscreteSampling:
|
||||
else:
|
||||
return self.idx_to_sigma(idx)
|
||||
|
||||
|
||||
class PartialDiscreteSampling:
|
||||
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
|
||||
self.total_num_idx = total_num_idx
|
||||
|
@ -592,8 +592,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
unregularized: bool = False,
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
use_cp: bool = True,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.cp_size > 0 and not input_cp:
|
||||
if self.cp_size <= 1:
|
||||
use_cp = False
|
||||
if self.cp_size > 0 and use_cp and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
|
||||
@ -603,11 +606,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
x = _conv_split(x, dim=2, kernel_size=1)
|
||||
|
||||
if return_reg_log:
|
||||
z, reg_log = super().encode(x, return_reg_log, unregularized)
|
||||
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
||||
else:
|
||||
z = super().encode(x, return_reg_log, unregularized)
|
||||
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
||||
|
||||
if self.cp_size > 0 and not output_cp:
|
||||
if self.cp_size > 0 and use_cp and not output_cp:
|
||||
z = _conv_gather(z, dim=2, kernel_size=1)
|
||||
|
||||
if return_reg_log:
|
||||
@ -619,23 +622,24 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
z: torch.Tensor,
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
split_kernel_size: int = 1,
|
||||
use_cp: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if self.cp_size > 0 and not input_cp:
|
||||
if self.cp_size <= 1:
|
||||
use_cp = False
|
||||
if self.cp_size > 0 and use_cp and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
||||
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
|
||||
|
||||
z = _conv_split(z, dim=2, kernel_size=split_kernel_size)
|
||||
z = _conv_split(z, dim=2, kernel_size=1)
|
||||
|
||||
x = super().decode(z, **kwargs)
|
||||
|
||||
if self.cp_size > 0 and not output_cp:
|
||||
x = _conv_gather(x, dim=2, kernel_size=split_kernel_size)
|
||||
x = super().decode(z, use_cp=use_cp, **kwargs)
|
||||
|
||||
if self.cp_size > 0 and use_cp and not output_cp:
|
||||
x = _conv_gather(x, dim=2, kernel_size=1)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
|
@ -5,8 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from beartype import beartype
|
||||
from beartype.typing import Union, Tuple, Optional, List
|
||||
from beartype.typing import Union, Tuple
|
||||
from einops import rearrange
|
||||
|
||||
from sgm.util import (
|
||||
@ -16,11 +15,7 @@ from sgm.util import (
|
||||
get_context_parallel_group_rank,
|
||||
)
|
||||
|
||||
# try:
|
||||
from vae_modules.utils import SafeConv3d as Conv3d
|
||||
# except:
|
||||
# # Degrade to normal Conv3d if SafeConv3d is not available
|
||||
# from torch.nn import Conv3d
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
@ -81,7 +76,6 @@ def _split(input_, dim):
|
||||
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||
@ -94,7 +88,6 @@ def _split(input_, dim):
|
||||
output = torch.cat([inpu_first_frame_, output], dim=dim)
|
||||
output = output.contiguous()
|
||||
|
||||
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
@ -382,19 +375,6 @@ class ContextParallelCausalConv3d(nn.Module):
|
||||
self.cache_padding = None
|
||||
|
||||
def forward(self, input_, clear_cache=True):
|
||||
# if input_.shape[2] == 1: # handle image
|
||||
# # first frame padding
|
||||
# input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
|
||||
# else:
|
||||
# input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
|
||||
|
||||
# padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
# input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
|
||||
|
||||
# output_parallel = self.conv(input_parallel)
|
||||
# output = output_parallel
|
||||
# return output
|
||||
|
||||
input_parallel = fake_cp_pass_from_previous_rank(
|
||||
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
|
||||
)
|
||||
@ -464,7 +444,6 @@ class SpatialNorm3D(nn.Module):
|
||||
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
|
||||
else:
|
||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
|
||||
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
||||
if freeze_norm_layer:
|
||||
for p in self.norm_layer.parameters:
|
||||
p.requires_grad = False
|
||||
@ -543,21 +522,29 @@ class Upsample3D(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
if x.shape[2] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
# Process the time dimension first as x_first
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
# print(x_first.shape)
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
# split the rest of the frames to avoid MAX_INT overflow in Pytorch
|
||||
splits = torch.split(x_rest, 16, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
# concatenate the first frame with the rest
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
splits = torch.split(x, 16, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
if self.with_conv:
|
||||
@ -585,15 +572,21 @@ class DownSample3D(nn.Module):
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
splits = torch.split(x_rest, 16, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
splits = torch.split(x, 16, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
|
||||
if self.with_conv:
|
||||
@ -675,31 +668,19 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
|
||||
h = x
|
||||
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
else:
|
||||
h = self.norm1(h)
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
else:
|
||||
h = self.norm2(h)
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h, clear_cache=clear_fake_cp_cache)
|
||||
@ -826,10 +807,7 @@ class ContextParallelEncoder3D(nn.Module):
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
h = self.norm_out(h)
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
@ -934,10 +912,10 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
if i_level <= self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
|
||||
@ -954,9 +932,7 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
t = z.shape[2]
|
||||
# z to block_in
|
||||
|
||||
zq = z
|
||||
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
|
@ -1,22 +1,15 @@
|
||||
"""
|
||||
This script demonstrates how to convert and generate video from a text prompt
|
||||
using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
||||
This script requires the `diffusers>=0.30.2` library to be installed.
|
||||
|
||||
Functions:
|
||||
- reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
|
||||
- reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
|
||||
- reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
|
||||
- remove_keys_inplace: Removes specified keys from the state_dict in-place.
|
||||
- replace_up_keys_inplace: Replaces keys in the "up" block in-place.
|
||||
- get_state_dict: Extracts the state_dict from a saved checkpoint.
|
||||
- update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
|
||||
- convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
|
||||
- convert_vae: Converts a VAE checkpoint to the CogVideoX format.
|
||||
- get_args: Parses command-line arguments for the script.
|
||||
- generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
|
||||
"""
|
||||
|
||||
The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format.
|
||||
This script supports the conversion of the following models:
|
||||
- CogVideoX-2B
|
||||
- CogVideoX-5B, CogVideoX-5B-I2V
|
||||
- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V
|
||||
|
||||
Original Script:
|
||||
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
|
||||
|
||||
"""
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
@ -153,12 +146,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
|
||||
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_attention_heads: int,
|
||||
use_rotary_positional_embeddings: bool,
|
||||
i2v: bool,
|
||||
dtype: torch.dtype,
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_attention_heads: int,
|
||||
use_rotary_positional_embeddings: bool,
|
||||
i2v: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
@ -172,7 +165,7 @@ def convert_transformer(
|
||||
).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
new_key = key[len(PREFIX_KEY):]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@ -209,7 +202,8 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint")
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
@ -259,9 +253,10 @@ if __name__ == "__main__":
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
||||
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
# Apparently, the conversion does not work anymore without this :shrug:
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
@ -301,4 +296,7 @@ if __name__ == "__main__":
|
||||
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||
# is either fp16/bf16 here).
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
|
||||
|
||||
# This is necessary This is necessary for users with insufficient memory,
|
||||
# such as those using Colab and notebooks, as it can save some memory used for model loading.
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
||||
|
Loading…
x
Reference in New Issue
Block a user