update with test code

This commit is contained in:
zR 2024-11-04 14:34:36 +08:00
parent b6abbeab97
commit 3a9af5bdd9
11 changed files with 588 additions and 423 deletions

View File

@ -36,6 +36,7 @@ def add_sampling_config_args(parser):
group.add_argument("--input-dir", type=str, default=None) group.add_argument("--input-dir", type=str, default=None)
group.add_argument("--input-type", type=str, default="cli") group.add_argument("--input-type", type=str, default="cli")
group.add_argument("--input-file", type=str, default="input.txt") group.add_argument("--input-file", type=str, default="input.txt")
group.add_argument("--sampling-image-size", type=list, default=[768, 1360])
group.add_argument("--final-size", type=int, default=2048) group.add_argument("--final-size", type=int, default=2048)
group.add_argument("--sdedit", action="store_true") group.add_argument("--sdedit", action="store_true")
group.add_argument("--grid-num-rows", type=int, default=1) group.add_argument("--grid-num-rows", type=int, default=1)

BIN
sat/configs/images.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

View File

@ -185,7 +185,12 @@ class SATVideoDiffusionEngine(nn.Module):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else: else:
kwargs = {} kwargs = {}
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs) frame = z.shape[2] * 4 - 3
if frame <= 9:
use_cp = False
else:
use_cp = True
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs)
all_out.append(out) all_out.append(out)
out = torch.cat(all_out, dim=0) out = torch.cat(all_out, dim=0)
return out return out
@ -218,6 +223,7 @@ class SATVideoDiffusionEngine(nn.Module):
shape: Union[None, Tuple, List] = None, shape: Union[None, Tuple, List] = None,
prefix=None, prefix=None,
concat_images=None, concat_images=None,
ofs=None,
**kwargs, **kwargs,
): ):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
@ -241,7 +247,7 @@ class SATVideoDiffusionEngine(nn.Module):
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
) )
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb) samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
samples = samples.to(self.dtype) samples = samples.to(self.dtype)
return samples return samples

View File

@ -1,5 +1,7 @@
from functools import partial from functools import partial
from einops import rearrange, repeat from einops import rearrange, repeat
from functools import reduce
from operator import mul
import numpy as np import numpy as np
import torch import torch
@ -13,38 +15,34 @@ from sat.mpu.layers import ColumnParallelLinear
from sgm.util import instantiate_from_config from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.openaimodel import Timestep from sgm.modules.diffusionmodules.openaimodel import Timestep
from sgm.modules.diffusionmodules.util import ( from sgm.modules.diffusionmodules.util import linear, timestep_embedding
linear,
timestep_embedding,
)
from sat.ops.layernorm import LayerNorm, RMSNorm from sat.ops.layernorm import LayerNorm, RMSNorm
class ImagePatchEmbeddingMixin(BaseMixin): class ImagePatchEmbeddingMixin(BaseMixin):
def __init__( def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None):
self,
in_channels,
hidden_size,
patch_size,
bias=True,
text_hidden_size=None,
):
super().__init__() super().__init__()
self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) self.patch_size = patch_size
self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size)
if text_hidden_size is not None: if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size) self.text_proj = nn.Linear(text_hidden_size, hidden_size)
else: else:
self.text_proj = None self.text_proj = None
def word_embedding_forward(self, input_ids, **kwargs): def word_embedding_forward(self, input_ids, **kwargs):
# now is 3d patch
images = kwargs["images"] # (b,t,c,h,w) images = kwargs["images"] # (b,t,c,h,w)
B, T = images.shape[:2] emb = rearrange(images, "b t c h w -> b (t h w) c")
emb = images.view(-1, *images.shape[2:]) emb = rearrange(
emb = self.proj(emb) # ((b t),d,h/2,w/2) emb,
emb = emb.view(B, T, *emb.shape[1:]) "b (t o h p w q) c -> b (t h w) (c o p q)",
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d) t=kwargs["rope_T"],
emb = rearrange(emb, "b t n d -> b (t n) d") h=kwargs["rope_H"],
w=kwargs["rope_W"],
o=self.patch_size[0],
p=self.patch_size[1],
q=self.patch_size[2],
)
emb = self.proj(emb)
if self.text_proj is not None: if self.text_proj is not None:
text_emb = self.text_proj(kwargs["encoder_outputs"]) text_emb = self.text_proj(kwargs["encoder_outputs"])
@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed(
grid_size: int of the grid height and width grid_size: int of the grid height and width
t_size: int of the temporal size t_size: int of the temporal size
return: return:
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim]
(w/ or w/o cls_token)
""" """
assert embed_dim % 4 == 0 assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3 embed_dim_spatial = embed_dim // 4 * 3
@ -100,7 +99,6 @@ def get_3d_sincos_pos_embed(
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
return pos_embed # [T, H*W, D] return pos_embed # [T, H*W, D]
@ -259,6 +257,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
text_length, text_length,
theta=10000, theta=10000,
rot_v=False, rot_v=False,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
learnable_pos_embed=False, learnable_pos_embed=False,
): ):
super().__init__() super().__init__()
@ -285,14 +286,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.contiguous() freqs = freqs.contiguous()
freqs_sin = freqs.sin() self.freqs_sin = freqs.sin().cuda()
freqs_cos = freqs.cos() self.freqs_cos = freqs.cos().cuda()
self.register_buffer("freqs_sin", freqs_sin)
self.register_buffer("freqs_cos", freqs_cos)
self.text_length = text_length self.text_length = text_length
if learnable_pos_embed: if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length num_patches = height * width * compressed_num_frames + text_length
@ -301,15 +298,20 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
self.pos_embedding = None self.pos_embedding = None
def rotary(self, t, **kwargs): def rotary(self, t, **kwargs):
seq_len = t.shape[2] def reshape_freq(freqs):
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
return t * freqs_cos + rotate_half(t) * freqs_sin return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs): def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None: if self.pos_embedding is not None:
return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]] return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
else: else:
return None return None
@ -326,10 +328,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
): ):
attention_fn_default = HOOKS_DEFAULT["attention_fn"] attention_fn_default = HOOKS_DEFAULT["attention_fn"]
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :]) query_layer = torch.cat(
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :]) (
query_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
query_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
key_layer = torch.cat(
(
key_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
key_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
if self.rot_v: if self.rot_v:
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :]) value_layer = torch.cat(
(
value_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
value_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
return attention_fn_default( return attention_fn_default(
query_layer, query_layer,
@ -347,21 +400,25 @@ def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): def unpatchify(x, c, patch_size, w, h, **kwargs):
""" """
x: (N, T/2 * S, patch_size**3 * C) x: (N, T/2 * S, patch_size**3 * C)
imgs: (N, T, H, W, C) imgs: (N, T, H, W, C)
patch_size 被拆解为三个不同的维度 (o, p, q)分别对应了深度o高度p和宽度q这使得 patch 大小在不同维度上可以不相等增加了灵活性
""" """
if rope_position_ids is not None:
assert NotImplementedError imgs = rearrange(
# do pix2struct unpatchify x,
L = x.shape[1] "b (t h w) (c o p q) -> b (t o) c (h p) (w q)",
x = x.reshape(shape=(x.shape[0], L, p, p, c)) c=c,
x = torch.einsum("nlpqc->ncplq", x) o=patch_size[0],
imgs = x.reshape(shape=(x.shape[0], c, p, L * p)) p=patch_size[1],
else: q=patch_size[2],
b = x.shape[0] t=kwargs["rope_T"],
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p) h=kwargs["rope_H"],
w=kwargs["rope_W"],
)
return imgs return imgs
@ -382,27 +439,17 @@ class FinalLayerMixin(BaseMixin):
self.patch_size = patch_size self.patch_size = patch_size
self.out_channels = out_channels self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
self.spatial_length = latent_width * latent_height // patch_size**2
self.latent_width = latent_width
self.latent_height = latent_height
def final_forward(self, logits, **kwargs): def final_forward(self, logits, **kwargs):
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d) x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale) x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x) x = self.linear(x)
return unpatchify( return unpatchify(
x, x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs
c=self.out_channels,
p=self.patch_size,
w=self.latent_width // self.patch_size,
h=self.latent_height // self.patch_size,
rope_position_ids=kwargs.get("rope_position_ids", None),
**kwargs,
) )
def reinit(self, parent_model=None): def reinit(self, parent_model=None):
@ -440,8 +487,6 @@ class SwiGLUMixin(BaseMixin):
class AdaLNMixin(BaseMixin): class AdaLNMixin(BaseMixin):
def __init__( def __init__(
self, self,
width,
height,
hidden_size, hidden_size,
num_layers, num_layers,
time_embed_dim, time_embed_dim,
@ -452,8 +497,6 @@ class AdaLNMixin(BaseMixin):
): ):
super().__init__() super().__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.width = width
self.height = height
self.compressed_num_frames = compressed_num_frames self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList( self.adaLN_modulations = nn.ModuleList(
@ -611,7 +654,8 @@ class DiffusionTransformer(BaseModel):
time_interpolation=1.0, time_interpolation=1.0,
use_SwiGLU=False, use_SwiGLU=False,
use_RMSNorm=False, use_RMSNorm=False,
zero_init_y_embed=False, cfg_embed_dim=None,
ofs_embed_dim=None,
**kwargs, **kwargs,
): ):
self.latent_width = latent_width self.latent_width = latent_width
@ -619,12 +663,14 @@ class DiffusionTransformer(BaseModel):
self.patch_size = patch_size self.patch_size = patch_size
self.num_frames = num_frames self.num_frames = num_frames
self.time_compressed_rate = time_compressed_rate self.time_compressed_rate = time_compressed_rate
self.spatial_length = latent_width * latent_height // patch_size**2 self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:])
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.model_channels = hidden_size self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.cfg_embed_dim = cfg_embed_dim
self.ofs_embed_dim = ofs_embed_dim
self.num_classes = num_classes self.num_classes = num_classes
self.adm_in_channels = adm_in_channels self.adm_in_channels = adm_in_channels
self.input_time = input_time self.input_time = input_time
@ -636,7 +682,6 @@ class DiffusionTransformer(BaseModel):
self.width_interpolation = width_interpolation self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation self.time_interpolation = time_interpolation
self.inner_hidden_size = hidden_size * 4 self.inner_hidden_size = hidden_size * 4
self.zero_init_y_embed = zero_init_y_embed
try: try:
self.dtype = str_to_dtype[kwargs.pop("dtype")] self.dtype = str_to_dtype[kwargs.pop("dtype")]
except: except:
@ -669,7 +714,6 @@ class DiffusionTransformer(BaseModel):
def _build_modules(self, module_configs): def _build_modules(self, module_configs):
model_channels = self.hidden_size model_channels = self.hidden_size
# time_embed_dim = model_channels * 4
time_embed_dim = self.time_embed_dim time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), linear(model_channels, time_embed_dim),
@ -677,6 +721,20 @@ class DiffusionTransformer(BaseModel):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
if self.ofs_embed_dim is not None:
self.ofs_embed = nn.Sequential(
linear(self.ofs_embed_dim, self.ofs_embed_dim),
nn.SiLU(),
linear(self.ofs_embed_dim, self.ofs_embed_dim),
)
if self.cfg_embed_dim is not None:
self.cfg_embed = nn.Sequential(
linear(self.cfg_embed_dim, self.cfg_embed_dim),
nn.SiLU(),
linear(self.cfg_embed_dim, self.cfg_embed_dim),
)
if self.num_classes is not None: if self.num_classes is not None:
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
@ -701,9 +759,6 @@ class DiffusionTransformer(BaseModel):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
) )
if self.zero_init_y_embed:
nn.init.constant_(self.label_emb[0][2].weight, 0)
nn.init.constant_(self.label_emb[0][2].bias, 0)
else: else:
raise ValueError() raise ValueError()
@ -712,10 +767,13 @@ class DiffusionTransformer(BaseModel):
"pos_embed", "pos_embed",
instantiate_from_config( instantiate_from_config(
pos_embed_config, pos_embed_config,
height=self.latent_height // self.patch_size, height=self.latent_height // self.patch_size[1],
width=self.latent_width // self.patch_size, width=self.latent_width // self.patch_size[2],
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
height_interpolation=self.height_interpolation,
width_interpolation=self.width_interpolation,
time_interpolation=self.time_interpolation,
), ),
reinit=True, reinit=True,
) )
@ -737,8 +795,6 @@ class DiffusionTransformer(BaseModel):
"adaln_layer", "adaln_layer",
instantiate_from_config( instantiate_from_config(
adaln_layer_config, adaln_layer_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_layers=self.num_layers, num_layers=self.num_layers,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
@ -749,7 +805,6 @@ class DiffusionTransformer(BaseModel):
) )
else: else:
raise NotImplementedError raise NotImplementedError
final_layer_config = module_configs["final_layer_config"] final_layer_config = module_configs["final_layer_config"]
self.add_mixin( self.add_mixin(
"final_layer", "final_layer",
@ -766,25 +821,18 @@ class DiffusionTransformer(BaseModel):
reinit=True, reinit=True,
) )
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
return return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape b, t, d, h, w = x.shape
if x.dtype != self.dtype: if x.dtype != self.dtype:
x = x.to(self.dtype) x = x.to(self.dtype)
# This is not use in inference
if "concat_images" in kwargs and kwargs["concat_images"] is not None: if "concat_images" in kwargs and kwargs["concat_images"] is not None:
if kwargs["concat_images"].shape[0] != x.shape[0]: if kwargs["concat_images"].shape[0] != x.shape[0]:
concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1) concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1)
else: else:
concat_images = kwargs["concat_images"] concat_images = kwargs["concat_images"]
x = torch.cat([x, concat_images], dim=2) x = torch.cat([x, concat_images], dim=2)
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
@ -792,17 +840,33 @@ class DiffusionTransformer(BaseModel):
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
# assert y.shape[0] == x.shape[0]
assert x.shape[0] % y.shape[0] == 0 assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
kwargs["seq_length"] = t * h * w // (self.patch_size**2) if self.ofs_embed_dim is not None:
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
ofs_emb = self.ofs_embed(ofs_emb)
emb = emb + ofs_emb
if self.cfg_embed_dim is not None:
cfg_emb = kwargs["scale_emb"]
cfg_emb = self.cfg_embed(cfg_emb)
emb = emb + cfg_emb
if "ofs" in kwargs.keys():
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
ofs_emb = self.ofs_embed(ofs_emb)
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
kwargs["images"] = x kwargs["images"] = x
kwargs["emb"] = emb kwargs["emb"] = emb
kwargs["encoder_outputs"] = context kwargs["encoder_outputs"] = context
kwargs["text_length"] = context.shape[1] kwargs["text_length"] = context.shape[1]
kwargs["rope_T"] = t // self.patch_size[0]
kwargs["rope_H"] = h // self.patch_size[1]
kwargs["rope_W"] = w // self.patch_size[2]
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
output = super().forward(**kwargs)[0] output = super().forward(**kwargs)[0]
return output return output

View File

@ -1,16 +1,11 @@
SwissArmyTransformer==0.4.12 SwissArmyTransformer>=0.4.12
omegaconf==2.3.0 omegaconf>=2.3.0
torch==2.4.0 pytorch_lightning>=2.4.0
torchvision==0.19.0 kornia>=0.7.3
pytorch_lightning==2.3.3 beartype>=0.19.0
kornia==0.7.3 fsspec>=2024.2.0
beartype==0.18.5 safetensors>=0.4.5
numpy==2.0.1 scipy>=1.14.1
fsspec==2024.5.0 decord>=0.6.0
safetensors==0.4.3 wandb>=0.18.5
imageio-ffmpeg==0.5.1 deepspeed>=0.15.3
imageio==2.34.2
scipy==1.14.0
decord==0.6.0
wandb==0.17.5
deepspeed==0.14.4

View File

@ -4,24 +4,20 @@ import argparse
from typing import List, Union from typing import List, Union
from tqdm import tqdm from tqdm import tqdm
from omegaconf import ListConfig from omegaconf import ListConfig
from PIL import Image
import imageio import imageio
import torch import torch
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange, repeat
import torchvision.transforms as TT import torchvision.transforms as TT
from sat.model.base_model import get_model from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint from sat.training.model_io import load_checkpoint
from sat import mpu from sat import mpu
from diffusion_video import SATVideoDiffusionEngine from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args from arguments import get_args
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
from PIL import Image
def read_from_cli(): def read_from_cli():
cnt = 0 cnt = 0
@ -56,6 +52,42 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
if key == "txt": if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
)
elif key == "fps":
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
elif key == "fps_id":
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
elif key == "motion_bucket_id":
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
else: else:
batch[key] = value_dict[key] batch[key] = value_dict[key]
@ -83,37 +115,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i
writer.append_data(frame) writer.append_data(frame)
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def sampling_main(args, model_cls): def sampling_main(args, model_cls):
if isinstance(model_cls, type): if isinstance(model_cls, type):
model = get_model(args, model_cls) model = get_model(args, model_cls)
@ -127,44 +128,65 @@ def sampling_main(args, model_cls):
data_iter = read_from_cli() data_iter = read_from_cli()
elif args.input_type == "txt": elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
print("rank and world_size", rank, world_size)
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else: else:
raise NotImplementedError raise NotImplementedError
image_size = [480, 720]
if args.image2video:
chained_trainsforms = []
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
sample_func = model.sample sample_func = model.sample
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
num_samples = [1] num_samples = [1]
force_uc_zero_embeddings = ["txt"] force_uc_zero_embeddings = ["txt"]
device = model.device
with torch.no_grad(): with torch.no_grad():
for text, cnt in tqdm(data_iter): for text, cnt in tqdm(data_iter):
if args.image2video: if args.image2video:
text, image_path = text.split("@@") # use with input image shape
text, image_path = text.split('@@')
assert os.path.exists(image_path), image_path assert os.path.exists(image_path), image_path
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to("cuda") (img_W, img_H) = image.size
image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0)
def nearest_multiple_of_16(n):
lower_multiple = (n // 16) * 16
upper_multiple = (n // 16 + 1) * 16
if abs(n - lower_multiple) < abs(n - upper_multiple):
return lower_multiple
else:
return upper_multiple
if img_H < img_W:
H = 96
W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
else:
W = 96
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
chained_trainsforms = []
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to('cuda')
image = image * 2.0 - 1.0 image = image * 2.0 - 1.0
image = image.unsqueeze(2).to(torch.bfloat16) image = image.unsqueeze(2).to(torch.bfloat16)
image = model.encode_first_stage(image, None) image = model.encode_first_stage(image, None)
image = image / model.scale_factor
image = image.permute(0, 2, 1, 3, 4).contiguous() image = image.permute(0, 2, 1, 3, 4).contiguous()
pad_shape = (image.shape[0], T - 1, C, H // F, W // F) pad_shape = (image.shape[0], T - 1, C, H, W)
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
else: else:
image_size = args.sampling_image_size
T, H, W, C = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels
F = 8 # 8x downsampled
image = None image = None
text_cast = [text]
mp_size = mpu.get_model_parallel_world_size()
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
text = text_cast[0]
value_dict = { value_dict = {
"prompt": text, 'prompt': text,
"negative_prompt": "", 'negative_prompt': '',
"num_frames": torch.tensor(T).unsqueeze(0), 'num_frames': torch.tensor(T).unsqueeze(0)
} }
batch, batch_uc = get_batch( batch, batch_uc = get_batch(
@ -187,64 +209,52 @@ def sampling_main(args, model_cls):
if not k == "crossattn": if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
if args.image2video and image is not None: if args.image2video:
c["concat"] = image c["concat"] = image
uc["concat"] = image uc["concat"] = image
for index in range(args.batch_size): for index in range(args.batch_size):
# reload model on GPU if args.image2video:
model.to(device) samples_z = sample_func(
samples_z = sample_func( c,
c, uc=uc,
uc=uc, batch_size=1,
batch_size=1, shape=(T, C, H, W),
shape=(T, C, H // F, W // F), ofs=torch.tensor([2.0]).to('cuda')
) )
else:
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
).to('cuda')
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
if args.only_save_latents:
samples_z = 1.0 / model.scale_factor * samples_z
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
)
os.makedirs(save_path, exist_ok=True)
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
with open(os.path.join(save_path, "text.txt"), "w") as f:
f.write(text)
else:
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
# Unload the model from GPU to save GPU memory if __name__ == '__main__':
model.to("cpu") if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
torch.cuda.empty_cache() os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
first_stage_model = model.first_stage_model os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
first_stage_model = first_stage_model.to(device) os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
recons = []
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
if __name__ == "__main__":
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
py_parser = argparse.ArgumentParser(add_help=False) py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args() known, args_list = py_parser.parse_known_args()

View File

@ -1,7 +1,8 @@
""" """
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
""" """
from typing import Dict, Union from typing import Dict, Union
import torch import torch
@ -16,7 +17,6 @@ from ...modules.diffusionmodules.sampling_utils import (
to_sigma, to_sigma,
) )
from ...util import append_dims, default, instantiate_from_config from ...util import append_dims, default, instantiate_from_config
from ...util import SeededNoise
from .guiders import DynamicCFG from .guiders import DynamicCFG
@ -44,7 +44,9 @@ class BaseDiffusionSampler:
self.device = device self.device = device
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device) sigmas = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device
)
uc = default(uc, cond) uc = default(uc, cond)
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
@ -83,7 +85,9 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler):
class EDMSampler(SingleStepDiffusionSampler): class EDMSampler(SingleStepDiffusionSampler):
def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): def __init__(
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.s_churn = s_churn self.s_churn = s_churn
@ -102,15 +106,21 @@ class EDMSampler(SingleStepDiffusionSampler):
dt = append_dims(next_sigma - sigma_hat, x.ndim) dt = append_dims(next_sigma - sigma_hat, x.ndim)
euler_step = self.euler_step(x, d, dt) euler_step = self.euler_step(x, d, dt)
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) x = self.possible_correction_step(
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
gamma = ( gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
) )
x = self.sampler_step( x = self.sampler_step(
s_in * sigmas[i], s_in * sigmas[i],
@ -126,23 +136,30 @@ class EDMSampler(SingleStepDiffusionSampler):
class DDIMSampler(SingleStepDiffusionSampler): class DDIMSampler(SingleStepDiffusionSampler):
def __init__(self, s_noise=0.1, *args, **kwargs): def __init__(
self, s_noise=0.1, *args, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.s_noise = s_noise self.s_noise = s_noise
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
denoised = self.denoise(x, denoiser, sigma, cond, uc) denoised = self.denoise(x, denoiser, sigma, cond, uc)
d = to_d(x, sigma, denoised) d = to_d(x, sigma, denoised)
dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim) dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) x = self.possible_correction_step(
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step( x = self.sampler_step(
@ -181,7 +198,9 @@ class AncestralSampler(SingleStepDiffusionSampler):
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step( x = self.sampler_step(
@ -208,32 +227,43 @@ class LinearMultistepSampler(BaseDiffusionSampler):
self.order = order self.order = order
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
ds = [] ds = []
sigmas_cpu = sigmas.detach().cpu().numpy() sigmas_cpu = sigmas.detach().cpu().numpy()
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
sigma = s_in * sigmas[i] sigma = s_in * sigmas[i]
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) denoised = denoiser(
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
)
denoised = self.guider(denoised, sigma) denoised = self.guider(denoised, sigma)
d = to_d(x, sigma, denoised) d = to_d(x, sigma, denoised)
ds.append(d) ds.append(d)
if len(ds) > self.order: if len(ds) > self.order:
ds.pop(0) ds.pop(0)
cur_order = min(i + 1, self.order) cur_order = min(i + 1, self.order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] coeffs = [
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
for j in range(cur_order)
]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x return x
class EulerEDMSampler(EDMSampler): class EulerEDMSampler(EDMSampler):
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
return euler_step return euler_step
class HeunEDMSampler(EDMSampler): class HeunEDMSampler(EDMSampler):
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
if torch.sum(next_sigma) < 1e-14: if torch.sum(next_sigma) < 1e-14:
# Save a network evaluation if all noise levels are 0 # Save a network evaluation if all noise levels are 0
return euler_step return euler_step
@ -243,7 +273,9 @@ class HeunEDMSampler(EDMSampler):
d_prime = (d + d_new) / 2.0 d_prime = (d + d_new) / 2.0
# apply correction if noise level is not 0 # apply correction if noise level is not 0
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
)
return x return x
@ -282,7 +314,9 @@ class DPMPP2SAncestralSampler(AncestralSampler):
x = x_euler x = x_euler
else: else:
h, s, t, t_next = self.get_variables(sigma, sigma_down) h, s, t, t_next = self.get_variables(sigma, sigma_down)
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] mult = [
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
]
x2 = mult[0] * x - mult[1] * denoised x2 = mult[0] * x - mult[1] * denoised
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
@ -332,7 +366,10 @@ class DPMPP2MSampler(BaseDiffusionSampler):
denoised = self.denoise(x, denoiser, sigma, cond, uc) denoised = self.denoise(x, denoiser, sigma, cond, uc)
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] mult = [
append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
]
x_standard = mult[0] * x - mult[1] * denoised x_standard = mult[0] * x - mult[1] * denoised
if old_denoised is None or torch.sum(next_sigma) < 1e-14: if old_denoised is None or torch.sum(next_sigma) < 1e-14:
@ -343,12 +380,16 @@ class DPMPP2MSampler(BaseDiffusionSampler):
x_advanced = mult[0] * x - mult[1] * denoised_d x_advanced = mult[0] * x - mult[1] * denoised_d
# apply correction if noise level is not 0 and not first step # apply correction if noise level is not 0 and not first step
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
)
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
@ -365,7 +406,6 @@ class DPMPP2MSampler(BaseDiffusionSampler):
return x return x
class SDEDPMPP2MSampler(BaseDiffusionSampler): class SDEDPMPP2MSampler(BaseDiffusionSampler):
def get_variables(self, sigma, next_sigma, previous_sigma=None): def get_variables(self, sigma, next_sigma, previous_sigma=None):
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
@ -380,7 +420,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
def get_mult(self, h, r, t, t_next, previous_sigma): def get_mult(self, h, r, t, t_next, previous_sigma):
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
mult2 = (-2 * h).expm1() mult2 = (-2*h).expm1()
if previous_sigma is not None: if previous_sigma is not None:
mult3 = 1 + 1 / (2 * r) mult3 = 1 + 1 / (2 * r)
@ -403,8 +443,11 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
denoised = self.denoise(x, denoiser, sigma, cond, uc) denoised = self.denoise(x, denoiser, sigma, cond, uc)
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] mult = [
mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim) append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
]
mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim)
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
if old_denoised is None or torch.sum(next_sigma) < 1e-14: if old_denoised is None or torch.sum(next_sigma) < 1e-14:
@ -415,12 +458,16 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
# apply correction if noise level is not 0 and not first step # apply correction if noise level is not 0 and not first step
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
)
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
@ -437,7 +484,6 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
return x return x
class SdeditEDMSampler(EulerEDMSampler): class SdeditEDMSampler(EulerEDMSampler):
def __init__(self, edit_ratio=0.5, *args, **kwargs): def __init__(self, edit_ratio=0.5, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -446,7 +492,9 @@ class SdeditEDMSampler(EulerEDMSampler):
def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None): def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
randn_unit = randn.clone() randn_unit = randn.clone()
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps) randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
randn, cond, uc, num_steps
)
if num_steps is None: if num_steps is None:
num_steps = self.num_steps num_steps = self.num_steps
@ -461,7 +509,9 @@ class SdeditEDMSampler(EulerEDMSampler):
x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape)) x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
gamma = ( gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
) )
x = self.sampler_step( x = self.sampler_step(
s_in * sigmas[i], s_in * sigmas[i],
@ -475,8 +525,8 @@ class SdeditEDMSampler(EulerEDMSampler):
return x return x
class VideoDDIMSampler(BaseDiffusionSampler): class VideoDDIMSampler(BaseDiffusionSampler):
def __init__(self, fixed_frames=0, sdedit=False, **kwargs): def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.fixed_frames = fixed_frames self.fixed_frames = fixed_frames
@ -484,13 +534,10 @@ class VideoDDIMSampler(BaseDiffusionSampler):
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
alpha_cumprod_sqrt, timesteps = self.discretization( alpha_cumprod_sqrt, timesteps = self.discretization(
self.num_steps if num_steps is None else num_steps, self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False
device=self.device,
return_idx=True,
do_append_zero=False,
) )
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]) timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))])
uc = default(uc, cond) uc = default(uc, cond)
@ -500,51 +547,36 @@ class VideoDDIMSampler(BaseDiffusionSampler):
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None): def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None):
additional_model_inputs = {} additional_model_inputs = {}
if ofs is not None:
additional_model_inputs['ofs'] = ofs
if isinstance(scale, torch.Tensor) == False and scale == 1: if isinstance(scale, torch.Tensor) == False and scale == 1:
additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
if scale_emb is not None: if scale_emb is not None:
additional_model_inputs["scale_emb"] = scale_emb additional_model_inputs['scale_emb'] = scale_emb
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
else: else:
additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
denoised = denoiser( denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32)
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
).to(torch.float32)
if isinstance(self.guider, DynamicCFG): if isinstance(self.guider, DynamicCFG):
denoised = self.guider( denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale)
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
)
else: else:
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale) denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale)
return denoised return denoised
def sampler_step( def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None):
self, denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
alpha_cumprod_sqrt,
next_alpha_cumprod_sqrt,
denoiser,
x,
cond,
uc=None,
idx=None,
timestep=None,
scale=None,
scale_emb=None,
):
denoised = self.denoise(
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
).to(torch.float32)
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps x, cond, uc, num_steps
) )
@ -558,25 +590,83 @@ class VideoDDIMSampler(BaseDiffusionSampler):
cond, cond,
uc, uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i + 1)], timestep=timesteps[-(i+1)],
scale=scale, scale=scale,
scale_emb=scale_emb, scale_emb=scale_emb,
ofs=ofs # 1020
) )
return x return x
class Image2VideoDDIMSampler(BaseDiffusionSampler):
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
alpha_cumprod_sqrt, timesteps = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
)
uc = default(uc, cond)
num_sigmas = len(alpha_cumprod_sqrt)
s_in = x.new_ones([x.shape[0]])
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
additional_model_inputs = {}
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(
torch.float32)
if isinstance(self.guider, DynamicCFG):
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep)
else:
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5)
return denoised
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None,
timestep=None):
# 此处的sigma实际上是alpha_cumprod_sqrt
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32)
if idx == 1:
return denoised
a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
return x
def __call__(self, image, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
s_in * alpha_cumprod_sqrt[i],
s_in * alpha_cumprod_sqrt[i + 1],
denoiser,
x,
cond,
uc,
idx=self.num_steps - i,
timestep=timesteps[-(i + 1)]
)
return x
class VPSDEDPMPP2MSampler(VideoDDIMSampler): class VPSDEDPMPP2MSampler(VideoDDIMSampler):
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
alpha_cumprod = alpha_cumprod_sqrt**2 alpha_cumprod = alpha_cumprod_sqrt ** 2
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
next_alpha_cumprod = next_alpha_cumprod_sqrt**2 next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
h = lamb_next - lamb h = lamb_next - lamb
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
h_last = lamb - lamb_previous h_last = lamb - lamb_previous
r = h_last / h r = h_last / h
return h, r, lamb, lamb_next return h, r, lamb, lamb_next
@ -584,8 +674,8 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
return h, None, lamb, lamb_next return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp() mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp()
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
mult3 = 1 + 1 / (2 * r) mult3 = 1 + 1 / (2 * r)
@ -608,21 +698,18 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
timestep=None, timestep=None,
scale=None, scale=None,
scale_emb=None, scale_emb=None,
ofs=None # 1020
): ):
denoised = self.denoise( denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
).to(torch.float32)
if idx == 1: if idx == 1:
return denoised, denoised return denoised, denoised
h, r, lamb, lamb_next = self.get_variables( h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
mult = [ mult = [
append_dims(mult, x.ndim) append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
] ]
mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim)
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
@ -636,24 +723,23 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps x, cond, uc, num_steps
) )
if self.fixed_frames > 0: if self.fixed_frames > 0:
prefix_frames = x[:, : self.fixed_frames] prefix_frames = x[:, :self.fixed_frames]
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
if self.fixed_frames > 0: if self.fixed_frames > 0:
if self.sdedit: if self.sdedit:
rd = torch.randn_like(prefix_frames) rd = torch.randn_like(prefix_frames)
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape))
s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1)
)
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
else: else:
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
x, old_denoised = self.sampler_step( x, old_denoised = self.sampler_step(
old_denoised, old_denoised,
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
@ -664,28 +750,29 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=uc, uc=uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i + 1)], timestep=timesteps[-(i+1)],
scale=scale, scale=scale,
scale_emb=scale_emb, scale_emb=scale_emb,
ofs=ofs # 1020
) )
if self.fixed_frames > 0: if self.fixed_frames > 0:
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
return x return x
class VPODEDPMPP2MSampler(VideoDDIMSampler): class VPODEDPMPP2MSampler(VideoDDIMSampler):
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
alpha_cumprod = alpha_cumprod_sqrt**2 alpha_cumprod = alpha_cumprod_sqrt ** 2
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
next_alpha_cumprod = next_alpha_cumprod_sqrt**2 next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
h = lamb_next - lamb h = lamb_next - lamb
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
h_last = lamb - lamb_previous h_last = lamb - lamb_previous
r = h_last / h r = h_last / h
return h, r, lamb, lamb_next return h, r, lamb, lamb_next
@ -693,7 +780,7 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
return h, None, lamb, lamb_next return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
@ -714,15 +801,13 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=None, uc=None,
idx=None, idx=None,
timestep=None, timestep=None
): ):
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
if idx == 1: if idx == 1:
return denoised, denoised return denoised, denoised
h, r, lamb, lamb_next = self.get_variables( h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
mult = [ mult = [
append_dims(mult, x.ndim) append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
@ -757,7 +842,39 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=uc, uc=uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i + 1)], timestep=timesteps[-(i+1)]
) )
return x return x
class VideoDDPMSampler(VideoDDIMSampler):
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None):
# 此处的sigma实际上是alpha_cumprod_sqrt
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32)
if idx == 1:
return denoised
alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \
+ append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \
+ append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x)
return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
s_in * alpha_cumprod_sqrt[i],
s_in * alpha_cumprod_sqrt[i + 1],
denoiser,
x,
cond,
uc,
idx=self.num_steps - i
)
return x

View File

@ -17,23 +17,20 @@ class EDMSampling:
class DiscreteSampling: class DiscreteSampling:
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0):
self.num_idx = num_idx self.num_idx = num_idx
self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) self.sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip
)
world_size = mpu.get_data_parallel_world_size() world_size = mpu.get_data_parallel_world_size()
if world_size <= 8:
uniform_sampling = False
self.uniform_sampling = uniform_sampling self.uniform_sampling = uniform_sampling
self.group_num = group_num
if self.uniform_sampling: if self.uniform_sampling:
i = 1
while True:
if world_size % i != 0 or num_idx % (world_size // i) != 0:
i += 1
else:
self.group_num = world_size // i
break
assert self.group_num > 0 assert self.group_num > 0
assert world_size % self.group_num == 0 assert world_size % group_num == 0
self.group_width = world_size // self.group_num # the number of rank in one group self.group_width = world_size // group_num # the number of rank in one group
self.sigma_interval = self.num_idx // self.group_num self.sigma_interval = self.num_idx // self.group_num
def idx_to_sigma(self, idx): def idx_to_sigma(self, idx):
@ -45,9 +42,7 @@ class DiscreteSampling:
group_index = rank // self.group_width group_index = rank // self.group_width
idx = default( idx = default(
rand, rand,
torch.randint( torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)),
group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
),
) )
else: else:
idx = default( idx = default(
@ -59,7 +54,6 @@ class DiscreteSampling:
else: else:
return self.idx_to_sigma(idx) return self.idx_to_sigma(idx)
class PartialDiscreteSampling: class PartialDiscreteSampling:
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
self.total_num_idx = total_num_idx self.total_num_idx = total_num_idx

View File

@ -592,8 +592,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
unregularized: bool = False, unregularized: bool = False,
input_cp: bool = False, input_cp: bool = False,
output_cp: bool = False, output_cp: bool = False,
use_cp: bool = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.cp_size > 0 and not input_cp: if self.cp_size <= 1:
use_cp = False
if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized: if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size) initialize_context_parallel(self.cp_size)
@ -603,11 +606,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
x = _conv_split(x, dim=2, kernel_size=1) x = _conv_split(x, dim=2, kernel_size=1)
if return_reg_log: if return_reg_log:
z, reg_log = super().encode(x, return_reg_log, unregularized) z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
else: else:
z = super().encode(x, return_reg_log, unregularized) z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
if self.cp_size > 0 and not output_cp: if self.cp_size > 0 and use_cp and not output_cp:
z = _conv_gather(z, dim=2, kernel_size=1) z = _conv_gather(z, dim=2, kernel_size=1)
if return_reg_log: if return_reg_log:
@ -619,23 +622,24 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
z: torch.Tensor, z: torch.Tensor,
input_cp: bool = False, input_cp: bool = False,
output_cp: bool = False, output_cp: bool = False,
split_kernel_size: int = 1, use_cp: bool = True,
**kwargs, **kwargs,
): ):
if self.cp_size > 0 and not input_cp: if self.cp_size <= 1:
use_cp = False
if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized: if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size) initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group()) torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
z = _conv_split(z, dim=2, kernel_size=split_kernel_size) z = _conv_split(z, dim=2, kernel_size=1)
x = super().decode(z, **kwargs) x = super().decode(z, use_cp=use_cp, **kwargs)
if self.cp_size > 0 and not output_cp:
x = _conv_gather(x, dim=2, kernel_size=split_kernel_size)
if self.cp_size > 0 and use_cp and not output_cp:
x = _conv_gather(x, dim=2, kernel_size=1)
return x return x
def forward( def forward(

View File

@ -5,8 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from beartype import beartype from beartype.typing import Union, Tuple
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange from einops import rearrange
from sgm.util import ( from sgm.util import (
@ -16,11 +15,7 @@ from sgm.util import (
get_context_parallel_group_rank, get_context_parallel_group_rank,
) )
# try:
from vae_modules.utils import SafeConv3d as Conv3d from vae_modules.utils import SafeConv3d as Conv3d
# except:
# # Degrade to normal Conv3d if SafeConv3d is not available
# from torch.nn import Conv3d
def cast_tuple(t, length=1): def cast_tuple(t, length=1):
@ -81,7 +76,6 @@ def _split(input_, dim):
cp_rank = get_context_parallel_rank() cp_rank = get_context_parallel_rank()
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
@ -94,7 +88,6 @@ def _split(input_, dim):
output = torch.cat([inpu_first_frame_, output], dim=dim) output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous() output = output.contiguous()
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
return output return output
@ -382,19 +375,6 @@ class ContextParallelCausalConv3d(nn.Module):
self.cache_padding = None self.cache_padding = None
def forward(self, input_, clear_cache=True): def forward(self, input_, clear_cache=True):
# if input_.shape[2] == 1: # handle image
# # first frame padding
# input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
# else:
# input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
# padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
# input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
# output_parallel = self.conv(input_parallel)
# output = output_parallel
# return output
input_parallel = fake_cp_pass_from_previous_rank( input_parallel = fake_cp_pass_from_previous_rank(
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
) )
@ -464,7 +444,6 @@ class SpatialNorm3D(nn.Module):
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params) self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
else: else:
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params) self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer: if freeze_norm_layer:
for p in self.norm_layer.parameters: for p in self.norm_layer.parameters:
p.requires_grad = False p.requires_grad = False
@ -543,21 +522,29 @@ class Upsample3D(nn.Module):
def forward(self, x): def forward(self, x):
if self.compress_time and x.shape[2] > 1: if self.compress_time and x.shape[2] > 1:
if x.shape[2] % 2 == 1: # Process the time dimension first as x_first
# split first frame x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first, x_rest = x[:, :, 0], x[:, :, 1:] # print(x_first.shape)
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") # split the rest of the frames to avoid MAX_INT overflow in Pytorch
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") splits = torch.split(x_rest, 16, dim=1)
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) interpolated_splits = [
else: torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") ]
x_rest = torch.cat(interpolated_splits, dim=1)
# concatenate the first frame with the rest
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else: else:
# only interpolate 2D
t = x.shape[2] t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
splits = torch.split(x, 16, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x = torch.cat(interpolated_splits, dim=1)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t) x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.with_conv: if self.with_conv:
@ -585,15 +572,21 @@ class DownSample3D(nn.Module):
x = rearrange(x, "b c t h w -> (b h w) c t") x = rearrange(x, "b c t h w -> (b h w) c t")
if x.shape[-1] % 2 == 1: if x.shape[-1] % 2 == 1:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:] x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0: if x_rest.shape[-1] > 0:
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) splits = torch.split(x_rest, 16, dim=1)
interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
]
x_rest = torch.cat(interpolated_splits, dim=1)
x = torch.cat([x_first[..., None], x_rest], dim=-1) x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else: else:
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) splits = torch.split(x, 16, dim=1)
interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
]
x = torch.cat(interpolated_splits, dim=1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
if self.with_conv: if self.with_conv:
@ -675,31 +668,19 @@ class ContextParallelResnetBlock3D(nn.Module):
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
h = x h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None: if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
else: else:
h = self.norm1(h) h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv1(h, clear_cache=clear_fake_cp_cache) h = self.conv1(h, clear_cache=clear_fake_cp_cache)
if temb is not None: if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None: if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
else: else:
h = self.norm2(h) h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.dropout(h) h = self.dropout(h)
h = self.conv2(h, clear_cache=clear_fake_cp_cache) h = self.conv2(h, clear_cache=clear_fake_cp_cache)
@ -826,10 +807,7 @@ class ContextParallelEncoder3D(nn.Module):
h = self.mid.block_2(h, temb) h = self.mid.block_2(h, temb)
# end # end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h) h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) h = self.conv_out(h)
@ -934,10 +912,10 @@ class ContextParallelDecoder3D(nn.Module):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level: if i_level <= self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
self.up.insert(0, up) self.up.insert(0, up)
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
@ -954,9 +932,7 @@ class ContextParallelDecoder3D(nn.Module):
# timestep embedding # timestep embedding
temb = None temb = None
t = z.shape[2]
# z to block_in # z to block_in
zq = z zq = z
h = self.conv_in(z, clear_cache=clear_fake_cp_cache) h = self.conv_in(z, clear_cache=clear_fake_cp_cache)

View File

@ -1,22 +1,15 @@
""" """
This script demonstrates how to convert and generate video from a text prompt
using CogVideoX with 🤗Huggingface Diffusers Pipeline.
This script requires the `diffusers>=0.30.2` library to be installed.
Functions:
- reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
- reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
- reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
- remove_keys_inplace: Removes specified keys from the state_dict in-place.
- replace_up_keys_inplace: Replaces keys in the "up" block in-place.
- get_state_dict: Extracts the state_dict from a saved checkpoint.
- update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
- convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
- convert_vae: Converts a VAE checkpoint to the CogVideoX format.
- get_args: Parses command-line arguments for the script.
- generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
"""
The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format.
This script supports the conversion of the following models:
- CogVideoX-2B
- CogVideoX-5B, CogVideoX-5B-I2V
- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V
Original Script:
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
"""
import argparse import argparse
from typing import Any, Dict from typing import Any, Dict
@ -153,12 +146,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer( def convert_transformer(
ckpt_path: str, ckpt_path: str,
num_layers: int, num_layers: int,
num_attention_heads: int, num_attention_heads: int,
use_rotary_positional_embeddings: bool, use_rotary_positional_embeddings: bool,
i2v: bool, i2v: bool,
dtype: torch.dtype, dtype: torch.dtype,
): ):
PREFIX_KEY = "model.diffusion_model." PREFIX_KEY = "model.diffusion_model."
@ -172,7 +165,7 @@ def convert_transformer(
).to(dtype=dtype) ).to(dtype=dtype)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :] new_key = key[len(PREFIX_KEY):]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key) new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key) update_state_dict_inplace(original_state_dict, key, new_key)
@ -209,7 +202,8 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint") "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
@ -259,9 +253,10 @@ if __name__ == "__main__":
if args.vae_ckpt_path is not None: if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
text_encoder_id = "google/t5-v1_1-xxl" text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work anymore without this :shrug: # Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters(): for param in text_encoder.parameters():
param.data = param.data.contiguous() param.data = param.data.contiguous()
@ -301,4 +296,7 @@ if __name__ == "__main__":
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here). # is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
# This is necessary This is necessary for users with insufficient memory,
# such as those using Colab and notebooks, as it can save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)