mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 04:22:46 +08:00
修复gpt的padding mask的问题 (#2153)
* 修复gpt的padding mask的问题 * rollback tts_config
This commit is contained in:
parent
fe2f04bdb8
commit
053a356ffe
@ -622,12 +622,39 @@ class Text2SemanticDecoder(nn.Module):
|
||||
)
|
||||
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
||||
|
||||
|
||||
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
||||
### 上面是错误的,会导致padding的token被"看见"
|
||||
|
||||
# 正确的padding_mask应该是:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
||||
# padding_mask = padding_mask.view(bsz, src_len, 1)
|
||||
|
||||
|
||||
# 正确的attn_mask应该是这样的:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
|
@ -145,11 +145,15 @@ class TTS_Config:
|
||||
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
if str(self.device) == "cpu":
|
||||
print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
||||
self.is_half = False
|
||||
else:
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
if "cuda" in str(self.device) and not torch.cuda.is_available():
|
||||
print(f"Warning: CUDA is not available, set device to CPU.")
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# self.is_half = self.configs.get("is_half", False)
|
||||
# if str(self.device) == "cpu" and self.is_half:
|
||||
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
||||
# self.is_half = False
|
||||
|
||||
self.version = version
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||
|
@ -1,8 +1,8 @@
|
||||
custom:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
device: cuda
|
||||
is_half: true
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
|
Loading…
x
Reference in New Issue
Block a user