From 02da15c996dca916c3ff29327ef5ac9a466b92dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 25 Jan 2024 02:30:08 +0800 Subject: [PATCH 1/3] Add Onnx Export --- GPT_SoVITS/onnx_export.py | 314 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 GPT_SoVITS/onnx_export.py diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py new file mode 100644 index 0000000..f08679f --- /dev/null +++ b/GPT_SoVITS/onnx_export.py @@ -0,0 +1,314 @@ +from module.models_onnx import SynthesizerTrn, symbols +from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule +import torch +import torchaudio +from torch import nn +from feature_extractor import cnhubert +cnhubert_base_path = "pretrained_models/chinese-hubert-base" +cnhubert.cnhubert_base_path=cnhubert_base_path +ssl_model = cnhubert.get_model() +from text import cleaned_text_to_sequence +import soundfile +from my_utils import load_audio +import os +import json + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + hann_window = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +class T2SEncoder(nn.Module): + def __init__(self, t2s, vits): + super().__init__() + self.encoder = t2s.onnx_encoder + self.vits = vits + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + codes = self.vits.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + prompt = prompt_semantic.unsqueeze(0) + return self.encoder(all_phoneme_ids, bert), prompt + + +class T2SModel(nn.Module): + def __init__(self, t2s_path, vits_model): + super().__init__() + dict_s1 = torch.load(t2s_path, map_location="cpu") + self.config = dict_s1["config"] + self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False) + self.t2s_model.load_state_dict(dict_s1["weight"]) + self.t2s_model.eval() + self.vits_model = vits_model.vq_model + self.hz = 50 + self.max_sec = self.config["data"]["max_sec"] + self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]]) + self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) + self.t2s_model = self.t2s_model.model + self.t2s_model.init_onnx() + self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model) + self.first_stage_decoder = self.t2s_model.first_stage_decoder + self.stage_decoder = self.t2s_model.stage_decoder + #self.t2s_model = torch.jit.script(self.t2s_model) + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + early_stop_num = self.t2s_model.early_stop_num + + #[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] + x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + + prefix_len = prompts.shape[1] + + #[1,N,512] [1,N] + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + + stop = False + for idx in range(1, 1500): + #[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + enco = self.stage_decoder(y, k, v, y_emb, x_example) + y, k, v, y_emb, logits, samples = enco + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: + stop = True + if stop: + break + y[0, -1] = 0 + + return y[:, -idx:].unsqueeze(0) + + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False): + #self.onnx_encoder = torch.jit.script(self.onnx_encoder) + if dynamo: + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + onnx_encoder_export_output = torch.onnx.dynamo_export( + self.onnx_encoder, + (ref_seq, text_seq, ref_bert, text_bert, ssl_content), + export_options=export_options + ) + onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx") + return + torch.onnx.export( + self.onnx_encoder, + (ref_seq, text_seq, ref_bert, text_bert, ssl_content), + f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", + input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], + output_names=["x", "prompts"], + dynamic_axes={ + "ref_seq": [1], + "text_seq": [1], + "ref_bert": [0], + "text_bert": [0], + "ssl_content": [2], + }, + opset_version=16 + ) + x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + torch.exp + torch.onnx.export( + self.first_stage_decoder, + (x, prompts), + f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", + input_names=["x", "prompts"], + output_names=["y", "k", "v", "y_emb", "x_example"], + dynamic_axes={ + "x": [1], + "prompts": [1], + }, + verbose=True, + opset_version=16 + ) + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + + torch.onnx.export( + self.stage_decoder, + (y, k, v, y_emb, x_example), + f"onnx/{project_name}/{project_name}_t2s_sdec.onnx", + input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], + output_names=["y", "k", "v", "y_emb", "logits", "samples"], + dynamic_axes={ + "iy": [1], + "ik": [1], + "iv": [1], + "iy_emb": [1], + "ix_example": [1], + }, + verbose=True, + opset_version=16 + ) + + +class VitsModel(nn.Module): + def __init__(self, vits_path): + super().__init__() + dict_s2 = torch.load(vits_path,map_location="cpu") + self.hps = dict_s2["config"] + self.hps = DictToAttrRecursive(self.hps) + self.hps.model.semantic_frame_rate = "25hz" + self.vq_model = SynthesizerTrn( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + n_speakers=self.hps.data.n_speakers, + **self.hps.model + ) + self.vq_model.eval() + self.vq_model.load_state_dict(dict_s2["weight"], strict=False) + + def forward(self, text_seq, pred_semantic, ref_audio): + refer = spectrogram_torch( + ref_audio, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False + ) + return self.vq_model(pred_semantic, text_seq, refer)[0, 0] + + +class GptSoVits(nn.Module): + def __init__(self, vits, t2s): + super().__init__() + self.vits = vits + self.t2s = t2s + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content): + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + return self.vits(text_seq, pred_semantic, ref_audio) + + def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name): + self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + torch.onnx.export( + self.vits, + (text_seq, pred_semantic, ref_audio), + f"onnx/{project_name}/{project_name}_vits.onnx", + input_names=["text_seq", "pred_semantic", "ref_audio"], + output_names=["audio"], + dynamic_axes={ + "text_seq": [1], + "pred_semantic": [2], + "ref_audio": [1], + }, + opset_version=17 + ) + + +class SSLModel(nn.Module): + def __init__(self): + super().__init__() + self.ssl = ssl_model + + def forward(self, ref_audio_16k): + return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + + +def export(vits_path, gpt_path, project_name): + vits = VitsModel(vits_path) + gpt = T2SModel(gpt_path, vits) + gpt_sovits = GptSoVits(vits, gpt) + ssl = SSLModel() + ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) + text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) + ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() + text_bert = torch.randn((text_seq.shape[1], 1024)).float() + ref_audio = torch.randn((1, 48000 * 5)).float() + # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float() + ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float() + ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float() + + try: + os.mkdir(f"onnx/{project_name}") + except: + pass + + ssl_content = ssl(ref_audio_16k).float() + + a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() + + # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) + + gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) + + MoeVSConf = { + "Folder" : f"{project_name}", + "Name" : f"{project_name}", + "Type" : "GPT-SoVits", + "Rate" : vits.hps.data.sampling_rate, + "NumLayers": gpt.t2s_model.num_layers, + "EmbeddingDim": gpt.t2s_model.embedding_dim, + "Dict": "BasicDict", + "BertPath": "chinese-roberta-wwm-ext-large", + "Symbol": symbols, + "AddBlank": False + } + + MoeVSConfJson = json.dumps(MoeVSConf) + with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile: + json.dump(MoeVSConf, MoeVsConfFile, indent = 4) + + +if __name__ == "__main__": + try: + os.mkdir("onnx") + except: + pass + + gpt_path = "pt_model/koharu-e20.ckpt" + vits_path = "pt_model/koharu_e20_s4960.pth" + exp_path = "koharu" + export(vits_path, gpt_path, exp_path) + + # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) \ No newline at end of file From bd68358c3f675300f028fe602733456e397d7f3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 25 Jan 2024 02:30:37 +0800 Subject: [PATCH 2/3] Add Vits Onnx Module --- GPT_SoVITS/module/attentions_onnx.py | 365 +++++++++++ GPT_SoVITS/module/models_onnx.py | 920 +++++++++++++++++++++++++++ 2 files changed, 1285 insertions(+) create mode 100644 GPT_SoVITS/module/attentions_onnx.py create mode 100644 GPT_SoVITS/module/models_onnx.py diff --git a/GPT_SoVITS/module/attentions_onnx.py b/GPT_SoVITS/module/attentions_onnx.py new file mode 100644 index 0000000..df0ae82 --- /dev/null +++ b/GPT_SoVITS/module/attentions_onnx.py @@ -0,0 +1,365 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from module import commons +from module.modules import LayerNorm + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + isflow=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + # if isflow: + # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) + # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) + # self.cond_layer = weight_norm(cond_layer, name='weight') + # self.gin_channels = 256 + self.cond_layer_idx = self.n_layers + if "gin_channels" in kwargs: + self.gin_channels = kwargs["gin_channels"] + if self.gin_channels != 0: + self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) + # vits2 says 3rd block, so idx is 2 by default + self.cond_layer_idx = ( + kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 + ) + logging.debug(self.gin_channels, self.cond_layer_idx) + assert ( + self.cond_layer_idx < self.n_layers + ), "cond_layer_idx should be less than n_layers" + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, g=None): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + if i == self.cond_layer_idx and g is not None: + g = self.spk_emb_linear(g.transpose(1, 2)) + g = g.transpose(1, 2) + x = x + g + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, _ = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, -1) + ) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py new file mode 100644 index 0000000..35fd291 --- /dev/null +++ b/GPT_SoVITS/module/models_onnx.py @@ -0,0 +1,920 @@ +import copy +import math +import torch +from torch import nn +from torch.nn import functional as F + +from module import commons +from module import modules +from module import attentions_onnx as attentions + +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from module.commons import init_weights, get_padding +from module.mrte_model import MRTE +from module.quantize import ResidualVectorQuantizer +from text import symbols +from torch.cuda.amp import autocast + + +class StochasticDurationPredictor(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__( + self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + latent_channels=192, + ): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.latent_channels = latent_channels + + self.ssl_proj = nn.Conv1d(768, hidden_channels, 1) + + self.encoder_ssl = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers // 2, + kernel_size, + p_dropout, + ) + + self.encoder_text = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.text_embedding = nn.Embedding(len(symbols), hidden_channels) + + self.mrte = MRTE() + + self.encoder2 = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers // 2, + kernel_size, + p_dropout, + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, y, text, ge): + y_mask = torch.ones_like(y[:1,:1,:]) + + y = self.ssl_proj(y * y_mask) * y_mask + y = self.encoder_ssl(y * y_mask, y_mask) + + text_mask = torch.ones_like(text).to(y.dtype).unsqueeze(0) + + text = self.text_embedding(text).transpose(1, 2) + text = self.encoder_text(text * text_mask, text_mask) + y = self.mrte(y, y_mask, text, text_mask, ge) + + y = self.encoder2(y * y_mask, y_mask) + + stats = self.proj(y) * y_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return y, m, logs, y_mask + + def extract_latent(self, x): + x = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(x) + return codes.transpose(0, 1) + + def decode_latent(self, codes, y_mask, refer, refer_mask, ge): + quantized = self.quantizer.decode(codes) + + y = self.vq_proj(quantized) * y_mask + y = self.encoder_ssl(y * y_mask, y_mask) + + y = self.mrte(y, y_mask, refer, refer_mask, ge) + + y = self.encoder2(y * y_mask, y_mask) + + stats = self.proj(y) * y_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return y, m, logs, y_mask, quantized + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + if g != None: + g = g.detach() + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class WNEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.norm = modules.LayerNorm(out_channels) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + out = self.proj(x) * x_mask + out = self.norm(out) + return out + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class ReferenceEncoder(nn.Module): + """ + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + """ + + def __init__(self, spec_channels, gin_channels=0): + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, gin_channels) + + def forward(self, inputs): + N = inputs.size(0) + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + for conv in self.convs: + out = conv(out) + # out = wn(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)).unsqueeze(-1) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class Quantizer_module(torch.nn.Module): + def __init__(self, n_e, e_dim): + super(Quantizer_module, self).__init__() + self.embedding = nn.Embedding(n_e, e_dim) + self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) + + def forward(self, x): + d = ( + torch.sum(x**2, 1, keepdim=True) + + torch.sum(self.embedding.weight**2, 1) + - 2 * torch.matmul(x, self.embedding.weight.T) + ) + min_indicies = torch.argmin(d, 1) + z_q = self.embedding(min_indicies) + return z_q, min_indicies + + +class Quantizer(torch.nn.Module): + def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160): + super(Quantizer, self).__init__() + assert embed_dim % n_code_groups == 0 + self.quantizer_modules = nn.ModuleList( + [ + Quantizer_module(n_codes, embed_dim // n_code_groups) + for _ in range(n_code_groups) + ] + ) + self.n_code_groups = n_code_groups + self.embed_dim = embed_dim + + def forward(self, xin): + # B, C, T + B, C, T = xin.shape + xin = xin.transpose(1, 2) + x = xin.reshape(-1, self.embed_dim) + x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1) + min_indicies = [] + z_q = [] + for _x, m in zip(x, self.quantizer_modules): + _z_q, _min_indicies = m(_x) + z_q.append(_z_q) + min_indicies.append(_min_indicies) # B * T, + z_q = torch.cat(z_q, -1).reshape(xin.shape) + loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean( + (z_q - xin.detach()) ** 2 + ) + z_q = xin + (z_q - xin).detach() + z_q = z_q.transpose(1, 2) + codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups) + return z_q, loss, codes.transpose(1, 2) + + def embed(self, x): + # idx: N, 4, T + x = x.transpose(1, 2) + x = torch.split(x, 1, 2) + ret = [] + for q, embed in zip(x, self.quantizer_modules): + q = embed.embedding(q.squeeze(-1)) + ret.append(q) + ret = torch.cat(ret, -1) + return ret.transpose(1, 2) # N, C, T + + +class CodePredictor(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_q=8, + dims=1024, + ssl_dim=768, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1) + self.ref_enc = modules.MelStyleEncoder( + ssl_dim, style_vector_dim=hidden_channels + ) + + self.encoder = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + + self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1) + self.n_q = n_q + self.dims = dims + + def forward(self, x, x_mask, refer, codes, infer=False): + x = x.detach() + x = self.vq_proj(x * x_mask) * x_mask + g = self.ref_enc(refer, x_mask) + x = x + g + x = self.encoder(x * x_mask, x_mask) + x = self.out_proj(x * x_mask) * x_mask + logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose( + 2, 3 + ) + target = codes[1:].transpose(0, 1) + if not infer: + logits = logits.reshape(-1, self.dims) + target = target.reshape(-1) + loss = torch.nn.functional.cross_entropy(logits, target) + return loss + else: + _, top10_preds = torch.topk(logits, 10, dim=-1) + correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1) + top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item() + + print("Top-10 Accuracy:", top3_acc, "%") + + pred_codes = torch.argmax(logits, dim=-1) + acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item() + print("Top-1 Accuracy:", acc, "%") + + return pred_codes.transpose(0, 1) + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + semantic_frame_rate=None, + freeze_quantizer=None, + **kwargs + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.use_sdp = use_sdp + self.enc_p = TextEncoder( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels + ) + + self.ref_enc = modules.MelStyleEncoder( + spec_channels, style_vector_dim=gin_channels + ) + + ssl_dim = 768 + self.ssl_dim = ssl_dim + assert semantic_frame_rate in ["25hz", "50hz"] + self.semantic_frame_rate = semantic_frame_rate + if semantic_frame_rate == "25hz": + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) + else: + self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) + + self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) + if freeze_quantizer: + self.ssl_proj.requires_grad_(False) + self.quantizer.requires_grad_(False) + # self.enc_p.text_embedding.requires_grad_(False) + # self.enc_p.encoder_text.requires_grad_(False) + # self.enc_p.mrte.requires_grad_(False) + + def forward(self, codes, text, refer): + refer_mask = torch.ones_like(refer[:1,:1,:]) + ge = self.ref_enc(refer * refer_mask, refer_mask) + + y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) + quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) + + x, m_p, logs_p, y_mask = self.enc_p( + quantized, text, ge + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o + + def extract_latent(self, x): + ssl = self.ssl_proj(x) + quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) + return codes.transpose(0, 1) \ No newline at end of file From 7d1e94c8b05e102e1914fd59171cb2b908fd8d6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 25 Jan 2024 02:31:08 +0800 Subject: [PATCH 3/3] Add AR Onnx Module --- .../AR/models/t2s_lightning_module_onnx.py | 106 ++++++ GPT_SoVITS/AR/models/t2s_model_onnx.py | 337 ++++++++++++++++++ GPT_SoVITS/AR/modules/activation_onnx.py | 178 +++++++++ GPT_SoVITS/AR/modules/embedding_onnx.py | 63 ++++ .../AR/modules/patched_mha_with_cache_onnx.py | 92 +++++ GPT_SoVITS/AR/modules/transformer_onnx.py | 292 +++++++++++++++ 6 files changed, 1068 insertions(+) create mode 100644 GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py create mode 100644 GPT_SoVITS/AR/models/t2s_model_onnx.py create mode 100644 GPT_SoVITS/AR/modules/activation_onnx.py create mode 100644 GPT_SoVITS/AR/modules/embedding_onnx.py create mode 100644 GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py create mode 100644 GPT_SoVITS/AR/modules/transformer_onnx.py diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py b/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py new file mode 100644 index 0000000..bb9e30b --- /dev/null +++ b/GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py @@ -0,0 +1,106 @@ +# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py +import os, sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +from typing import Dict + +import torch +from pytorch_lightning import LightningModule +from AR.models.t2s_model_onnx import Text2SemanticDecoder +from AR.modules.lr_schedulers import WarmupCosineLRSchedule +from AR.modules.optim import ScaledAdam + + +class Text2SemanticLightningModule(LightningModule): + def __init__(self, config, output_dir, is_train=True): + super().__init__() + self.config = config + self.top_k = 3 + self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) + pretrained_s1 = config.get("pretrained_s1") + if pretrained_s1 and is_train: + # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) + print( + self.load_state_dict( + torch.load(pretrained_s1, map_location="cpu")["weight"] + ) + ) + if is_train: + self.automatic_optimization = False + self.save_hyperparameters() + self.eval_dir = output_dir / "eval" + self.eval_dir.mkdir(parents=True, exist_ok=True) + + def training_step(self, batch: Dict, batch_idx: int): + opt = self.optimizers() + scheduler = self.lr_schedulers() + loss, acc = self.model.forward( + batch["phoneme_ids"], + batch["phoneme_ids_len"], + batch["semantic_ids"], + batch["semantic_ids_len"], + batch["bert_feature"], + ) + self.manual_backward(loss) + if batch_idx > 0 and batch_idx % 4 == 0: + opt.step() + opt.zero_grad() + scheduler.step() + + self.log( + "total_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + self.log( + "lr", + scheduler.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + self.log( + f"top_{self.top_k}_acc", + acc, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + def validation_step(self, batch: Dict, batch_idx: int): + return + + def configure_optimizers(self): + model_parameters = self.model.parameters() + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in self.model.named_parameters()] + ) + lm_opt = ScaledAdam( + model_parameters, + lr=0.01, + betas=(0.9, 0.95), + clipping_scale=2.0, + parameters_names=parameters_names, + show_dominant_parameters=False, + clipping_update_period=1000, + ) + + return { + "optimizer": lm_opt, + "lr_scheduler": { + "scheduler": WarmupCosineLRSchedule( + lm_opt, + init_lr=self.config["optimizer"]["lr_init"], + peak_lr=self.config["optimizer"]["lr"], + end_lr=self.config["optimizer"]["lr_end"], + warmup_steps=self.config["optimizer"]["warmup_steps"], + total_steps=self.config["optimizer"]["decay_steps"], + ) + }, + } diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py new file mode 100644 index 0000000..263b933 --- /dev/null +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -0,0 +1,337 @@ +# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py +import torch +from tqdm import tqdm + +from AR.modules.embedding_onnx import SinePositionalEmbedding +from AR.modules.embedding_onnx import TokenEmbedding +from AR.modules.transformer_onnx import LayerNorm +from AR.modules.transformer_onnx import TransformerEncoder +from AR.modules.transformer_onnx import TransformerEncoderLayer +from torch import nn +from torch.nn import functional as F +from torchmetrics.classification import MulticlassAccuracy + +default_config = { + "embedding_dim": 512, + "hidden_dim": 512, + "num_head": 8, + "num_layers": 12, + "num_codebook": 8, + "p_dropout": 0.0, + "vocab_size": 1024 + 1, + "phoneme_vocab_size": 512, + "EOS": 1024, +} + +inf_tensor_value = torch.FloatTensor([-float("Inf")]).float() + +def logits_to_probs( + logits, + previous_tokens = None, + temperature: float = 1.0, + top_k = None, + top_p = None, + repetition_penalty: float = 1.0, +): + previous_tokens = previous_tokens.squeeze() + if previous_tokens is not None and repetition_penalty != 1.0: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, score * repetition_penalty, score / repetition_penalty + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum( + torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 + ) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, inf_tensor_value, logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def multinomial_sample_one_no_sync( + probs_sort +): # Does multinomial sampling without a cuda synchronization + q = torch.randn_like(probs_sort) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def sample( + logits, + previous_tokens, + **sampling_kwargs, +): + probs = logits_to_probs( + logits=logits, previous_tokens=previous_tokens, **sampling_kwargs + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +class OnnxEncoder(nn.Module): + def __init__(self, ar_text_embedding, bert_proj, ar_text_position): + super().__init__() + self.ar_text_embedding = ar_text_embedding + self.bert_proj = bert_proj + self.ar_text_position = ar_text_position + + def forward(self, x, bert_feature): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + return self.ar_text_position(x) + + +class T2SFirstStageDecoder(nn.Module): + def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric, + top_k, early_stop_num, num_layers): + super().__init__() + self.ar_audio_embedding = ar_audio_embedding + self.ar_audio_position = ar_audio_position + self.h = h + self.ar_predict_layer = ar_predict_layer + self.loss_fct = loss_fct + self.ar_accuracy_metric = ar_accuracy_metric + self.top_k = top_k + self.early_stop_num = early_stop_num + self.num_layers = num_layers + + def forward(self, x, prompt): + y = prompt + x_example = x[:,:,0] * 0.0 + #N, 1, 512 + cache = { + "all_stage": self.num_layers, + "k": None, + "v": None, + "y_emb": None, + "first_infer": 1, + "stage": 0, + } + + y_emb = self.ar_audio_embedding(y) + + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + + xy_pos = torch.concat([x, y_pos], dim=1) + + y_example = y_pos[:,:,0] * 0.0 + x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool() + y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64) + y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( + torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0 + ) + y_attn_mask = y_attn_mask > 0 + + x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool() + y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool() + x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1) + y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\ + .unsqueeze(1).repeat(self.num_layers, 1, 1, 1) + cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\ + .unsqueeze(1).repeat(self.num_layers, 1, 1, 1) + + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + + y = torch.concat([y, samples], dim=1) + + return y, cache["k"], cache["v"], cache["y_emb"], x_example + + +class T2SStageDecoder(nn.Module): + def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric, + top_k, early_stop_num, num_layers): + super().__init__() + self.ar_audio_embedding = ar_audio_embedding + self.ar_audio_position = ar_audio_position + self.h = h + self.ar_predict_layer = ar_predict_layer + self.loss_fct = loss_fct + self.ar_accuracy_metric = ar_accuracy_metric + self.top_k = top_k + self.early_stop_num = early_stop_num + self.num_layers = num_layers + + def forward(self, y, k, v, y_emb, x_example): + cache = { + "all_stage": self.num_layers, + "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), + "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)), + "y_emb": y_emb, + "first_infer": 0, + "stage": 0, + } + + y_emb = torch.cat( + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 + ) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + + xy_pos = y_pos[:, -1:] + + y_example = y_pos[:,:,0] * 0.0 + + xy_attn_mask = torch.cat([x_example, y_example], dim=1) + xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool) + + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + + y = torch.concat([y, samples], dim=1) + + return y, cache["k"], cache["v"], cache["y_emb"], logits, samples + + +class Text2SemanticDecoder(nn.Module): + def __init__(self, config, norm_first=False, top_k=3): + super(Text2SemanticDecoder, self).__init__() + self.model_dim = config["model"]["hidden_dim"] + self.embedding_dim = config["model"]["embedding_dim"] + self.num_head = config["model"]["head"] + self.num_layers = config["model"]["n_layer"] + self.norm_first = norm_first + self.vocab_size = config["model"]["vocab_size"] + self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] + self.p_dropout = float(config["model"]["dropout"]) + self.EOS = config["model"]["EOS"] + self.norm_first = norm_first + assert self.EOS == self.vocab_size - 1 + self.bert_proj = nn.Linear(1024, self.embedding_dim) + self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) + self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True) + self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout) + self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True) + self.h = TransformerEncoder( + TransformerEncoderLayer( + d_model=self.model_dim, + nhead=self.num_head, + dim_feedforward=self.model_dim * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=self.num_layers, + norm=LayerNorm(self.model_dim) if norm_first else None, + ) + self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) + self.loss_fct = nn.CrossEntropyLoss(reduction="sum") + self.ar_accuracy_metric = MulticlassAccuracy( + self.vocab_size, + top_k=top_k, + average="micro", + multidim_average="global", + ignore_index=self.EOS, + ) + self.top_k = torch.LongTensor([1]) + self.early_stop_num = torch.LongTensor([-1]) + + def init_onnx(self): + self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position) + self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h, + self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num, + self.num_layers) + self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h, + self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num, + self.num_layers) + + def forward(self, x, prompts, bert_feature): + early_stop_num = self.early_stop_num + prefix_len = prompts.shape[1] + + x = self.onnx_encoder(x, bert_feature) + y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts) + + stop = False + for idx in range(1, 1500): + enco = self.stage_decoder(y, k, v, y_emb, stage, x_example) + y, k, v, y_emb, stage, logits, samples = enco + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + break + y[0, -1] = 0 + return y, idx + + def infer(self, x, prompts, bert_feature): + top_k = self.top_k + early_stop_num = self.early_stop_num + + x = self.onnx_encoder(x, bert_feature) + + y = prompts + prefix_len = y.shape[1] + x_len = x.shape[1] + x_example = x[:,:,0] * 0.0 + x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example) + x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool) + + stop = False + cache = { + "all_stage": self.num_layers, + "k": [None] * self.num_layers, + "v": [None] * self.num_layers, + "y_emb": None, + "first_infer": 1, + "stage": 0, + } + for idx in range(1500): + if cache["first_infer"] == 1: + y_emb = self.ar_audio_embedding(y) + else: + y_emb = torch.cat( + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 + ) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + if cache["first_infer"] == 1: + xy_pos = torch.concat([x, y_pos], dim=1) + else: + xy_pos = y_pos[:, -1:] + y_len = y_pos.shape[1] + if cache["first_infer"] == 1: + x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True) + y_attn_mask = F.pad( + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), value=False + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + else: + xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool) + xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + if prompts.shape[1] == y.shape[1]: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + break + y = torch.concat([y, samples], dim=1) + cache["first_infer"] = 0 + return y, idx \ No newline at end of file diff --git a/GPT_SoVITS/AR/modules/activation_onnx.py b/GPT_SoVITS/AR/modules/activation_onnx.py new file mode 100644 index 0000000..b54acd9 --- /dev/null +++ b/GPT_SoVITS/AR/modules/activation_onnx.py @@ -0,0 +1,178 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py +from typing import Optional +from typing import Tuple +import torch +from torch import Tensor +from torch.nn import Linear +from torch.nn import Module +from torch.nn.init import constant_ +from torch.nn.init import xavier_normal_ +from torch.nn.init import xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter + +from torch.nn import functional as F +from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched + + +class MultiheadAttention(Module): + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + cache=None, + ) -> Tuple[Tensor, Optional[Tensor]]: + any_nested = query.is_nested or key.is_nested or value.is_nested + query = key = value = query.transpose(1, 0) + attn_output = multi_head_attention_forward_patched( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + cache=cache, + ) + return attn_output.transpose(1, 0) diff --git a/GPT_SoVITS/AR/modules/embedding_onnx.py b/GPT_SoVITS/AR/modules/embedding_onnx.py new file mode 100644 index 0000000..b93405b --- /dev/null +++ b/GPT_SoVITS/AR/modules/embedding_onnx.py @@ -0,0 +1,63 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py +import math + +import torch +from torch import nn + + +class TokenEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + x = self.word_embeddings(x) + x = self.dropout(x) + return x + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + self.reverse = False + self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim)) + + def extend_pe(self, x): + position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1) + scpe = (position * self.div_term).unsqueeze(0) + pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0) + pe = pe.contiguous().view(1, -1, self.embedding_dim) + return pe + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pe = self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * pe + return self.dropout(output) diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py new file mode 100644 index 0000000..14bdb55 --- /dev/null +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -0,0 +1,92 @@ +from torch.nn.functional import * +from torch.nn.functional import ( + _mha_shape_check, + _canonical_mask, + _none_or_dtype, + _in_projection_packed, +) + +def multi_head_attention_forward_patched( + query, + key, + value, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight, + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + cache=None, +) -> Tuple[Tensor, Optional[Tensor]]: + + # set up shape vars + _, _, embed_dim = query.shape + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + head_dim = embed_dim // num_heads + + proj_qkv = linear(query, in_proj_weight, in_proj_bias) + proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2] + + if cache["first_infer"] == 1: + cache["k"][cache["stage"]] = k + cache["v"][cache["stage"]] = v + else: + cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0) + cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0) + k = cache["k"][cache["stage"]] + v = cache["v"][cache["stage"]] + cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] + + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=q.dtype, + check_other=False, + ) + attn_mask = attn_mask.unsqueeze(0) + + q = q.view(-1, num_heads, head_dim).transpose(0, 1) + k = k.view(-1, num_heads, head_dim).transpose(0, 1) + v = v.view(-1, num_heads, head_dim).transpose(0, 1) + + dropout_p = 0.0 + attn_mask = attn_mask.unsqueeze(0) + q = q.view(num_heads, -1, head_dim).unsqueeze(0) + k = k.view(num_heads, -1, head_dim).unsqueeze(0) + v = v.view(num_heads, -1, head_dim).unsqueeze(0) + attn_output = scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal + ) + attn_output = ( + attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim) + ) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(-1, 1, attn_output.size(1)) + + return attn_output diff --git a/GPT_SoVITS/AR/modules/transformer_onnx.py b/GPT_SoVITS/AR/modules/transformer_onnx.py new file mode 100644 index 0000000..a3f68b4 --- /dev/null +++ b/GPT_SoVITS/AR/modules/transformer_onnx.py @@ -0,0 +1,292 @@ +# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py +import copy +import numbers +from functools import partial +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +from AR.modules.activation_onnx import MultiheadAttention +from AR.modules.scaling import BalancedDoubleSwish +from torch import nn +from torch import Tensor +from torch.nn import functional as F + +_shape_t = Union[int, List[int], torch.Size] + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class IdentityNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ) -> None: + super(IdentityNorm, self).__init__() + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + return input + + assert embedding is None + return input + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + cache=None, + ) -> Tensor: + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + cache=cache, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, # 512 16 + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + elif activation == BalancedDoubleSwish: + activation = BalancedDoubleSwish(d_model) + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + if layer_norm_cls == IdentityNorm: + norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + else: + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + cache=None, + ) -> Tensor: + x = src + stage_embedding = None + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + return x + + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + cache=None, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + cache=cache, + ) + return self.dropout1(x) + + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)])