diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py b/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py new file mode 100644 index 0000000..bb9e30b --- /dev/null +++ b/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py @@ -0,0 +1,106 @@ +# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py +import os, sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +from typing import Dict + +import torch +from pytorch_lightning import LightningModule +from AR.models.t2s_model_onnx import Text2SemanticDecoder +from AR.modules.lr_schedulers import WarmupCosineLRSchedule +from AR.modules.optim import ScaledAdam + + +class Text2SemanticLightningModule(LightningModule): + def __init__(self, config, output_dir, is_train=True): + super().__init__() + self.config = config + self.top_k = 3 + self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) + pretrained_s1 = config.get("pretrained_s1") + if pretrained_s1 and is_train: + # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) + print( + self.load_state_dict( + torch.load(pretrained_s1, map_location="cpu")["weight"] + ) + ) + if is_train: + self.automatic_optimization = False + self.save_hyperparameters() + self.eval_dir = output_dir / "eval" + self.eval_dir.mkdir(parents=True, exist_ok=True) + + def training_step(self, batch: Dict, batch_idx: int): + opt = self.optimizers() + scheduler = self.lr_schedulers() + loss, acc = self.model.forward( + batch["phoneme_ids"], + batch["phoneme_ids_len"], + batch["semantic_ids"], + batch["semantic_ids_len"], + batch["bert_feature"], + ) + self.manual_backward(loss) + if batch_idx > 0 and batch_idx % 4 == 0: + opt.step() + opt.zero_grad() + scheduler.step() + + self.log( + "total_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + self.log( + "lr", + scheduler.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + self.log( + f"top_{self.top_k}_acc", + acc, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + def validation_step(self, batch: Dict, batch_idx: int): + return + + def configure_optimizers(self): + model_parameters = self.model.parameters() + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in self.model.named_parameters()] + ) + lm_opt = ScaledAdam( + model_parameters, + lr=0.01, + betas=(0.9, 0.95), + clipping_scale=2.0, + parameters_names=parameters_names, + show_dominant_parameters=False, + clipping_update_period=1000, + ) + + return { + "optimizer": lm_opt, + "lr_scheduler": { + "scheduler": WarmupCosineLRSchedule( + lm_opt, + init_lr=self.config["optimizer"]["lr_init"], + peak_lr=self.config["optimizer"]["lr"], + end_lr=self.config["optimizer"]["lr_end"], + warmup_steps=self.config["optimizer"]["warmup_steps"], + total_steps=self.config["optimizer"]["decay_steps"], + ) + }, + } diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py new file mode 100644 index 0000000..263b933 --- /dev/null +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -0,0 +1,337 @@ +# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py +import torch +from tqdm import tqdm + +from AR.modules.embedding_onnx import SinePositionalEmbedding +from AR.modules.embedding_onnx import TokenEmbedding +from AR.modules.transformer_onnx import LayerNorm +from AR.modules.transformer_onnx import TransformerEncoder +from AR.modules.transformer_onnx import TransformerEncoderLayer +from torch import nn +from torch.nn import functional as F +from torchmetrics.classification import MulticlassAccuracy + +default_config = { + "embedding_dim": 512, + "hidden_dim": 512, + "num_head": 8, + "num_layers": 12, + "num_codebook": 8, + "p_dropout": 0.0, + "vocab_size": 1024 + 1, + "phoneme_vocab_size": 512, + "EOS": 1024, +} + +inf_tensor_value = torch.FloatTensor([-float("Inf")]).float() + +def logits_to_probs( + logits, + previous_tokens = None, + temperature: float = 1.0, + top_k = None, + top_p = None, + repetition_penalty: float = 1.0, +): + previous_tokens = previous_tokens.squeeze() + if previous_tokens is not None and repetition_penalty != 1.0: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, score * repetition_penalty, score / repetition_penalty + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum( + torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 + ) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, inf_tensor_value, logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def multinomial_sample_one_no_sync( + probs_sort +): # Does multinomial sampling without a cuda synchronization + q = torch.randn_like(probs_sort) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def sample( + logits, + previous_tokens, + **sampling_kwargs, +): + probs = logits_to_probs( + logits=logits, previous_tokens=previous_tokens, **sampling_kwargs + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +class OnnxEncoder(nn.Module): + def __init__(self, ar_text_embedding, bert_proj, ar_text_position): + super().__init__() + self.ar_text_embedding = ar_text_embedding + self.bert_proj = bert_proj + self.ar_text_position = ar_text_position + + def forward(self, x, bert_feature): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + return self.ar_text_position(x) + + +class T2SFirstStageDecoder(nn.Module): + def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric, + top_k, early_stop_num, num_layers): + super().__init__() + self.ar_audio_embedding = ar_audio_embedding + self.ar_audio_position = ar_audio_position + self.h = h + self.ar_predict_layer = ar_predict_layer + self.loss_fct = loss_fct + self.ar_accuracy_metric = ar_accuracy_metric + self.top_k = top_k + self.early_stop_num = early_stop_num + self.num_layers = num_layers + + def forward(self, x, prompt): + y = prompt + x_example = x[:,:,0] * 0.0 + #N, 1, 512 + cache = { + "all_stage": self.num_layers, + "k": None, + "v": None, + "y_emb": None, + "first_infer": 1, + "stage": 0, + } + + y_emb = self.ar_audio_embedding(y) + + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + + xy_pos = torch.concat([x, y_pos], dim=1) + + y_example = y_pos[:,:,0] * 0.0 + x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool() + y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64) + y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( + torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0 + ) + y_attn_mask = y_attn_mask > 0 + + x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool() + y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool() + x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1) + y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\ + .unsqueeze(1).repeat(self.num_layers, 1, 1, 1) + cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\ + .unsqueeze(1).repeat(self.num_layers, 1, 1, 1) + + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + + y = torch.concat([y, samples], dim=1) + + return y, cache["k"], cache["v"], cache["y_emb"], x_example + + +class T2SStageDecoder(nn.Module): + def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric, + top_k, early_stop_num, num_layers): + super().__init__() + self.ar_audio_embedding = ar_audio_embedding + self.ar_audio_position = ar_audio_position + self.h = h + self.ar_predict_layer = ar_predict_layer + self.loss_fct = loss_fct + self.ar_accuracy_metric = ar_accuracy_metric + self.top_k = top_k + self.early_stop_num = early_stop_num + self.num_layers = num_layers + + def forward(self, y, k, v, y_emb, x_example): + cache = { + "all_stage": self.num_layers, + "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), + "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)), + "y_emb": y_emb, + "first_infer": 0, + "stage": 0, + } + + y_emb = torch.cat( + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 + ) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + + xy_pos = y_pos[:, -1:] + + y_example = y_pos[:,:,0] * 0.0 + + xy_attn_mask = torch.cat([x_example, y_example], dim=1) + xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool) + + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + + y = torch.concat([y, samples], dim=1) + + return y, cache["k"], cache["v"], cache["y_emb"], logits, samples + + +class Text2SemanticDecoder(nn.Module): + def __init__(self, config, norm_first=False, top_k=3): + super(Text2SemanticDecoder, self).__init__() + self.model_dim = config["model"]["hidden_dim"] + self.embedding_dim = config["model"]["embedding_dim"] + self.num_head = config["model"]["head"] + self.num_layers = config["model"]["n_layer"] + self.norm_first = norm_first + self.vocab_size = config["model"]["vocab_size"] + self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] + self.p_dropout = float(config["model"]["dropout"]) + self.EOS = config["model"]["EOS"] + self.norm_first = norm_first + assert self.EOS == self.vocab_size - 1 + self.bert_proj = nn.Linear(1024, self.embedding_dim) + self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) + self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True) + self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout) + self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True) + self.h = TransformerEncoder( + TransformerEncoderLayer( + d_model=self.model_dim, + nhead=self.num_head, + dim_feedforward=self.model_dim * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=self.num_layers, + norm=LayerNorm(self.model_dim) if norm_first else None, + ) + self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) + self.loss_fct = nn.CrossEntropyLoss(reduction="sum") + self.ar_accuracy_metric = MulticlassAccuracy( + self.vocab_size, + top_k=top_k, + average="micro", + multidim_average="global", + ignore_index=self.EOS, + ) + self.top_k = torch.LongTensor([1]) + self.early_stop_num = torch.LongTensor([-1]) + + def init_onnx(self): + self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position) + self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h, + self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num, + self.num_layers) + self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h, + self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num, + self.num_layers) + + def forward(self, x, prompts, bert_feature): + early_stop_num = self.early_stop_num + prefix_len = prompts.shape[1] + + x = self.onnx_encoder(x, bert_feature) + y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts) + + stop = False + for idx in range(1, 1500): + enco = self.stage_decoder(y, k, v, y_emb, stage, x_example) + y, k, v, y_emb, stage, logits, samples = enco + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + break + y[0, -1] = 0 + return y, idx + + def infer(self, x, prompts, bert_feature): + top_k = self.top_k + early_stop_num = self.early_stop_num + + x = self.onnx_encoder(x, bert_feature) + + y = prompts + prefix_len = y.shape[1] + x_len = x.shape[1] + x_example = x[:,:,0] * 0.0 + x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example) + x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool) + + stop = False + cache = { + "all_stage": self.num_layers, + "k": [None] * self.num_layers, + "v": [None] * self.num_layers, + "y_emb": None, + "first_infer": 1, + "stage": 0, + } + for idx in range(1500): + if cache["first_infer"] == 1: + y_emb = self.ar_audio_embedding(y) + else: + y_emb = torch.cat( + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 + ) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + if cache["first_infer"] == 1: + xy_pos = torch.concat([x, y_pos], dim=1) + else: + xy_pos = y_pos[:, -1:] + y_len = y_pos.shape[1] + if cache["first_infer"] == 1: + x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True) + y_attn_mask = F.pad( + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), value=False + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + else: + xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool) + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + if prompts.shape[1] == y.shape[1]: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + break + y = torch.concat([y, samples], dim=1) + cache["first_infer"] = 0 + return y, idx \ No newline at end of file diff --git a/GPT_SoVITS/AR/modules/activation_onnx.py b/GPT_SoVITS/AR/modules/activation_onnx.py new file mode 100644 index 0000000..b54acd9 --- /dev/null +++ b/GPT_SoVITS/AR/modules/activation_onnx.py @@ -0,0 +1,178 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py +from typing import Optional +from typing import Tuple +import torch +from torch import Tensor +from torch.nn import Linear +from torch.nn import Module +from torch.nn.init import constant_ +from torch.nn.init import xavier_normal_ +from torch.nn.init import xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter + +from torch.nn import functional as F +from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched + + +class MultiheadAttention(Module): + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + cache=None, + ) -> Tuple[Tensor, Optional[Tensor]]: + any_nested = query.is_nested or key.is_nested or value.is_nested + query = key = value = query.transpose(1, 0) + attn_output = multi_head_attention_forward_patched( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + cache=cache, + ) + return attn_output.transpose(1, 0) diff --git a/GPT_SoVITS/AR/modules/embedding_onnx.py b/GPT_SoVITS/AR/modules/embedding_onnx.py new file mode 100644 index 0000000..b93405b --- /dev/null +++ b/GPT_SoVITS/AR/modules/embedding_onnx.py @@ -0,0 +1,63 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py +import math + +import torch +from torch import nn + + +class TokenEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + x = self.word_embeddings(x) + x = self.dropout(x) + return x + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + self.reverse = False + self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim)) + + def extend_pe(self, x): + position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1) + scpe = (position * self.div_term).unsqueeze(0) + pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0) + pe = pe.contiguous().view(1, -1, self.embedding_dim) + return pe + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pe = self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * pe + return self.dropout(output) diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py new file mode 100644 index 0000000..14bdb55 --- /dev/null +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -0,0 +1,92 @@ +from torch.nn.functional import * +from torch.nn.functional import ( + _mha_shape_check, + _canonical_mask, + _none_or_dtype, + _in_projection_packed, +) + +def multi_head_attention_forward_patched( + query, + key, + value, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight, + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + cache=None, +) -> Tuple[Tensor, Optional[Tensor]]: + + # set up shape vars + _, _, embed_dim = query.shape + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + head_dim = embed_dim // num_heads + + proj_qkv = linear(query, in_proj_weight, in_proj_bias) + proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2] + + if cache["first_infer"] == 1: + cache["k"][cache["stage"]] = k + cache["v"][cache["stage"]] = v + else: + cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0) + cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0) + k = cache["k"][cache["stage"]] + v = cache["v"][cache["stage"]] + cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] + + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=q.dtype, + check_other=False, + ) + attn_mask = attn_mask.unsqueeze(0) + + q = q.view(-1, num_heads, head_dim).transpose(0, 1) + k = k.view(-1, num_heads, head_dim).transpose(0, 1) + v = v.view(-1, num_heads, head_dim).transpose(0, 1) + + dropout_p = 0.0 + attn_mask = attn_mask.unsqueeze(0) + q = q.view(num_heads, -1, head_dim).unsqueeze(0) + k = k.view(num_heads, -1, head_dim).unsqueeze(0) + v = v.view(num_heads, -1, head_dim).unsqueeze(0) + attn_output = scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal + ) + attn_output = ( + attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim) + ) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(-1, 1, attn_output.size(1)) + + return attn_output diff --git a/GPT_SoVITS/AR/modules/transformer_onnx.py b/GPT_SoVITS/AR/modules/transformer_onnx.py new file mode 100644 index 0000000..a3f68b4 --- /dev/null +++ b/GPT_SoVITS/AR/modules/transformer_onnx.py @@ -0,0 +1,292 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py +import copy +import numbers +from functools import partial +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +from AR.modules.activation_onnx import MultiheadAttention +from AR.modules.scaling import BalancedDoubleSwish +from torch import nn +from torch import Tensor +from torch.nn import functional as F + +_shape_t = Union[int, List[int], torch.Size] + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class IdentityNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ) -> None: + super(IdentityNorm, self).__init__() + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + return input + + assert embedding is None + return input + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + cache=None, + ) -> Tensor: + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + cache=cache, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, # 512 16 + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + elif activation == BalancedDoubleSwish: + activation = BalancedDoubleSwish(d_model) + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + if layer_norm_cls == IdentityNorm: + norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + else: + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + cache=None, + ) -> Tensor: + x = src + stage_embedding = None + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + return x + + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + cache=None, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + cache=cache, + ) + return self.dropout1(x) + + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)])