From 6dd2f720901a91904268c5b03fa68e2ebd5a309e Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Tue, 4 Mar 2025 16:45:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9gpt=E5=B9=B6=E8=A1=8C?= =?UTF-8?q?=E6=8E=A8=E7=90=86=E6=97=B6=E7=9A=84mask=E7=AD=96=E7=95=A5?= =?UTF-8?q?=E4=B8=BApadding=20left=20(#2144)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 更改gpt并行推理时的mask策略为padding left,使batch_infer更接近于naive_infer 减少冗余操作并使用torch_sdpa,以提升推理速度 * rollback tts_infer.yaml --- GPT_SoVITS/AR/models/t2s_model.py | 84 ++++++++++++++----------------- GPT_SoVITS/AR/models/utils.py | 33 ++++++++++++ GPT_SoVITS/TTS_infer_pack/TTS.py | 6 ++- GPT_SoVITS/configs/tts_infer.yaml | 4 +- 4 files changed, 77 insertions(+), 50 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index f8f6582..d97cfb7 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -5,7 +5,7 @@ from typing import List, Optional import torch from tqdm import tqdm -from AR.models.utils import make_pad_mask +from AR.models.utils import make_pad_mask, make_pad_mask_left from AR.models.utils import ( topk_sampling, sample, @@ -162,7 +162,7 @@ class T2SBlock: ) return x, k_cache, v_cache - def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True): + def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True): q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) k_cache = torch.cat([k_cache, k], dim=1) @@ -178,7 +178,7 @@ class T2SBlock: if torch_sdpa: - attn = F.scaled_dot_product_attention(q, k, v) + attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None) else: attn = scaled_dot_product_attention(q, k, v, attn_mask) @@ -223,7 +223,7 @@ class T2STransformer: self, x:torch.Tensor, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], - attn_mask : Optional[torch.Tensor]=None, + attn_mask : torch.Tensor=None, torch_sdpa:bool=True ): for i in range(self.num_blocks): @@ -573,71 +573,61 @@ class Text2SemanticDecoder(nn.Module): x_item = self.ar_text_embedding(x_item.unsqueeze(0)) x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0)) x_item = self.ar_text_position(x_item).squeeze(0) - x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0] torch.Tensor: return expaned_lengths >= lengths.unsqueeze(-1) +def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + #>>> lengths = torch.tensor([1, 3, 2, 5]) + #>>> make_pad_mask(lengths) + tensor( + [ + [True, True, False], + [True, False, False], + [True, True, False], + ... + ] + ) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1) + expaned_lengths -= (max_len-lengths).unsqueeze(-1) + + return expaned_lengths<0 + + # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index ee2ec1e..c768fb3 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -145,7 +145,11 @@ class TTS_Config: self.device = self.configs.get("device", torch.device("cpu")) - self.is_half = self.configs.get("is_half", False) + if str(self.device) == "cpu": + print(f"Warning: Half precision is not supported on CPU, set is_half to False.") + self.is_half = False + else: + self.is_half = self.configs.get("is_half", False) self.version = version self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None) diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index 66f1193..d58cfce 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -1,8 +1,8 @@ custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cuda - is_half: true + device: cpu + is_half: false t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt version: v2 vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth