mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
148 lines
3.4 KiB
Python
148 lines
3.4 KiB
Python
import logging
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@torch.jit.script
|
|
def _fused_tanh_sigmoid(h):
|
|
a, b = h.chunk(2, dim=1)
|
|
h = a.tanh() * b.sigmoid()
|
|
return h
|
|
|
|
|
|
class WNLayer(nn.Module):
|
|
"""
|
|
A DiffWave-like WN
|
|
"""
|
|
|
|
def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
|
|
super().__init__()
|
|
|
|
local_output_dim = hidden_dim * 2
|
|
|
|
if global_dim is not None:
|
|
self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
|
|
|
|
if local_dim is not None:
|
|
self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
|
|
|
|
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
|
|
|
|
self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
|
|
|
|
def forward(self, z, l, g):
|
|
identity = z
|
|
|
|
if g is not None:
|
|
if g.dim() == 2:
|
|
g = g.unsqueeze(-1)
|
|
z = z + self.gconv(g)
|
|
|
|
z = self.dconv(z)
|
|
|
|
if l is not None:
|
|
z = z + self.lconv(l)
|
|
|
|
z = _fused_tanh_sigmoid(z)
|
|
|
|
h = self.out(z)
|
|
|
|
z, s = h.chunk(2, dim=1)
|
|
|
|
o = (z + identity) / math.sqrt(2)
|
|
|
|
return o, s
|
|
|
|
|
|
class WN(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
output_dim,
|
|
local_dim=None,
|
|
global_dim=None,
|
|
n_layers=30,
|
|
kernel_size=3,
|
|
dilation_cycle=5,
|
|
hidden_dim=512,
|
|
):
|
|
super().__init__()
|
|
assert kernel_size % 2 == 1
|
|
assert hidden_dim % 2 == 0
|
|
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.local_dim = local_dim
|
|
self.global_dim = global_dim
|
|
|
|
self.start = nn.Conv1d(input_dim, hidden_dim, 1)
|
|
if local_dim is not None:
|
|
self.local_norm = nn.InstanceNorm1d(local_dim)
|
|
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
WNLayer(
|
|
hidden_dim=hidden_dim,
|
|
local_dim=local_dim,
|
|
global_dim=global_dim,
|
|
kernel_size=kernel_size,
|
|
dilation=2 ** (i % dilation_cycle),
|
|
)
|
|
for i in range(n_layers)
|
|
]
|
|
)
|
|
|
|
self.end = nn.Conv1d(hidden_dim, output_dim, 1)
|
|
|
|
def forward(self, z, l=None, g=None):
|
|
"""
|
|
Args:
|
|
z: input (b c t)
|
|
l: local condition (b c t)
|
|
g: global condition (b d)
|
|
"""
|
|
z = self.start(z)
|
|
|
|
if l is not None:
|
|
l = self.local_norm(l)
|
|
|
|
# Skips
|
|
s_list = []
|
|
|
|
for layer in self.layers:
|
|
z, s = layer(z, l, g)
|
|
s_list.append(s)
|
|
|
|
s_list = torch.stack(s_list, dim=0).sum(dim=0)
|
|
s_list = s_list / math.sqrt(len(self.layers))
|
|
|
|
o = self.end(s_list)
|
|
|
|
return o
|
|
|
|
def summarize(self, length=100):
|
|
from ptflops import get_model_complexity_info
|
|
|
|
x = torch.randn(1, self.input_dim, length)
|
|
|
|
macs, params = get_model_complexity_info(
|
|
self,
|
|
(self.input_dim, length),
|
|
as_strings=True,
|
|
print_per_layer_stat=True,
|
|
verbose=True,
|
|
)
|
|
|
|
print(f"Input shape: {x.shape}")
|
|
print(f"Computational complexity: {macs}")
|
|
print(f"Number of parameters: {params}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = WN(input_dim=64, output_dim=64)
|
|
model.summarize()
|