mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
commit
d0d35194a1
106
GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py
Normal file
106
GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
|
||||||
|
import os, sys
|
||||||
|
|
||||||
|
now_dir = os.getcwd()
|
||||||
|
sys.path.append(now_dir)
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch_lightning import LightningModule
|
||||||
|
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
||||||
|
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||||
|
from AR.modules.optim import ScaledAdam
|
||||||
|
|
||||||
|
|
||||||
|
class Text2SemanticLightningModule(LightningModule):
|
||||||
|
def __init__(self, config, output_dir, is_train=True):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.top_k = 3
|
||||||
|
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
|
||||||
|
pretrained_s1 = config.get("pretrained_s1")
|
||||||
|
if pretrained_s1 and is_train:
|
||||||
|
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||||
|
print(
|
||||||
|
self.load_state_dict(
|
||||||
|
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_train:
|
||||||
|
self.automatic_optimization = False
|
||||||
|
self.save_hyperparameters()
|
||||||
|
self.eval_dir = output_dir / "eval"
|
||||||
|
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def training_step(self, batch: Dict, batch_idx: int):
|
||||||
|
opt = self.optimizers()
|
||||||
|
scheduler = self.lr_schedulers()
|
||||||
|
loss, acc = self.model.forward(
|
||||||
|
batch["phoneme_ids"],
|
||||||
|
batch["phoneme_ids_len"],
|
||||||
|
batch["semantic_ids"],
|
||||||
|
batch["semantic_ids_len"],
|
||||||
|
batch["bert_feature"],
|
||||||
|
)
|
||||||
|
self.manual_backward(loss)
|
||||||
|
if batch_idx > 0 and batch_idx % 4 == 0:
|
||||||
|
opt.step()
|
||||||
|
opt.zero_grad()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
self.log(
|
||||||
|
"total_loss",
|
||||||
|
loss,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"lr",
|
||||||
|
scheduler.get_last_lr()[0],
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
f"top_{self.top_k}_acc",
|
||||||
|
acc,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validation_step(self, batch: Dict, batch_idx: int):
|
||||||
|
return
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
model_parameters = self.model.parameters()
|
||||||
|
parameters_names = []
|
||||||
|
parameters_names.append(
|
||||||
|
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
||||||
|
)
|
||||||
|
lm_opt = ScaledAdam(
|
||||||
|
model_parameters,
|
||||||
|
lr=0.01,
|
||||||
|
betas=(0.9, 0.95),
|
||||||
|
clipping_scale=2.0,
|
||||||
|
parameters_names=parameters_names,
|
||||||
|
show_dominant_parameters=False,
|
||||||
|
clipping_update_period=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"optimizer": lm_opt,
|
||||||
|
"lr_scheduler": {
|
||||||
|
"scheduler": WarmupCosineLRSchedule(
|
||||||
|
lm_opt,
|
||||||
|
init_lr=self.config["optimizer"]["lr_init"],
|
||||||
|
peak_lr=self.config["optimizer"]["lr"],
|
||||||
|
end_lr=self.config["optimizer"]["lr_end"],
|
||||||
|
warmup_steps=self.config["optimizer"]["warmup_steps"],
|
||||||
|
total_steps=self.config["optimizer"]["decay_steps"],
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
337
GPT_SoVITS/AR/models/t2s_model_onnx.py
Normal file
337
GPT_SoVITS/AR/models/t2s_model_onnx.py
Normal file
@ -0,0 +1,337 @@
|
|||||||
|
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from AR.modules.embedding_onnx import SinePositionalEmbedding
|
||||||
|
from AR.modules.embedding_onnx import TokenEmbedding
|
||||||
|
from AR.modules.transformer_onnx import LayerNorm
|
||||||
|
from AR.modules.transformer_onnx import TransformerEncoder
|
||||||
|
from AR.modules.transformer_onnx import TransformerEncoderLayer
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
|
||||||
|
default_config = {
|
||||||
|
"embedding_dim": 512,
|
||||||
|
"hidden_dim": 512,
|
||||||
|
"num_head": 8,
|
||||||
|
"num_layers": 12,
|
||||||
|
"num_codebook": 8,
|
||||||
|
"p_dropout": 0.0,
|
||||||
|
"vocab_size": 1024 + 1,
|
||||||
|
"phoneme_vocab_size": 512,
|
||||||
|
"EOS": 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
||||||
|
|
||||||
|
def logits_to_probs(
|
||||||
|
logits,
|
||||||
|
previous_tokens = None,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k = None,
|
||||||
|
top_p = None,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
|
):
|
||||||
|
previous_tokens = previous_tokens.squeeze()
|
||||||
|
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||||
|
previous_tokens = previous_tokens.long()
|
||||||
|
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||||
|
score = torch.where(
|
||||||
|
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||||
|
)
|
||||||
|
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||||
|
|
||||||
|
if top_p is not None and top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cum_probs = torch.cumsum(
|
||||||
|
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||||
|
)
|
||||||
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
|
sorted_indices_to_remove[0] = False # keep at least one option
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
|
logits = logits / max(temperature, 1e-5)
|
||||||
|
|
||||||
|
if top_k is not None:
|
||||||
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||||
|
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
||||||
|
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
def multinomial_sample_one_no_sync(
|
||||||
|
probs_sort
|
||||||
|
): # Does multinomial sampling without a cuda synchronization
|
||||||
|
q = torch.randn_like(probs_sort)
|
||||||
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
logits,
|
||||||
|
previous_tokens,
|
||||||
|
**sampling_kwargs,
|
||||||
|
):
|
||||||
|
probs = logits_to_probs(
|
||||||
|
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
||||||
|
)
|
||||||
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
|
return idx_next, probs
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
|
||||||
|
super().__init__()
|
||||||
|
self.ar_text_embedding = ar_text_embedding
|
||||||
|
self.bert_proj = bert_proj
|
||||||
|
self.ar_text_position = ar_text_position
|
||||||
|
|
||||||
|
def forward(self, x, bert_feature):
|
||||||
|
x = self.ar_text_embedding(x)
|
||||||
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
|
return self.ar_text_position(x)
|
||||||
|
|
||||||
|
|
||||||
|
class T2SFirstStageDecoder(nn.Module):
|
||||||
|
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||||
|
top_k, early_stop_num, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.ar_audio_embedding = ar_audio_embedding
|
||||||
|
self.ar_audio_position = ar_audio_position
|
||||||
|
self.h = h
|
||||||
|
self.ar_predict_layer = ar_predict_layer
|
||||||
|
self.loss_fct = loss_fct
|
||||||
|
self.ar_accuracy_metric = ar_accuracy_metric
|
||||||
|
self.top_k = top_k
|
||||||
|
self.early_stop_num = early_stop_num
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(self, x, prompt):
|
||||||
|
y = prompt
|
||||||
|
x_example = x[:,:,0] * 0.0
|
||||||
|
#N, 1, 512
|
||||||
|
cache = {
|
||||||
|
"all_stage": self.num_layers,
|
||||||
|
"k": None,
|
||||||
|
"v": None,
|
||||||
|
"y_emb": None,
|
||||||
|
"first_infer": 1,
|
||||||
|
"stage": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
y_emb = self.ar_audio_embedding(y)
|
||||||
|
|
||||||
|
cache["y_emb"] = y_emb
|
||||||
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
|
||||||
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
|
||||||
|
y_example = y_pos[:,:,0] * 0.0
|
||||||
|
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
|
||||||
|
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
||||||
|
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||||
|
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
|
||||||
|
)
|
||||||
|
y_attn_mask = y_attn_mask > 0
|
||||||
|
|
||||||
|
x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
|
||||||
|
y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
|
||||||
|
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
||||||
|
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||||
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
|
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||||
|
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||||
|
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||||
|
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||||
|
|
||||||
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||||
|
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|
||||||
|
return y, cache["k"], cache["v"], cache["y_emb"], x_example
|
||||||
|
|
||||||
|
|
||||||
|
class T2SStageDecoder(nn.Module):
|
||||||
|
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||||
|
top_k, early_stop_num, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.ar_audio_embedding = ar_audio_embedding
|
||||||
|
self.ar_audio_position = ar_audio_position
|
||||||
|
self.h = h
|
||||||
|
self.ar_predict_layer = ar_predict_layer
|
||||||
|
self.loss_fct = loss_fct
|
||||||
|
self.ar_accuracy_metric = ar_accuracy_metric
|
||||||
|
self.top_k = top_k
|
||||||
|
self.early_stop_num = early_stop_num
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(self, y, k, v, y_emb, x_example):
|
||||||
|
cache = {
|
||||||
|
"all_stage": self.num_layers,
|
||||||
|
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
||||||
|
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
|
||||||
|
"y_emb": y_emb,
|
||||||
|
"first_infer": 0,
|
||||||
|
"stage": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
y_emb = torch.cat(
|
||||||
|
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||||
|
)
|
||||||
|
cache["y_emb"] = y_emb
|
||||||
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
|
||||||
|
xy_pos = y_pos[:, -1:]
|
||||||
|
|
||||||
|
y_example = y_pos[:,:,0] * 0.0
|
||||||
|
|
||||||
|
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
||||||
|
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
||||||
|
|
||||||
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||||
|
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|
||||||
|
return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
|
||||||
|
|
||||||
|
|
||||||
|
class Text2SemanticDecoder(nn.Module):
|
||||||
|
def __init__(self, config, norm_first=False, top_k=3):
|
||||||
|
super(Text2SemanticDecoder, self).__init__()
|
||||||
|
self.model_dim = config["model"]["hidden_dim"]
|
||||||
|
self.embedding_dim = config["model"]["embedding_dim"]
|
||||||
|
self.num_head = config["model"]["head"]
|
||||||
|
self.num_layers = config["model"]["n_layer"]
|
||||||
|
self.norm_first = norm_first
|
||||||
|
self.vocab_size = config["model"]["vocab_size"]
|
||||||
|
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
|
||||||
|
self.p_dropout = float(config["model"]["dropout"])
|
||||||
|
self.EOS = config["model"]["EOS"]
|
||||||
|
self.norm_first = norm_first
|
||||||
|
assert self.EOS == self.vocab_size - 1
|
||||||
|
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||||
|
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
|
||||||
|
self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
|
||||||
|
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
|
||||||
|
self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
|
||||||
|
self.h = TransformerEncoder(
|
||||||
|
TransformerEncoderLayer(
|
||||||
|
d_model=self.model_dim,
|
||||||
|
nhead=self.num_head,
|
||||||
|
dim_feedforward=self.model_dim * 4,
|
||||||
|
dropout=0.1,
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=norm_first,
|
||||||
|
),
|
||||||
|
num_layers=self.num_layers,
|
||||||
|
norm=LayerNorm(self.model_dim) if norm_first else None,
|
||||||
|
)
|
||||||
|
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
||||||
|
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
|
||||||
|
self.ar_accuracy_metric = MulticlassAccuracy(
|
||||||
|
self.vocab_size,
|
||||||
|
top_k=top_k,
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
ignore_index=self.EOS,
|
||||||
|
)
|
||||||
|
self.top_k = torch.LongTensor([1])
|
||||||
|
self.early_stop_num = torch.LongTensor([-1])
|
||||||
|
|
||||||
|
def init_onnx(self):
|
||||||
|
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||||
|
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||||
|
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||||
|
self.num_layers)
|
||||||
|
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||||
|
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||||
|
self.num_layers)
|
||||||
|
|
||||||
|
def forward(self, x, prompts, bert_feature):
|
||||||
|
early_stop_num = self.early_stop_num
|
||||||
|
prefix_len = prompts.shape[1]
|
||||||
|
|
||||||
|
x = self.onnx_encoder(x, bert_feature)
|
||||||
|
y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
|
||||||
|
|
||||||
|
stop = False
|
||||||
|
for idx in range(1, 1500):
|
||||||
|
enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
|
||||||
|
y, k, v, y_emb, stage, logits, samples = enco
|
||||||
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
|
stop = True
|
||||||
|
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||||
|
stop = True
|
||||||
|
if stop:
|
||||||
|
break
|
||||||
|
y[0, -1] = 0
|
||||||
|
return y, idx
|
||||||
|
|
||||||
|
def infer(self, x, prompts, bert_feature):
|
||||||
|
top_k = self.top_k
|
||||||
|
early_stop_num = self.early_stop_num
|
||||||
|
|
||||||
|
x = self.onnx_encoder(x, bert_feature)
|
||||||
|
|
||||||
|
y = prompts
|
||||||
|
prefix_len = y.shape[1]
|
||||||
|
x_len = x.shape[1]
|
||||||
|
x_example = x[:,:,0] * 0.0
|
||||||
|
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
|
||||||
|
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
|
||||||
|
|
||||||
|
stop = False
|
||||||
|
cache = {
|
||||||
|
"all_stage": self.num_layers,
|
||||||
|
"k": [None] * self.num_layers,
|
||||||
|
"v": [None] * self.num_layers,
|
||||||
|
"y_emb": None,
|
||||||
|
"first_infer": 1,
|
||||||
|
"stage": 0,
|
||||||
|
}
|
||||||
|
for idx in range(1500):
|
||||||
|
if cache["first_infer"] == 1:
|
||||||
|
y_emb = self.ar_audio_embedding(y)
|
||||||
|
else:
|
||||||
|
y_emb = torch.cat(
|
||||||
|
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||||
|
)
|
||||||
|
cache["y_emb"] = y_emb
|
||||||
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
if cache["first_infer"] == 1:
|
||||||
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
else:
|
||||||
|
xy_pos = y_pos[:, -1:]
|
||||||
|
y_len = y_pos.shape[1]
|
||||||
|
if cache["first_infer"] == 1:
|
||||||
|
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
||||||
|
y_attn_mask = F.pad(
|
||||||
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||||
|
(x_len, 0), value=False
|
||||||
|
)
|
||||||
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
|
else:
|
||||||
|
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
||||||
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||||
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
|
stop = True
|
||||||
|
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||||
|
stop = True
|
||||||
|
if stop:
|
||||||
|
if prompts.shape[1] == y.shape[1]:
|
||||||
|
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||||
|
break
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
cache["first_infer"] = 0
|
||||||
|
return y, idx
|
178
GPT_SoVITS/AR/modules/activation_onnx.py
Normal file
178
GPT_SoVITS/AR/modules/activation_onnx.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Linear
|
||||||
|
from torch.nn import Module
|
||||||
|
from torch.nn.init import constant_
|
||||||
|
from torch.nn.init import xavier_normal_
|
||||||
|
from torch.nn.init import xavier_uniform_
|
||||||
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(Module):
|
||||||
|
__constants__ = ["batch_first"]
|
||||||
|
bias_k: Optional[torch.Tensor]
|
||||||
|
bias_v: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
|
kdim=None,
|
||||||
|
vdim=None,
|
||||||
|
batch_first=False,
|
||||||
|
linear1_cls=Linear,
|
||||||
|
linear2_cls=Linear,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super(MultiheadAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.batch_first = batch_first
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||||
|
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
if linear1_cls == Linear:
|
||||||
|
if not self._qkv_same_embed_dim:
|
||||||
|
self.q_proj_weight = Parameter(
|
||||||
|
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.k_proj_weight = Parameter(
|
||||||
|
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.v_proj_weight = Parameter(
|
||||||
|
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.register_parameter("in_proj_weight", None)
|
||||||
|
else:
|
||||||
|
self.in_proj_weight = Parameter(
|
||||||
|
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.register_parameter("q_proj_weight", None)
|
||||||
|
self.register_parameter("k_proj_weight", None)
|
||||||
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.in_proj_bias = Parameter(
|
||||||
|
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||||
|
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
else:
|
||||||
|
if not self._qkv_same_embed_dim:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
self.in_proj_linear = linear1_cls(
|
||||||
|
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.in_proj_weight = self.in_proj_linear.weight
|
||||||
|
|
||||||
|
self.register_parameter("q_proj_weight", None)
|
||||||
|
self.register_parameter("k_proj_weight", None)
|
||||||
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.in_proj_bias = self.in_proj_linear.bias
|
||||||
|
else:
|
||||||
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
|
self.out_proj = linear2_cls(
|
||||||
|
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
xavier_normal_(self.bias_v)
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
if self._qkv_same_embed_dim:
|
||||||
|
xavier_uniform_(self.in_proj_weight)
|
||||||
|
else:
|
||||||
|
xavier_uniform_(self.q_proj_weight)
|
||||||
|
xavier_uniform_(self.k_proj_weight)
|
||||||
|
xavier_uniform_(self.v_proj_weight)
|
||||||
|
|
||||||
|
if self.in_proj_bias is not None:
|
||||||
|
constant_(self.in_proj_bias, 0.0)
|
||||||
|
constant_(self.out_proj.bias, 0.0)
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
xavier_normal_(self.bias_v)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
||||||
|
if "_qkv_same_embed_dim" not in state:
|
||||||
|
state["_qkv_same_embed_dim"] = True
|
||||||
|
|
||||||
|
super(MultiheadAttention, self).__setstate__(state)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
average_attn_weights: bool = True,
|
||||||
|
cache=None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||||
|
query = key = value = query.transpose(1, 0)
|
||||||
|
attn_output = multi_head_attention_forward_patched(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.in_proj_weight,
|
||||||
|
self.in_proj_bias,
|
||||||
|
self.bias_k,
|
||||||
|
self.bias_v,
|
||||||
|
self.add_zero_attn,
|
||||||
|
self.dropout,
|
||||||
|
self.out_proj.weight,
|
||||||
|
self.out_proj.bias,
|
||||||
|
training=self.training,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
average_attn_weights=average_attn_weights,
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
return attn_output.transpose(1, 0)
|
63
GPT_SoVITS/AR/modules/embedding_onnx.py
Normal file
63
GPT_SoVITS/AR/modules/embedding_onnx.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class TokenEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout)
|
||||||
|
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self) -> torch.Tensor:
|
||||||
|
return self.word_embeddings.weight
|
||||||
|
|
||||||
|
def embedding(self, index: int) -> torch.Tensor:
|
||||||
|
return self.word_embeddings.weight[index : index + 1]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = self.word_embeddings(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SinePositionalEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scale: bool = False,
|
||||||
|
alpha: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||||
|
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout)
|
||||||
|
self.reverse = False
|
||||||
|
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
||||||
|
|
||||||
|
def extend_pe(self, x):
|
||||||
|
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
|
||||||
|
scpe = (position * self.div_term).unsqueeze(0)
|
||||||
|
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
||||||
|
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pe = self.extend_pe(x)
|
||||||
|
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
||||||
|
output = output * self.x_scale + self.alpha * pe
|
||||||
|
return self.dropout(output)
|
92
GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py
Normal file
92
GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from torch.nn.functional import *
|
||||||
|
from torch.nn.functional import (
|
||||||
|
_mha_shape_check,
|
||||||
|
_canonical_mask,
|
||||||
|
_none_or_dtype,
|
||||||
|
_in_projection_packed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def multi_head_attention_forward_patched(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
embed_dim_to_check: int,
|
||||||
|
num_heads: int,
|
||||||
|
in_proj_weight,
|
||||||
|
in_proj_bias: Optional[Tensor],
|
||||||
|
bias_k: Optional[Tensor],
|
||||||
|
bias_v: Optional[Tensor],
|
||||||
|
add_zero_attn: bool,
|
||||||
|
dropout_p: float,
|
||||||
|
out_proj_weight: Tensor,
|
||||||
|
out_proj_bias: Optional[Tensor],
|
||||||
|
training: bool = True,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
use_separate_proj_weight: bool = False,
|
||||||
|
q_proj_weight: Optional[Tensor] = None,
|
||||||
|
k_proj_weight: Optional[Tensor] = None,
|
||||||
|
v_proj_weight: Optional[Tensor] = None,
|
||||||
|
static_k: Optional[Tensor] = None,
|
||||||
|
static_v: Optional[Tensor] = None,
|
||||||
|
average_attn_weights: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
cache=None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
|
||||||
|
# set up shape vars
|
||||||
|
_, _, embed_dim = query.shape
|
||||||
|
attn_mask = _canonical_mask(
|
||||||
|
mask=attn_mask,
|
||||||
|
mask_name="attn_mask",
|
||||||
|
other_type=None,
|
||||||
|
other_name="",
|
||||||
|
target_type=query.dtype,
|
||||||
|
check_other=False,
|
||||||
|
)
|
||||||
|
head_dim = embed_dim // num_heads
|
||||||
|
|
||||||
|
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
|
||||||
|
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||||
|
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
||||||
|
|
||||||
|
if cache["first_infer"] == 1:
|
||||||
|
cache["k"][cache["stage"]] = k
|
||||||
|
cache["v"][cache["stage"]] = v
|
||||||
|
else:
|
||||||
|
cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
|
||||||
|
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
|
||||||
|
k = cache["k"][cache["stage"]]
|
||||||
|
v = cache["v"][cache["stage"]]
|
||||||
|
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
||||||
|
|
||||||
|
attn_mask = _canonical_mask(
|
||||||
|
mask=attn_mask,
|
||||||
|
mask_name="attn_mask",
|
||||||
|
other_type=None,
|
||||||
|
other_name="",
|
||||||
|
target_type=q.dtype,
|
||||||
|
check_other=False,
|
||||||
|
)
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
q = q.view(-1, num_heads, head_dim).transpose(0, 1)
|
||||||
|
k = k.view(-1, num_heads, head_dim).transpose(0, 1)
|
||||||
|
v = v.view(-1, num_heads, head_dim).transpose(0, 1)
|
||||||
|
|
||||||
|
dropout_p = 0.0
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
|
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
|
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask, dropout_p, is_causal
|
||||||
|
)
|
||||||
|
attn_output = (
|
||||||
|
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||||
|
)
|
||||||
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
|
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
||||||
|
|
||||||
|
return attn_output
|
292
GPT_SoVITS/AR/modules/transformer_onnx.py
Normal file
292
GPT_SoVITS/AR/modules/transformer_onnx.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
|
||||||
|
import copy
|
||||||
|
import numbers
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from AR.modules.activation_onnx import MultiheadAttention
|
||||||
|
from AR.modules.scaling import BalancedDoubleSwish
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
_shape_t = Union[int, List[int], torch.Size]
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
||||||
|
normalized_shape: Tuple[int, ...]
|
||||||
|
eps: float
|
||||||
|
elementwise_affine: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape: _shape_t,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
elementwise_affine: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.bias = nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
if self.elementwise_affine:
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
|
||||||
|
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
||||||
|
if isinstance(input, tuple):
|
||||||
|
input, embedding = input
|
||||||
|
return (
|
||||||
|
F.layer_norm(
|
||||||
|
input,
|
||||||
|
self.normalized_shape,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.eps,
|
||||||
|
),
|
||||||
|
embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding is None
|
||||||
|
return F.layer_norm(
|
||||||
|
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return (
|
||||||
|
"{normalized_shape}, eps={eps}, "
|
||||||
|
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityNorm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
super(IdentityNorm, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
||||||
|
if isinstance(input, tuple):
|
||||||
|
return input
|
||||||
|
|
||||||
|
assert embedding is None
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
||||||
|
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
||||||
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
|
norm: the layer normalization component (optional).
|
||||||
|
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
||||||
|
(and convert back on output). This will improve the overall performance of
|
||||||
|
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
||||||
|
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
||||||
|
>>> src = torch.rand(10, 32, 512)
|
||||||
|
>>> out = transformer_encoder(src)
|
||||||
|
"""
|
||||||
|
__constants__ = ["norm"]
|
||||||
|
|
||||||
|
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||||
|
super(TransformerEncoder, self).__init__()
|
||||||
|
self.layers = _get_clones(encoder_layer, num_layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.norm = norm
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
return_layer_states: bool = False,
|
||||||
|
cache=None,
|
||||||
|
) -> Tensor:
|
||||||
|
output = src
|
||||||
|
for mod in self.layers:
|
||||||
|
output = mod(
|
||||||
|
output,
|
||||||
|
src_mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
output = self.norm(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
__constants__ = ["batch_first", "norm_first"]
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
nhead: int,
|
||||||
|
dim_feedforward: int = 2048,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
||||||
|
batch_first: bool = False,
|
||||||
|
norm_first: bool = False,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
linear1_self_attention_cls: nn.Module = nn.Linear,
|
||||||
|
linear2_self_attention_cls: nn.Module = nn.Linear,
|
||||||
|
linear1_feedforward_cls: nn.Module = nn.Linear,
|
||||||
|
linear2_feedforward_cls: nn.Module = nn.Linear,
|
||||||
|
layer_norm_cls: nn.Module = LayerNorm,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
adaptive_layer_norm=False,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
d_model, # 512 16
|
||||||
|
nhead,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=batch_first,
|
||||||
|
linear1_cls=linear1_self_attention_cls,
|
||||||
|
linear2_cls=linear2_self_attention_cls,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.linear1 = linear1_feedforward_cls(
|
||||||
|
d_model, dim_feedforward, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.linear2 = linear2_feedforward_cls(
|
||||||
|
dim_feedforward, d_model, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.norm_first = norm_first
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
if isinstance(activation, str):
|
||||||
|
activation = _get_activation_fn(activation)
|
||||||
|
elif isinstance(activation, partial):
|
||||||
|
activation = activation(d_model)
|
||||||
|
elif activation == BalancedDoubleSwish:
|
||||||
|
activation = BalancedDoubleSwish(d_model)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||||
|
if layer_norm_cls == IdentityNorm:
|
||||||
|
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||||
|
else:
|
||||||
|
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||||
|
|
||||||
|
if adaptive_layer_norm:
|
||||||
|
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
||||||
|
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
||||||
|
else:
|
||||||
|
self.norm1 = norm1
|
||||||
|
self.norm2 = norm2
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super(TransformerEncoderLayer, self).__setstate__(state)
|
||||||
|
if not hasattr(self, "activation"):
|
||||||
|
self.activation = F.relu
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
cache=None,
|
||||||
|
) -> Tensor:
|
||||||
|
x = src
|
||||||
|
stage_embedding = None
|
||||||
|
x = self.norm1(
|
||||||
|
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
|
||||||
|
stage_embedding,
|
||||||
|
)
|
||||||
|
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _sa_block(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
attn_mask: Optional[Tensor],
|
||||||
|
key_padding_mask: Optional[Tensor],
|
||||||
|
cache=None,
|
||||||
|
) -> Tensor:
|
||||||
|
x = self.self_attn(
|
||||||
|
x,
|
||||||
|
x,
|
||||||
|
x,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
return self.dropout1(x)
|
||||||
|
|
||||||
|
def _ff_block(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||||
|
return self.dropout2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveLayerNorm(nn.Module):
|
||||||
|
r"""Adaptive Layer Normalization"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, norm) -> None:
|
||||||
|
super(AdaptiveLayerNorm, self).__init__()
|
||||||
|
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
||||||
|
self.norm = norm
|
||||||
|
self.d_model = d_model
|
||||||
|
self.eps = self.norm.eps
|
||||||
|
|
||||||
|
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
||||||
|
if isinstance(input, tuple):
|
||||||
|
input, embedding = input
|
||||||
|
weight, bias = torch.split(
|
||||||
|
self.project_layer(embedding),
|
||||||
|
split_size_or_sections=self.d_model,
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return (weight * self.norm(input) + bias, embedding)
|
||||||
|
|
||||||
|
weight, bias = torch.split(
|
||||||
|
self.project_layer(embedding),
|
||||||
|
split_size_or_sections=self.d_model,
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return weight * self.norm(input) + bias
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clones(module, N):
|
||||||
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
365
GPT_SoVITS/module/attentions_onnx.py
Normal file
365
GPT_SoVITS/module/attentions_onnx.py
Normal file
@ -0,0 +1,365 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from module import commons
|
||||||
|
from module.modules import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||||
|
return x.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||||
|
n_channels_int = n_channels[0]
|
||||||
|
in_act = input_a + input_b
|
||||||
|
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||||
|
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||||
|
acts = t_act * s_act
|
||||||
|
return acts
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
filter_channels,
|
||||||
|
n_heads,
|
||||||
|
n_layers,
|
||||||
|
kernel_size=1,
|
||||||
|
p_dropout=0.0,
|
||||||
|
window_size=4,
|
||||||
|
isflow=True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.filter_channels = filter_channels
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
self.window_size = window_size
|
||||||
|
# if isflow:
|
||||||
|
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
||||||
|
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
||||||
|
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
||||||
|
# self.gin_channels = 256
|
||||||
|
self.cond_layer_idx = self.n_layers
|
||||||
|
if "gin_channels" in kwargs:
|
||||||
|
self.gin_channels = kwargs["gin_channels"]
|
||||||
|
if self.gin_channels != 0:
|
||||||
|
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
||||||
|
# vits2 says 3rd block, so idx is 2 by default
|
||||||
|
self.cond_layer_idx = (
|
||||||
|
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
||||||
|
)
|
||||||
|
logging.debug(self.gin_channels, self.cond_layer_idx)
|
||||||
|
assert (
|
||||||
|
self.cond_layer_idx < self.n_layers
|
||||||
|
), "cond_layer_idx should be less than n_layers"
|
||||||
|
self.drop = nn.Dropout(p_dropout)
|
||||||
|
self.attn_layers = nn.ModuleList()
|
||||||
|
self.norm_layers_1 = nn.ModuleList()
|
||||||
|
self.ffn_layers = nn.ModuleList()
|
||||||
|
self.norm_layers_2 = nn.ModuleList()
|
||||||
|
for i in range(self.n_layers):
|
||||||
|
self.attn_layers.append(
|
||||||
|
MultiHeadAttention(
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
n_heads,
|
||||||
|
p_dropout=p_dropout,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||||
|
self.ffn_layers.append(
|
||||||
|
FFN(
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
filter_channels,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout=p_dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None):
|
||||||
|
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||||
|
x = x * x_mask
|
||||||
|
for i in range(self.n_layers):
|
||||||
|
if i == self.cond_layer_idx and g is not None:
|
||||||
|
g = self.spk_emb_linear(g.transpose(1, 2))
|
||||||
|
g = g.transpose(1, 2)
|
||||||
|
x = x + g
|
||||||
|
x = x * x_mask
|
||||||
|
y = self.attn_layers[i](x, x, attn_mask)
|
||||||
|
y = self.drop(y)
|
||||||
|
x = self.norm_layers_1[i](x + y)
|
||||||
|
|
||||||
|
y = self.ffn_layers[i](x, x_mask)
|
||||||
|
y = self.drop(y)
|
||||||
|
x = self.norm_layers_2[i](x + y)
|
||||||
|
x = x * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
out_channels,
|
||||||
|
n_heads,
|
||||||
|
p_dropout=0.0,
|
||||||
|
window_size=None,
|
||||||
|
heads_share=True,
|
||||||
|
block_length=None,
|
||||||
|
proximal_bias=False,
|
||||||
|
proximal_init=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert channels % n_heads == 0
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
self.window_size = window_size
|
||||||
|
self.heads_share = heads_share
|
||||||
|
self.block_length = block_length
|
||||||
|
self.proximal_bias = proximal_bias
|
||||||
|
self.proximal_init = proximal_init
|
||||||
|
self.attn = None
|
||||||
|
|
||||||
|
self.k_channels = channels // n_heads
|
||||||
|
self.conv_q = nn.Conv1d(channels, channels, 1)
|
||||||
|
self.conv_k = nn.Conv1d(channels, channels, 1)
|
||||||
|
self.conv_v = nn.Conv1d(channels, channels, 1)
|
||||||
|
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
||||||
|
self.drop = nn.Dropout(p_dropout)
|
||||||
|
|
||||||
|
if window_size is not None:
|
||||||
|
n_heads_rel = 1 if heads_share else n_heads
|
||||||
|
rel_stddev = self.k_channels**-0.5
|
||||||
|
self.emb_rel_k = nn.Parameter(
|
||||||
|
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||||
|
* rel_stddev
|
||||||
|
)
|
||||||
|
self.emb_rel_v = nn.Parameter(
|
||||||
|
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||||
|
* rel_stddev
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||||
|
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||||
|
nn.init.xavier_uniform_(self.conv_v.weight)
|
||||||
|
if proximal_init:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||||
|
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||||
|
|
||||||
|
def forward(self, x, c, attn_mask=None):
|
||||||
|
q = self.conv_q(x)
|
||||||
|
k = self.conv_k(c)
|
||||||
|
v = self.conv_v(c)
|
||||||
|
|
||||||
|
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||||
|
|
||||||
|
x = self.conv_o(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def attention(self, query, key, value, mask=None):
|
||||||
|
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||||
|
b, d, t_s, _ = (*key.size(), query.size(2))
|
||||||
|
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||||
|
key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||||
|
value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||||
|
|
||||||
|
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||||
|
if self.window_size is not None:
|
||||||
|
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||||
|
rel_logits = self._matmul_with_relative_keys(
|
||||||
|
query / math.sqrt(self.k_channels), key_relative_embeddings
|
||||||
|
)
|
||||||
|
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||||
|
scores = scores + scores_local
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores.masked_fill(mask == 0, -1e4)
|
||||||
|
if self.block_length is not None:
|
||||||
|
block_mask = (
|
||||||
|
torch.ones_like(scores)
|
||||||
|
.triu(-self.block_length)
|
||||||
|
.tril(self.block_length)
|
||||||
|
)
|
||||||
|
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||||
|
p_attn = F.softmax(scores, dim=-1)
|
||||||
|
p_attn = self.drop(p_attn)
|
||||||
|
output = torch.matmul(p_attn, value)
|
||||||
|
if self.window_size is not None:
|
||||||
|
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||||
|
value_relative_embeddings = self._get_relative_embeddings(
|
||||||
|
self.emb_rel_v, t_s
|
||||||
|
)
|
||||||
|
output = output + self._matmul_with_relative_values(
|
||||||
|
relative_weights, value_relative_embeddings
|
||||||
|
)
|
||||||
|
output = (
|
||||||
|
output.transpose(2, 3).contiguous().view(b, d, -1)
|
||||||
|
)
|
||||||
|
return output, p_attn
|
||||||
|
|
||||||
|
def _matmul_with_relative_values(self, x, y):
|
||||||
|
"""
|
||||||
|
x: [b, h, l, m]
|
||||||
|
y: [h or 1, m, d]
|
||||||
|
ret: [b, h, l, d]
|
||||||
|
"""
|
||||||
|
ret = torch.matmul(x, y.unsqueeze(0))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _matmul_with_relative_keys(self, x, y):
|
||||||
|
"""
|
||||||
|
x: [b, h, l, d]
|
||||||
|
y: [h or 1, m, d]
|
||||||
|
ret: [b, h, l, m]
|
||||||
|
"""
|
||||||
|
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||||
|
max_relative_position = 2 * self.window_size + 1
|
||||||
|
# Pad first before slice to avoid using cond ops.
|
||||||
|
pad_length = max(length - (self.window_size + 1), 0)
|
||||||
|
slice_start_position = max((self.window_size + 1) - length, 0)
|
||||||
|
slice_end_position = slice_start_position + 2 * length - 1
|
||||||
|
if pad_length > 0:
|
||||||
|
padded_relative_embeddings = F.pad(
|
||||||
|
relative_embeddings,
|
||||||
|
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
padded_relative_embeddings = relative_embeddings
|
||||||
|
used_relative_embeddings = padded_relative_embeddings[
|
||||||
|
:, slice_start_position:slice_end_position
|
||||||
|
]
|
||||||
|
return used_relative_embeddings
|
||||||
|
|
||||||
|
def _relative_position_to_absolute_position(self, x):
|
||||||
|
"""
|
||||||
|
x: [b, h, l, 2*l-1]
|
||||||
|
ret: [b, h, l, l]
|
||||||
|
"""
|
||||||
|
batch, heads, length, _ = x.size()
|
||||||
|
# Concat columns of pad to shift from relative to absolute indexing.
|
||||||
|
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
||||||
|
|
||||||
|
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||||
|
x_flat = x.view([batch, heads, length * 2 * length])
|
||||||
|
x_flat = F.pad(
|
||||||
|
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape and slice out the padded elements.
|
||||||
|
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
||||||
|
:, :, :length, length - 1 :
|
||||||
|
]
|
||||||
|
return x_final
|
||||||
|
|
||||||
|
def _absolute_position_to_relative_position(self, x):
|
||||||
|
"""
|
||||||
|
x: [b, h, l, l]
|
||||||
|
ret: [b, h, l, 2*l-1]
|
||||||
|
"""
|
||||||
|
batch, heads, length, _ = x.size()
|
||||||
|
# padd along column
|
||||||
|
x = F.pad(
|
||||||
|
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||||
|
)
|
||||||
|
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||||
|
# add 0's in the beginning that will skew the elements after reshape
|
||||||
|
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||||
|
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||||
|
return x_final
|
||||||
|
|
||||||
|
def _attention_bias_proximal(self, length):
|
||||||
|
"""Bias for self-attention to encourage attention to close positions.
|
||||||
|
Args:
|
||||||
|
length: an integer scalar.
|
||||||
|
Returns:
|
||||||
|
a Tensor with shape [1, 1, length, length]
|
||||||
|
"""
|
||||||
|
r = torch.arange(length, dtype=torch.float32)
|
||||||
|
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||||
|
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
||||||
|
|
||||||
|
|
||||||
|
class FFN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
filter_channels,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout=0.0,
|
||||||
|
activation=None,
|
||||||
|
causal=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.filter_channels = filter_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
self.activation = activation
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
self.padding = self._causal_padding
|
||||||
|
else:
|
||||||
|
self.padding = self._same_padding
|
||||||
|
|
||||||
|
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||||
|
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||||
|
self.drop = nn.Dropout(p_dropout)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
x = self.conv_1(self.padding(x * x_mask))
|
||||||
|
if self.activation == "gelu":
|
||||||
|
x = x * torch.sigmoid(1.702 * x)
|
||||||
|
else:
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.conv_2(self.padding(x * x_mask))
|
||||||
|
return x * x_mask
|
||||||
|
|
||||||
|
def _causal_padding(self, x):
|
||||||
|
if self.kernel_size == 1:
|
||||||
|
return x
|
||||||
|
pad_l = self.kernel_size - 1
|
||||||
|
pad_r = 0
|
||||||
|
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||||
|
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _same_padding(self, x):
|
||||||
|
if self.kernel_size == 1:
|
||||||
|
return x
|
||||||
|
pad_l = (self.kernel_size - 1) // 2
|
||||||
|
pad_r = self.kernel_size // 2
|
||||||
|
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||||
|
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||||
|
return x
|
920
GPT_SoVITS/module/models_onnx.py
Normal file
920
GPT_SoVITS/module/models_onnx.py
Normal file
@ -0,0 +1,920 @@
|
|||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from module import commons
|
||||||
|
from module import modules
|
||||||
|
from module import attentions_onnx as attentions
|
||||||
|
|
||||||
|
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||||
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||||
|
from module.commons import init_weights, get_padding
|
||||||
|
from module.mrte_model import MRTE
|
||||||
|
from module.quantize import ResidualVectorQuantizer
|
||||||
|
from text import symbols
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
|
|
||||||
|
class StochasticDurationPredictor(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
filter_channels,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout,
|
||||||
|
n_flows=4,
|
||||||
|
gin_channels=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
filter_channels = in_channels # it needs to be removed from future version.
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.filter_channels = filter_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
self.n_flows = n_flows
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
|
self.log_flow = modules.Log()
|
||||||
|
self.flows = nn.ModuleList()
|
||||||
|
self.flows.append(modules.ElementwiseAffine(2))
|
||||||
|
for i in range(n_flows):
|
||||||
|
self.flows.append(
|
||||||
|
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||||
|
)
|
||||||
|
self.flows.append(modules.Flip())
|
||||||
|
|
||||||
|
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||||
|
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||||
|
self.post_convs = modules.DDSConv(
|
||||||
|
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||||
|
)
|
||||||
|
self.post_flows = nn.ModuleList()
|
||||||
|
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||||
|
for i in range(4):
|
||||||
|
self.post_flows.append(
|
||||||
|
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||||
|
)
|
||||||
|
self.post_flows.append(modules.Flip())
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||||
|
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||||
|
self.convs = modules.DDSConv(
|
||||||
|
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||||
|
)
|
||||||
|
if gin_channels != 0:
|
||||||
|
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
||||||
|
x = torch.detach(x)
|
||||||
|
x = self.pre(x)
|
||||||
|
if g is not None:
|
||||||
|
g = torch.detach(g)
|
||||||
|
x = x + self.cond(g)
|
||||||
|
x = self.convs(x, x_mask)
|
||||||
|
x = self.proj(x) * x_mask
|
||||||
|
|
||||||
|
if not reverse:
|
||||||
|
flows = self.flows
|
||||||
|
assert w is not None
|
||||||
|
|
||||||
|
logdet_tot_q = 0
|
||||||
|
h_w = self.post_pre(w)
|
||||||
|
h_w = self.post_convs(h_w, x_mask)
|
||||||
|
h_w = self.post_proj(h_w) * x_mask
|
||||||
|
e_q = (
|
||||||
|
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
||||||
|
* x_mask
|
||||||
|
)
|
||||||
|
z_q = e_q
|
||||||
|
for flow in self.post_flows:
|
||||||
|
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||||
|
logdet_tot_q += logdet_q
|
||||||
|
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||||
|
u = torch.sigmoid(z_u) * x_mask
|
||||||
|
z0 = (w - u) * x_mask
|
||||||
|
logdet_tot_q += torch.sum(
|
||||||
|
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||||
|
)
|
||||||
|
logq = (
|
||||||
|
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||||
|
- logdet_tot_q
|
||||||
|
)
|
||||||
|
|
||||||
|
logdet_tot = 0
|
||||||
|
z0, logdet = self.log_flow(z0, x_mask)
|
||||||
|
logdet_tot += logdet
|
||||||
|
z = torch.cat([z0, z1], 1)
|
||||||
|
for flow in flows:
|
||||||
|
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
|
logdet_tot = logdet_tot + logdet
|
||||||
|
nll = (
|
||||||
|
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||||
|
- logdet_tot
|
||||||
|
)
|
||||||
|
return nll + logq # [b]
|
||||||
|
else:
|
||||||
|
flows = list(reversed(self.flows))
|
||||||
|
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||||
|
z = (
|
||||||
|
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||||
|
* noise_scale
|
||||||
|
)
|
||||||
|
for flow in flows:
|
||||||
|
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
|
z0, z1 = torch.split(z, [1, 1], 1)
|
||||||
|
logw = z0
|
||||||
|
return logw
|
||||||
|
|
||||||
|
|
||||||
|
class DurationPredictor(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.filter_channels = filter_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
|
self.drop = nn.Dropout(p_dropout)
|
||||||
|
self.conv_1 = nn.Conv1d(
|
||||||
|
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||||
|
self.conv_2 = nn.Conv1d(
|
||||||
|
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||||
|
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||||
|
|
||||||
|
if gin_channels != 0:
|
||||||
|
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None):
|
||||||
|
x = torch.detach(x)
|
||||||
|
if g is not None:
|
||||||
|
g = torch.detach(g)
|
||||||
|
x = x + self.cond(g)
|
||||||
|
x = self.conv_1(x * x_mask)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.norm_1(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.conv_2(x * x_mask)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.norm_2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.proj(x * x_mask)
|
||||||
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_channels,
|
||||||
|
hidden_channels,
|
||||||
|
filter_channels,
|
||||||
|
n_heads,
|
||||||
|
n_layers,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout,
|
||||||
|
latent_channels=192,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.filter_channels = filter_channels
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
|
||||||
|
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
||||||
|
|
||||||
|
self.encoder_ssl = attentions.Encoder(
|
||||||
|
hidden_channels,
|
||||||
|
filter_channels,
|
||||||
|
n_heads,
|
||||||
|
n_layers // 2,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder_text = attentions.Encoder(
|
||||||
|
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||||
|
)
|
||||||
|
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
||||||
|
|
||||||
|
self.mrte = MRTE()
|
||||||
|
|
||||||
|
self.encoder2 = attentions.Encoder(
|
||||||
|
hidden_channels,
|
||||||
|
filter_channels,
|
||||||
|
n_heads,
|
||||||
|
n_layers // 2,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(self, y, text, ge):
|
||||||
|
y_mask = torch.ones_like(y[:1,:1,:])
|
||||||
|
|
||||||
|
y = self.ssl_proj(y * y_mask) * y_mask
|
||||||
|
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||||
|
|
||||||
|
text_mask = torch.ones_like(text).to(y.dtype).unsqueeze(0)
|
||||||
|
|
||||||
|
text = self.text_embedding(text).transpose(1, 2)
|
||||||
|
text = self.encoder_text(text * text_mask, text_mask)
|
||||||
|
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||||
|
|
||||||
|
y = self.encoder2(y * y_mask, y_mask)
|
||||||
|
|
||||||
|
stats = self.proj(y) * y_mask
|
||||||
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
return y, m, logs, y_mask
|
||||||
|
|
||||||
|
def extract_latent(self, x):
|
||||||
|
x = self.ssl_proj(x)
|
||||||
|
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
|
||||||
|
return codes.transpose(0, 1)
|
||||||
|
|
||||||
|
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
|
||||||
|
quantized = self.quantizer.decode(codes)
|
||||||
|
|
||||||
|
y = self.vq_proj(quantized) * y_mask
|
||||||
|
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||||
|
|
||||||
|
y = self.mrte(y, y_mask, refer, refer_mask, ge)
|
||||||
|
|
||||||
|
y = self.encoder2(y * y_mask, y_mask)
|
||||||
|
|
||||||
|
stats = self.proj(y) * y_mask
|
||||||
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
return y, m, logs, y_mask, quantized
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCouplingBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
n_layers,
|
||||||
|
n_flows=4,
|
||||||
|
gin_channels=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.n_flows = n_flows
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
|
self.flows = nn.ModuleList()
|
||||||
|
for i in range(n_flows):
|
||||||
|
self.flows.append(
|
||||||
|
modules.ResidualCouplingLayer(
|
||||||
|
channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
n_layers,
|
||||||
|
gin_channels=gin_channels,
|
||||||
|
mean_only=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.flows.append(modules.Flip())
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
if not reverse:
|
||||||
|
for flow in self.flows:
|
||||||
|
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||||
|
else:
|
||||||
|
for flow in reversed(self.flows):
|
||||||
|
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PosteriorEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
n_layers,
|
||||||
|
gin_channels=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
|
self.enc = modules.WN(
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
n_layers,
|
||||||
|
gin_channels=gin_channels,
|
||||||
|
)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_lengths, g=None):
|
||||||
|
if g != None:
|
||||||
|
g = g.detach()
|
||||||
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||||
|
x.dtype
|
||||||
|
)
|
||||||
|
x = self.pre(x) * x_mask
|
||||||
|
x = self.enc(x, x_mask, g=g)
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||||
|
return z, m, logs, x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class WNEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
n_layers,
|
||||||
|
gin_channels=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
|
self.enc = modules.WN(
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
n_layers,
|
||||||
|
gin_channels=gin_channels,
|
||||||
|
)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
|
self.norm = modules.LayerNorm(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x, x_lengths, g=None):
|
||||||
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||||
|
x.dtype
|
||||||
|
)
|
||||||
|
x = self.pre(x) * x_mask
|
||||||
|
x = self.enc(x, x_mask, g=g)
|
||||||
|
out = self.proj(x) * x_mask
|
||||||
|
out = self.norm(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_channel,
|
||||||
|
resblock,
|
||||||
|
resblock_kernel_sizes,
|
||||||
|
resblock_dilation_sizes,
|
||||||
|
upsample_rates,
|
||||||
|
upsample_initial_channel,
|
||||||
|
upsample_kernel_sizes,
|
||||||
|
gin_channels=0,
|
||||||
|
):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
self.conv_pre = Conv1d(
|
||||||
|
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||||
|
)
|
||||||
|
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
weight_norm(
|
||||||
|
ConvTranspose1d(
|
||||||
|
upsample_initial_channel // (2**i),
|
||||||
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for j, (k, d) in enumerate(
|
||||||
|
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||||
|
):
|
||||||
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
|
|
||||||
|
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||||
|
self.ups.apply(init_weights)
|
||||||
|
|
||||||
|
if gin_channels != 0:
|
||||||
|
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||||
|
|
||||||
|
def forward(self, x, g=None):
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
if g is not None:
|
||||||
|
x = x + self.cond(g)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
print("Removing weight norm...")
|
||||||
|
for l in self.ups:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorP(torch.nn.Module):
|
||||||
|
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorP, self).__init__()
|
||||||
|
self.period = period
|
||||||
|
self.use_spectral_norm = use_spectral_norm
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
1,
|
||||||
|
32,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(kernel_size, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
32,
|
||||||
|
128,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(kernel_size, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
128,
|
||||||
|
512,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(kernel_size, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(kernel_size, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
(kernel_size, 1),
|
||||||
|
1,
|
||||||
|
padding=(get_padding(kernel_size, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
# 1d to 2d
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # pad first
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorS, self).__init__()
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||||
|
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||||
|
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
|
periods = [2, 3, 5, 7, 11]
|
||||||
|
|
||||||
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||||
|
discs = discs + [
|
||||||
|
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||||||
|
]
|
||||||
|
self.discriminators = nn.ModuleList(discs)
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
y_d_r, fmap_r = d(y)
|
||||||
|
y_d_g, fmap_g = d(y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
inputs --- [N, Ty/r, n_mels*r] mels
|
||||||
|
outputs --- [N, ref_enc_gru_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, spec_channels, gin_channels=0):
|
||||||
|
super().__init__()
|
||||||
|
self.spec_channels = spec_channels
|
||||||
|
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
||||||
|
K = len(ref_enc_filters)
|
||||||
|
filters = [1] + ref_enc_filters
|
||||||
|
convs = [
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=filters[i],
|
||||||
|
out_channels=filters[i + 1],
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
stride=(2, 2),
|
||||||
|
padding=(1, 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(K)
|
||||||
|
]
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
|
||||||
|
|
||||||
|
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_size=ref_enc_filters[-1] * out_channels,
|
||||||
|
hidden_size=256 // 2,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.proj = nn.Linear(128, gin_channels)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
N = inputs.size(0)
|
||||||
|
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
||||||
|
for conv in self.convs:
|
||||||
|
out = conv(out)
|
||||||
|
# out = wn(out)
|
||||||
|
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
||||||
|
|
||||||
|
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
||||||
|
T = out.size(1)
|
||||||
|
N = out.size(0)
|
||||||
|
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
||||||
|
|
||||||
|
self.gru.flatten_parameters()
|
||||||
|
memory, out = self.gru(out) # out --- [1, N, 128]
|
||||||
|
|
||||||
|
return self.proj(out.squeeze(0)).unsqueeze(-1)
|
||||||
|
|
||||||
|
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
||||||
|
for i in range(n_convs):
|
||||||
|
L = (L - kernel_size + 2 * pad) // stride + 1
|
||||||
|
return L
|
||||||
|
|
||||||
|
|
||||||
|
class Quantizer_module(torch.nn.Module):
|
||||||
|
def __init__(self, n_e, e_dim):
|
||||||
|
super(Quantizer_module, self).__init__()
|
||||||
|
self.embedding = nn.Embedding(n_e, e_dim)
|
||||||
|
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
d = (
|
||||||
|
torch.sum(x**2, 1, keepdim=True)
|
||||||
|
+ torch.sum(self.embedding.weight**2, 1)
|
||||||
|
- 2 * torch.matmul(x, self.embedding.weight.T)
|
||||||
|
)
|
||||||
|
min_indicies = torch.argmin(d, 1)
|
||||||
|
z_q = self.embedding(min_indicies)
|
||||||
|
return z_q, min_indicies
|
||||||
|
|
||||||
|
|
||||||
|
class Quantizer(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
|
||||||
|
super(Quantizer, self).__init__()
|
||||||
|
assert embed_dim % n_code_groups == 0
|
||||||
|
self.quantizer_modules = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
||||||
|
for _ in range(n_code_groups)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.n_code_groups = n_code_groups
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def forward(self, xin):
|
||||||
|
# B, C, T
|
||||||
|
B, C, T = xin.shape
|
||||||
|
xin = xin.transpose(1, 2)
|
||||||
|
x = xin.reshape(-1, self.embed_dim)
|
||||||
|
x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
|
||||||
|
min_indicies = []
|
||||||
|
z_q = []
|
||||||
|
for _x, m in zip(x, self.quantizer_modules):
|
||||||
|
_z_q, _min_indicies = m(_x)
|
||||||
|
z_q.append(_z_q)
|
||||||
|
min_indicies.append(_min_indicies) # B * T,
|
||||||
|
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||||
|
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
||||||
|
(z_q - xin.detach()) ** 2
|
||||||
|
)
|
||||||
|
z_q = xin + (z_q - xin).detach()
|
||||||
|
z_q = z_q.transpose(1, 2)
|
||||||
|
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||||
|
return z_q, loss, codes.transpose(1, 2)
|
||||||
|
|
||||||
|
def embed(self, x):
|
||||||
|
# idx: N, 4, T
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = torch.split(x, 1, 2)
|
||||||
|
ret = []
|
||||||
|
for q, embed in zip(x, self.quantizer_modules):
|
||||||
|
q = embed.embedding(q.squeeze(-1))
|
||||||
|
ret.append(q)
|
||||||
|
ret = torch.cat(ret, -1)
|
||||||
|
return ret.transpose(1, 2) # N, C, T
|
||||||
|
|
||||||
|
|
||||||
|
class CodePredictor(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
filter_channels,
|
||||||
|
n_heads,
|
||||||
|
n_layers,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout,
|
||||||
|
n_q=8,
|
||||||
|
dims=1024,
|
||||||
|
ssl_dim=768,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.filter_channels = filter_channels
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
|
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||||
|
self.ref_enc = modules.MelStyleEncoder(
|
||||||
|
ssl_dim, style_vector_dim=hidden_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = attentions.Encoder(
|
||||||
|
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||||
|
self.n_q = n_q
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, refer, codes, infer=False):
|
||||||
|
x = x.detach()
|
||||||
|
x = self.vq_proj(x * x_mask) * x_mask
|
||||||
|
g = self.ref_enc(refer, x_mask)
|
||||||
|
x = x + g
|
||||||
|
x = self.encoder(x * x_mask, x_mask)
|
||||||
|
x = self.out_proj(x * x_mask) * x_mask
|
||||||
|
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
||||||
|
2, 3
|
||||||
|
)
|
||||||
|
target = codes[1:].transpose(0, 1)
|
||||||
|
if not infer:
|
||||||
|
logits = logits.reshape(-1, self.dims)
|
||||||
|
target = target.reshape(-1)
|
||||||
|
loss = torch.nn.functional.cross_entropy(logits, target)
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
_, top10_preds = torch.topk(logits, 10, dim=-1)
|
||||||
|
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
|
||||||
|
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
|
||||||
|
|
||||||
|
print("Top-10 Accuracy:", top3_acc, "%")
|
||||||
|
|
||||||
|
pred_codes = torch.argmax(logits, dim=-1)
|
||||||
|
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
|
||||||
|
print("Top-1 Accuracy:", acc, "%")
|
||||||
|
|
||||||
|
return pred_codes.transpose(0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class SynthesizerTrn(nn.Module):
|
||||||
|
"""
|
||||||
|
Synthesizer for Training
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spec_channels,
|
||||||
|
segment_size,
|
||||||
|
inter_channels,
|
||||||
|
hidden_channels,
|
||||||
|
filter_channels,
|
||||||
|
n_heads,
|
||||||
|
n_layers,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout,
|
||||||
|
resblock,
|
||||||
|
resblock_kernel_sizes,
|
||||||
|
resblock_dilation_sizes,
|
||||||
|
upsample_rates,
|
||||||
|
upsample_initial_channel,
|
||||||
|
upsample_kernel_sizes,
|
||||||
|
n_speakers=0,
|
||||||
|
gin_channels=0,
|
||||||
|
use_sdp=True,
|
||||||
|
semantic_frame_rate=None,
|
||||||
|
freeze_quantizer=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.spec_channels = spec_channels
|
||||||
|
self.inter_channels = inter_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.filter_channels = filter_channels
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
self.resblock = resblock
|
||||||
|
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||||
|
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||||
|
self.upsample_rates = upsample_rates
|
||||||
|
self.upsample_initial_channel = upsample_initial_channel
|
||||||
|
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||||
|
self.segment_size = segment_size
|
||||||
|
self.n_speakers = n_speakers
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
|
self.use_sdp = use_sdp
|
||||||
|
self.enc_p = TextEncoder(
|
||||||
|
inter_channels,
|
||||||
|
hidden_channels,
|
||||||
|
filter_channels,
|
||||||
|
n_heads,
|
||||||
|
n_layers,
|
||||||
|
kernel_size,
|
||||||
|
p_dropout,
|
||||||
|
)
|
||||||
|
self.dec = Generator(
|
||||||
|
inter_channels,
|
||||||
|
resblock,
|
||||||
|
resblock_kernel_sizes,
|
||||||
|
resblock_dilation_sizes,
|
||||||
|
upsample_rates,
|
||||||
|
upsample_initial_channel,
|
||||||
|
upsample_kernel_sizes,
|
||||||
|
gin_channels=gin_channels,
|
||||||
|
)
|
||||||
|
self.enc_q = PosteriorEncoder(
|
||||||
|
spec_channels,
|
||||||
|
inter_channels,
|
||||||
|
hidden_channels,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
16,
|
||||||
|
gin_channels=gin_channels,
|
||||||
|
)
|
||||||
|
self.flow = ResidualCouplingBlock(
|
||||||
|
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ref_enc = modules.MelStyleEncoder(
|
||||||
|
spec_channels, style_vector_dim=gin_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
ssl_dim = 768
|
||||||
|
self.ssl_dim = ssl_dim
|
||||||
|
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||||
|
self.semantic_frame_rate = semantic_frame_rate
|
||||||
|
if semantic_frame_rate == "25hz":
|
||||||
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||||
|
else:
|
||||||
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||||
|
|
||||||
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||||
|
if freeze_quantizer:
|
||||||
|
self.ssl_proj.requires_grad_(False)
|
||||||
|
self.quantizer.requires_grad_(False)
|
||||||
|
# self.enc_p.text_embedding.requires_grad_(False)
|
||||||
|
# self.enc_p.encoder_text.requires_grad_(False)
|
||||||
|
# self.enc_p.mrte.requires_grad_(False)
|
||||||
|
|
||||||
|
def forward(self, codes, text, refer):
|
||||||
|
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||||
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
|
|
||||||
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||||
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||||
|
|
||||||
|
quantized = self.quantizer.decode(codes)
|
||||||
|
if self.semantic_frame_rate == "25hz":
|
||||||
|
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||||
|
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||||
|
|
||||||
|
x, m_p, logs_p, y_mask = self.enc_p(
|
||||||
|
quantized, text, ge
|
||||||
|
)
|
||||||
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
|
||||||
|
|
||||||
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
|
||||||
|
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def extract_latent(self, x):
|
||||||
|
ssl = self.ssl_proj(x)
|
||||||
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
|
return codes.transpose(0, 1)
|
314
GPT_SoVITS/onnx_export.py
Normal file
314
GPT_SoVITS/onnx_export.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
from module.models_onnx import SynthesizerTrn, symbols
|
||||||
|
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch import nn
|
||||||
|
from feature_extractor import cnhubert
|
||||||
|
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
|
||||||
|
cnhubert.cnhubert_base_path=cnhubert_base_path
|
||||||
|
ssl_model = cnhubert.get_model()
|
||||||
|
from text import cleaned_text_to_sequence
|
||||||
|
import soundfile
|
||||||
|
from my_utils import load_audio
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||||
|
hann_window = torch.hann_window(win_size).to(
|
||||||
|
dtype=y.dtype, device=y.device
|
||||||
|
)
|
||||||
|
y = torch.nn.functional.pad(
|
||||||
|
y.unsqueeze(1),
|
||||||
|
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||||
|
mode="reflect",
|
||||||
|
)
|
||||||
|
y = y.squeeze(1)
|
||||||
|
spec = torch.stft(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_size,
|
||||||
|
win_length=win_size,
|
||||||
|
window=hann_window,
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=False,
|
||||||
|
)
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
class DictToAttrRecursive(dict):
|
||||||
|
def __init__(self, input_dict):
|
||||||
|
super().__init__(input_dict)
|
||||||
|
for key, value in input_dict.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = DictToAttrRecursive(value)
|
||||||
|
self[key] = value
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
try:
|
||||||
|
return self[item]
|
||||||
|
except KeyError:
|
||||||
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = DictToAttrRecursive(value)
|
||||||
|
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __delattr__(self, item):
|
||||||
|
try:
|
||||||
|
del self[item]
|
||||||
|
except KeyError:
|
||||||
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
|
||||||
|
class T2SEncoder(nn.Module):
|
||||||
|
def __init__(self, t2s, vits):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = t2s.onnx_encoder
|
||||||
|
self.vits = vits
|
||||||
|
|
||||||
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||||
|
codes = self.vits.extract_latent(ssl_content)
|
||||||
|
prompt_semantic = codes[0, 0]
|
||||||
|
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
||||||
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
|
bert = bert.unsqueeze(0)
|
||||||
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
|
return self.encoder(all_phoneme_ids, bert), prompt
|
||||||
|
|
||||||
|
|
||||||
|
class T2SModel(nn.Module):
|
||||||
|
def __init__(self, t2s_path, vits_model):
|
||||||
|
super().__init__()
|
||||||
|
dict_s1 = torch.load(t2s_path, map_location="cpu")
|
||||||
|
self.config = dict_s1["config"]
|
||||||
|
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
|
||||||
|
self.t2s_model.load_state_dict(dict_s1["weight"])
|
||||||
|
self.t2s_model.eval()
|
||||||
|
self.vits_model = vits_model.vq_model
|
||||||
|
self.hz = 50
|
||||||
|
self.max_sec = self.config["data"]["max_sec"]
|
||||||
|
self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
|
||||||
|
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||||
|
self.t2s_model = self.t2s_model.model
|
||||||
|
self.t2s_model.init_onnx()
|
||||||
|
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
||||||
|
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
||||||
|
self.stage_decoder = self.t2s_model.stage_decoder
|
||||||
|
#self.t2s_model = torch.jit.script(self.t2s_model)
|
||||||
|
|
||||||
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||||
|
early_stop_num = self.t2s_model.early_stop_num
|
||||||
|
|
||||||
|
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
|
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
|
prefix_len = prompts.shape[1]
|
||||||
|
|
||||||
|
#[1,N,512] [1,N]
|
||||||
|
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
|
|
||||||
|
stop = False
|
||||||
|
for idx in range(1, 1500):
|
||||||
|
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
|
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||||
|
y, k, v, y_emb, logits, samples = enco
|
||||||
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
|
stop = True
|
||||||
|
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
|
stop = True
|
||||||
|
if stop:
|
||||||
|
break
|
||||||
|
y[0, -1] = 0
|
||||||
|
|
||||||
|
return y[:, -idx:].unsqueeze(0)
|
||||||
|
|
||||||
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
||||||
|
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
||||||
|
if dynamo:
|
||||||
|
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||||
|
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
||||||
|
self.onnx_encoder,
|
||||||
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||||
|
export_options=export_options
|
||||||
|
)
|
||||||
|
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
||||||
|
return
|
||||||
|
torch.onnx.export(
|
||||||
|
self.onnx_encoder,
|
||||||
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||||
|
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
|
||||||
|
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||||
|
output_names=["x", "prompts"],
|
||||||
|
dynamic_axes={
|
||||||
|
"ref_seq": [1],
|
||||||
|
"text_seq": [1],
|
||||||
|
"ref_bert": [0],
|
||||||
|
"text_bert": [0],
|
||||||
|
"ssl_content": [2],
|
||||||
|
},
|
||||||
|
opset_version=16
|
||||||
|
)
|
||||||
|
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
torch.exp
|
||||||
|
torch.onnx.export(
|
||||||
|
self.first_stage_decoder,
|
||||||
|
(x, prompts),
|
||||||
|
f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
|
||||||
|
input_names=["x", "prompts"],
|
||||||
|
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": [1],
|
||||||
|
"prompts": [1],
|
||||||
|
},
|
||||||
|
verbose=True,
|
||||||
|
opset_version=16
|
||||||
|
)
|
||||||
|
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
self.stage_decoder,
|
||||||
|
(y, k, v, y_emb, x_example),
|
||||||
|
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
||||||
|
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
||||||
|
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||||
|
dynamic_axes={
|
||||||
|
"iy": [1],
|
||||||
|
"ik": [1],
|
||||||
|
"iv": [1],
|
||||||
|
"iy_emb": [1],
|
||||||
|
"ix_example": [1],
|
||||||
|
},
|
||||||
|
verbose=True,
|
||||||
|
opset_version=16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VitsModel(nn.Module):
|
||||||
|
def __init__(self, vits_path):
|
||||||
|
super().__init__()
|
||||||
|
dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||||
|
self.hps = dict_s2["config"]
|
||||||
|
self.hps = DictToAttrRecursive(self.hps)
|
||||||
|
self.hps.model.semantic_frame_rate = "25hz"
|
||||||
|
self.vq_model = SynthesizerTrn(
|
||||||
|
self.hps.data.filter_length // 2 + 1,
|
||||||
|
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||||
|
n_speakers=self.hps.data.n_speakers,
|
||||||
|
**self.hps.model
|
||||||
|
)
|
||||||
|
self.vq_model.eval()
|
||||||
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
|
|
||||||
|
def forward(self, text_seq, pred_semantic, ref_audio):
|
||||||
|
refer = spectrogram_torch(
|
||||||
|
ref_audio,
|
||||||
|
self.hps.data.filter_length,
|
||||||
|
self.hps.data.sampling_rate,
|
||||||
|
self.hps.data.hop_length,
|
||||||
|
self.hps.data.win_length,
|
||||||
|
center=False
|
||||||
|
)
|
||||||
|
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class GptSoVits(nn.Module):
|
||||||
|
def __init__(self, vits, t2s):
|
||||||
|
super().__init__()
|
||||||
|
self.vits = vits
|
||||||
|
self.t2s = t2s
|
||||||
|
|
||||||
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
|
||||||
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
return self.vits(text_seq, pred_semantic, ref_audio)
|
||||||
|
|
||||||
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
||||||
|
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
||||||
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
torch.onnx.export(
|
||||||
|
self.vits,
|
||||||
|
(text_seq, pred_semantic, ref_audio),
|
||||||
|
f"onnx/{project_name}/{project_name}_vits.onnx",
|
||||||
|
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
||||||
|
output_names=["audio"],
|
||||||
|
dynamic_axes={
|
||||||
|
"text_seq": [1],
|
||||||
|
"pred_semantic": [2],
|
||||||
|
"ref_audio": [1],
|
||||||
|
},
|
||||||
|
opset_version=17
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SSLModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ssl = ssl_model
|
||||||
|
|
||||||
|
def forward(self, ref_audio_16k):
|
||||||
|
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def export(vits_path, gpt_path, project_name):
|
||||||
|
vits = VitsModel(vits_path)
|
||||||
|
gpt = T2SModel(gpt_path, vits)
|
||||||
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
|
ssl = SSLModel()
|
||||||
|
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||||
|
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||||
|
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||||
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
|
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||||
|
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
||||||
|
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
||||||
|
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.mkdir(f"onnx/{project_name}")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
ssl_content = ssl(ref_audio_16k).float()
|
||||||
|
|
||||||
|
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||||
|
|
||||||
|
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||||
|
|
||||||
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||||
|
|
||||||
|
MoeVSConf = {
|
||||||
|
"Folder" : f"{project_name}",
|
||||||
|
"Name" : f"{project_name}",
|
||||||
|
"Type" : "GPT-SoVits",
|
||||||
|
"Rate" : vits.hps.data.sampling_rate,
|
||||||
|
"NumLayers": gpt.t2s_model.num_layers,
|
||||||
|
"EmbeddingDim": gpt.t2s_model.embedding_dim,
|
||||||
|
"Dict": "BasicDict",
|
||||||
|
"BertPath": "chinese-roberta-wwm-ext-large",
|
||||||
|
"Symbol": symbols,
|
||||||
|
"AddBlank": False
|
||||||
|
}
|
||||||
|
|
||||||
|
MoeVSConfJson = json.dumps(MoeVSConf)
|
||||||
|
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
||||||
|
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
os.mkdir("onnx")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
gpt_path = "pt_model/koharu-e20.ckpt"
|
||||||
|
vits_path = "pt_model/koharu_e20_s4960.pth"
|
||||||
|
exp_path = "koharu"
|
||||||
|
export(vits_path, gpt_path, exp_path)
|
||||||
|
|
||||||
|
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
Loading…
x
Reference in New Issue
Block a user