From 74f81f3009b9aa42bf8f41dfc161f1cebd278577 Mon Sep 17 00:00:00 2001 From: XTer Date: Thu, 7 Mar 2024 14:31:33 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8C=E6=AD=A5=E4=BA=860306=E6=9B=B4?= =?UTF-8?q?=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 219 +++++++++++++++++++---------- docs/cn/Changelog_CN.md | 13 ++ tools/uvr5/uvr5_weights/.gitignore | 2 + 3 files changed, 156 insertions(+), 78 deletions(-) create mode 100644 tools/uvr5/uvr5_weights/.gitignore diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index c8ad3d82..e4891f4c 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -10,7 +10,7 @@ from AR.models.utils import ( logits_to_probs, multinomial_sample_one_no_sync, dpo_loss, - make_reject_y, + make_reject_y, get_batch_logps ) from AR.modules.embedding import SinePositionalEmbedding @@ -22,6 +22,11 @@ from torch import nn from torch.nn import functional as F from torchmetrics.classification import MulticlassAccuracy +try: + from flash_attn import flash_attn_with_kvcache +except ImportError: + flash_attn_with_kvcache = None + default_config = { "embedding_dim": 512, "hidden_dim": 512, @@ -116,7 +121,7 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) - + y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -177,7 +182,7 @@ class Text2SemanticDecoder(nn.Module): A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets) loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True) - + loss = loss_1 + loss_2 return loss, acc @@ -246,14 +251,14 @@ class Text2SemanticDecoder(nn.Module): # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 def infer( - self, - x, - x_lens, - prompts, - bert_feature, - top_k: int = -100, - early_stop_num: int = -1, - temperature: float = 1.0, + self, + x, + x_lens, + prompts, + bert_feature, + top_k: int = -100, + early_stop_num: int = -1, + temperature: float = 1.0, ): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -321,16 +326,100 @@ class Text2SemanticDecoder(nn.Module): # 错位 return targets[:, :-1], targets[:, 1:] + def infer_one_step(self, x, xy_attn_mask, k_cache, v_cache, cache_seqlens): + hidden_dim = x.shape[-1] + + for layer_id in range(self.num_layers): + layer = self.h.layers[layer_id] + + q, k, v = F.linear( + x, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias + ).chunk(3, dim=-1) + + batch_size = q.shape[0] + q_len = q.shape[1] + + if flash_attn_with_kvcache is None: + past_k = k_cache[layer_id] + past_v = v_cache[layer_id] + + if past_k is not None: + k = torch.cat([past_k, k], 1) + v = torch.cat([past_v, v], 1) + k_cache[layer_id] = k + v_cache[layer_id] = v + kv_len = k.shape[1] + + q = q.view(batch_size, q_len, layer.self_attn.num_heads, -1).transpose(1, 2) + k = k.view(batch_size, kv_len, layer.self_attn.num_heads, -1).transpose(1, 2) + v = v.view(batch_size, kv_len, layer.self_attn.num_heads, -1).transpose(1, 2) + + if xy_attn_mask is None: + attn = F.scaled_dot_product_attention(q, k, v) + else: + attn = F.scaled_dot_product_attention(q, k, v, ~xy_attn_mask) + + attn = attn.permute(2, 0, 1, 3).reshape(-1, hidden_dim) + else: + q = q.view(batch_size, q_len, layer.self_attn.num_heads, -1) + k = k.view(batch_size, q_len, layer.self_attn.num_heads, -1) + v = v.view(batch_size, q_len, layer.self_attn.num_heads, -1) + + if xy_attn_mask is None: + attn = flash_attn_with_kvcache(q, k_cache[layer_id], v_cache[layer_id], k, v, cache_seqlens=cache_seqlens, causal=True) + else: + # NOTE: there's a slight difference with the result produced by SDPA. + x_len = (~xy_attn_mask).sum(1)[0].item() + + attn_x = flash_attn_with_kvcache( + q[:, :x_len], + k_cache[layer_id], + v_cache[layer_id], + k[:, :x_len], + v[:, :x_len], + cache_seqlens=cache_seqlens, + causal=False + ) + + attn_y = flash_attn_with_kvcache( + q[:, x_len:], + k_cache[layer_id], + v_cache[layer_id], + k[:, x_len:], + v[:, x_len:], + cache_seqlens=cache_seqlens + x_len, + causal=True + ) + + attn = torch.cat([attn_x, attn_y], dim=1) + attn = attn.view(-1, hidden_dim) + + attn_out = F.linear(attn, layer.self_attn.out_proj.weight, layer.self_attn.out_proj.bias) + + x = layer.norm1(x + attn_out, None) + + x = layer.norm2(x + layer.linear2(F.relu(layer.linear1(x))), None) + + xy_dec = x + + logits = self.ar_predict_layer( + xy_dec[:, -1] + ) + + return logits + def infer_panel( - self, - x, #####全部文本token - x_lens, - prompts, ####参考音频token - bert_feature, - top_k: int = -100, - top_p: int = 100, - early_stop_num: int = -1, - temperature: float = 1.0, + self, + x, #####全部文本token + x_lens, + prompts, ####参考音频token + bert_feature, + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, ): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -338,22 +427,18 @@ class Text2SemanticDecoder(nn.Module): # AR Decoder y = prompts - + x_len = x.shape[1] x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) stop = False # print(1111111,self.num_layers) - cache = { - "all_stage": self.num_layers, - "k": [None] * self.num_layers, ###根据配置自己手写 - "v": [None] * self.num_layers, - # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 - "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行 - # "logits":None,###原版就已经只对结尾求再拼接了,不用管 - # "xy_dec":None,###不需要,本来只需要最后一个做logits - "first_infer": 1, - "stage": 0, - } + + if flash_attn_with_kvcache is not None: + k_cache = [torch.empty(x.shape[0], 2048, 16, 32, dtype=x.dtype, device=x.device) for _ in range(self.num_layers)] + v_cache = [torch.empty(x.shape[0], 2048, 16, 32, dtype=x.dtype, device=x.device) for _ in range(self.num_layers)] + else: + k_cache = [None] * self.num_layers + v_cache = [None] * self.num_layers ################### first step ########################## if y is not None: y_emb = self.ar_audio_embedding(y) @@ -361,7 +446,6 @@ class Text2SemanticDecoder(nn.Module): prefix_len = y.shape[1] y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) - cache["y_emb"] = y_emb ref_free = False else: y_emb = None @@ -373,10 +457,10 @@ class Text2SemanticDecoder(nn.Module): ref_free = True x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) - value=True, - ) + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), @@ -385,64 +469,43 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( x.device ) - + + cache_seqlens = torch.zeros(x.shape[0], dtype=torch.int32, device=x.device) for idx in tqdm(range(1500)): - - xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) - logits = self.ar_predict_layer( - xy_dec[:, -1] - ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 - # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) - if(idx==0):###第一次跑不能EOS否则没有了 - logits = logits[:, :-1] ###刨除1024终止符号的概率 + logits = self.infer_one_step(xy_pos, xy_attn_mask, k_cache, v_cache, cache_seqlens) + + if idx == 0: + cache_seqlens += xy_pos.shape[1] + else: + cache_seqlens += 1 + xy_attn_mask = None + + if idx == 0: + logits = logits[:, :-1] samples = sample( logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature )[0].unsqueeze(0) - # 本次生成的 semantic_ids 和之前的 y 构成新的 y - # print(samples.shape)#[1,1]#第一个1是bs - y = torch.concat([y, samples], dim=1) + + y = torch.concat([y, samples], dim=1) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: - # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) stop = True if stop: - # if prompts.shape[1] == y.shape[1]: - # y = torch.concat([y, torch.zeros_like(samples)], dim=1) - # print("bad zero prediction") - if y.shape[1]==0: + if y.shape[1] == 0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") break - - ####################### update next step ################################### - cache["first_infer"] = 0 - if cache["y_emb"] is not None: - y_emb = torch.cat( - [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1 - ) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos[:, -1:] - else: - y_emb = self.ar_audio_embedding(y[:, -1:]) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos - y_len = y_pos.shape[1] - ###最右边一列(是错的) - # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) - # xy_attn_mask[:,-1]=False - ###最下面一行(是对的) - xy_attn_mask = torch.zeros( - (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device - ) + ####################### update next step ################################### + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, prompts.shape[1] + idx] + if ref_free: return y[:, :-1], 0 - return y[:, :-1], idx-1 + return y[:, :-1], idx - 1 \ No newline at end of file diff --git a/docs/cn/Changelog_CN.md b/docs/cn/Changelog_CN.md index 8afd3514..25705c70 100644 --- a/docs/cn/Changelog_CN.md +++ b/docs/cn/Changelog_CN.md @@ -135,6 +135,19 @@ 4-colab修复不开启公网url +### 20240306更新 + +1-推理加速50%(RTX3090+pytorch2.2.1+cu11.8+win10+py39 tested)https://github.com/RVC-Boss/GPT-SoVITS/pull/672 + +2-如果用faster whisper非中文ASR不再需要先下中文funasr模型 + +3-修复uvr5去混响模型 是否混响 反的 https://github.com/RVC-Boss/GPT-SoVITS/pull/610 + +4-faster whisper如果无cuda可用自动cpu推理 https://github.com/RVC-Boss/GPT-SoVITS/pull/675 + +5-修改is_half的判断使在Mac上能正常CPU推理 https://github.com/RVC-Boss/GPT-SoVITS/pull/573 + + todolist: 1-中文多音字推理优化 diff --git a/tools/uvr5/uvr5_weights/.gitignore b/tools/uvr5/uvr5_weights/.gitignore new file mode 100644 index 00000000..d6b7ef32 --- /dev/null +++ b/tools/uvr5/uvr5_weights/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore