From b0786f2998f1b2fce6678434524b4e0e8cc716f5 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Sun, 7 Jul 2024 21:34:28 +0800 Subject: [PATCH] double inference speed double inference speed --- GPT_SoVITS/AR/models/t2s_model.py | 291 ++++++++++++++++++++++-------- 1 file changed, 214 insertions(+), 77 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index c8ad3d8..f465b49 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -1,8 +1,11 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e import torch -from tqdm import tqdm +import random +import numpy as np +from tqdm import tqdm +from typing import List from AR.models.utils import make_pad_mask from AR.models.utils import ( topk_sampling, @@ -10,7 +13,7 @@ from AR.models.utils import ( logits_to_probs, multinomial_sample_one_no_sync, dpo_loss, - make_reject_y, + make_reject_y, get_batch_logps ) from AR.modules.embedding import SinePositionalEmbedding @@ -35,6 +38,139 @@ default_config = { } +@torch.jit.script +class T2SMLP: + def __init__(self, w1, b1, w2, b2): + self.w1 = w1 + self.b1 = b1 + self.w2 = w2 + self.b2 = b2 + + def forward(self, x): + x = F.relu(F.linear(x, self.w1, self.b1)) + x = F.linear(x, self.w2, self.b2) + return x + + +@torch.jit.script +class T2SBlock: + def __init__( + self, + num_heads, + hidden_dim: int, + mlp: T2SMLP, + qkv_w, + qkv_b, + out_w, + out_b, + norm_w1, + norm_b1, + norm_eps1, + norm_w2, + norm_b2, + norm_eps2, + ): + self.num_heads = num_heads + self.mlp = mlp + self.hidden_dim: int = hidden_dim + self.qkv_w = qkv_w + self.qkv_b = qkv_b + self.out_w = out_w + self.out_b = out_b + self.norm_w1 = norm_w1 + self.norm_b1 = norm_b1 + self.norm_eps1 = norm_eps1 + self.norm_w2 = norm_w2 + self.norm_b2 = norm_b2 + self.norm_eps2 = norm_eps2 + + def process_prompt(self, x, attn_mask: torch.Tensor): + q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k.shape[1] + + k_cache = k + v_cache = v + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim) + attn = F.linear(attn, self.out_w, self.out_b) + + x = F.layer_norm( + x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x = F.layer_norm( + x + self.mlp.forward(x), + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + def decode_next_token(self, x, k_cache, v_cache): + q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + k_cache = torch.cat([k_cache, k], dim=1) + v_cache = torch.cat([v_cache, v], dim=1) + kv_len = k_cache.shape[1] + + batch_size = q.shape[0] + q_len = q.shape[1] + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim) + attn = F.linear(attn, self.out_w, self.out_b) + + x = F.layer_norm( + x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x = F.layer_norm( + x + self.mlp.forward(x), + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + +@torch.jit.script +class T2STransformer: + def __init__(self, num_blocks: int, blocks: List[T2SBlock]): + self.num_blocks: int = num_blocks + self.blocks = blocks + + def process_prompt( + self, x, attn_mask: torch.Tensor): + k_cache: List[torch.Tensor] = [] + v_cache: List[torch.Tensor] = [] + for i in range(self.num_blocks): + x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask) + k_cache.append(k_cache_) + v_cache.append(v_cache_) + return x, k_cache, v_cache + + def decode_next_token( + self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor] + ): + for i in range(self.num_blocks): + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) + return x, k_cache, v_cache + + class Text2SemanticDecoder(nn.Module): def __init__(self, config, norm_first=False, top_k=3): super(Text2SemanticDecoder, self).__init__() @@ -89,6 +225,37 @@ class Text2SemanticDecoder(nn.Module): ignore_index=self.EOS, ) + blocks = [] + + for i in range(self.num_layers): + layer = self.h.layers[i] + t2smlp = T2SMLP( + layer.linear1.weight, + layer.linear1.bias, + layer.linear2.weight, + layer.linear2.bias + ) + # (layer.self_attn.in_proj_weight, layer.self_attn.in_proj_bias) + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps + ) + + blocks.append(block) + + self.t2s_transformer = T2STransformer(self.num_layers, blocks) + def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -116,7 +283,7 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) - + y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -177,7 +344,7 @@ class Text2SemanticDecoder(nn.Module): A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets) loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True) - + loss = loss_1 + loss_2 return loss, acc @@ -246,14 +413,14 @@ class Text2SemanticDecoder(nn.Module): # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 def infer( - self, - x, - x_lens, - prompts, - bert_feature, - top_k: int = -100, - early_stop_num: int = -1, - temperature: float = 1.0, + self, + x, + x_lens, + prompts, + bert_feature, + top_k: int = -100, + early_stop_num: int = -1, + temperature: float = 1.0, ): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -322,15 +489,15 @@ class Text2SemanticDecoder(nn.Module): return targets[:, :-1], targets[:, 1:] def infer_panel( - self, - x, #####全部文本token - x_lens, - prompts, ####参考音频token - bert_feature, - top_k: int = -100, - top_p: int = 100, - early_stop_num: int = -1, - temperature: float = 1.0, + self, + x, #####全部文本token + x_lens, + prompts, ####参考音频token + bert_feature, + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, ): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -338,22 +505,14 @@ class Text2SemanticDecoder(nn.Module): # AR Decoder y = prompts - + x_len = x.shape[1] x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) stop = False # print(1111111,self.num_layers) - cache = { - "all_stage": self.num_layers, - "k": [None] * self.num_layers, ###根据配置自己手写 - "v": [None] * self.num_layers, - # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 - "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行 - # "logits":None,###原版就已经只对结尾求再拼接了,不用管 - # "xy_dec":None,###不需要,本来只需要最后一个做logits - "first_infer": 1, - "stage": 0, - } + + k_cache = None + v_cache = None ################### first step ########################## if y is not None: y_emb = self.ar_audio_embedding(y) @@ -361,7 +520,6 @@ class Text2SemanticDecoder(nn.Module): prefix_len = y.shape[1] y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) - cache["y_emb"] = y_emb ref_free = False else: y_emb = None @@ -373,10 +531,10 @@ class Text2SemanticDecoder(nn.Module): ref_free = True x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) - value=True, - ) + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), @@ -385,64 +543,43 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( x.device ) - for idx in tqdm(range(1500)): - - xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) + if xy_attn_mask is not None: + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) + else: + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) + logits = self.ar_predict_layer( xy_dec[:, -1] - ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 - # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) - if(idx==0):###第一次跑不能EOS否则没有了 - logits = logits[:, :-1] ###刨除1024终止符号的概率 + ) + + if idx == 0: + xy_attn_mask = None + logits = logits[:, :-1] samples = sample( logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature )[0].unsqueeze(0) - # 本次生成的 semantic_ids 和之前的 y 构成新的 y - # print(samples.shape)#[1,1]#第一个1是bs - y = torch.concat([y, samples], dim=1) + + y = torch.concat([y, samples], dim=1) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: - # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) stop = True if stop: - # if prompts.shape[1] == y.shape[1]: - # y = torch.concat([y, torch.zeros_like(samples)], dim=1) - # print("bad zero prediction") - if y.shape[1]==0: + if y.shape[1] == 0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") break - - ####################### update next step ################################### - cache["first_infer"] = 0 - if cache["y_emb"] is not None: - y_emb = torch.cat( - [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1 - ) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos[:, -1:] - else: - y_emb = self.ar_audio_embedding(y[:, -1:]) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos - y_len = y_pos.shape[1] - ###最右边一列(是错的) - # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) - # xy_attn_mask[:,-1]=False - ###最下面一行(是对的) - xy_attn_mask = torch.zeros( - (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device - ) + ####################### update next step ################################### + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, prompts.shape[1] + idx] + if ref_free: return y[:, :-1], 0 - return y[:, :-1], idx-1 + return y[:, :-1], idx - 1 \ No newline at end of file