From 053a356ffe3d72b475d512b4f48b229e9b145f19 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Wed, 5 Mar 2025 17:14:43 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dgpt=E7=9A=84padding=20mask?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98=20(#2153)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复gpt的padding mask的问题 * rollback tts_config --- GPT_SoVITS/AR/models/t2s_model.py | 35 +++++++++++++++++++++++++++---- GPT_SoVITS/TTS_infer_pack/TTS.py | 14 ++++++++----- GPT_SoVITS/configs/tts_infer.yaml | 4 ++-- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index d97cfb7..8a32d0d 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -622,12 +622,39 @@ class Text2SemanticDecoder(nn.Module): ) causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device) - padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y] - - + # padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y] + ### 上面是错误的,会导致padding的token被"看见" + + # 正确的padding_mask应该是: + # | pad_len | x_len | y_len | + # [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果 + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]] + + padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1) + attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask) attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool() - # padding_mask = padding_mask.view(bsz, src_len, 1) + + + # 正确的attn_mask应该是这样的: + # | pad_len | x_len | y_len | + # [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], + # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], + # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果 + # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], + # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], + # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], + # [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS], + # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]] + ###### decode ##### y_list = [None]*y.shape[0] diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c768fb3..420aa77 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -145,11 +145,15 @@ class TTS_Config: self.device = self.configs.get("device", torch.device("cpu")) - if str(self.device) == "cpu": - print(f"Warning: Half precision is not supported on CPU, set is_half to False.") - self.is_half = False - else: - self.is_half = self.configs.get("is_half", False) + if "cuda" in str(self.device) and not torch.cuda.is_available(): + print(f"Warning: CUDA is not available, set device to CPU.") + self.device = torch.device("cpu") + + # self.is_half = self.configs.get("is_half", False) + # if str(self.device) == "cpu" and self.is_half: + # print(f"Warning: Half precision is not supported on CPU, set is_half to False.") + # self.is_half = False + self.version = version self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None) diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index d58cfce..66f1193 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -1,8 +1,8 @@ custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cpu - is_half: false + device: cuda + is_half: true t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt version: v2 vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth