mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
fixed some bugs GPT_SoVITS/AR/models/t2s_model.py
fixed some bugs GPT_SoVITS/TTS_infer_pack/TTS.py
This commit is contained in:
parent
cae976ef5a
commit
cd746848e6
@ -97,7 +97,7 @@ class T2SBlock:
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
@ -532,6 +532,20 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
||||
ref_free = True
|
||||
|
||||
|
||||
##### create mask #####
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
|
||||
y_mask = make_pad_mask(y_lens)
|
||||
x_mask = make_pad_mask(x_lens)
|
||||
|
||||
|
||||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||
_xy_padding_mask = (
|
||||
xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1)
|
||||
)
|
||||
|
||||
x_attn_mask_pad = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
@ -545,7 +559,12 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
x.device
|
||||
)
|
||||
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
||||
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||||
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
||||
xy_attn_mask = new_attn_mask
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
|
@ -361,6 +361,7 @@ class TTS:
|
||||
phones_list = []
|
||||
# bert_features_list = []
|
||||
all_phones_list = []
|
||||
all_phones_len_list = []
|
||||
all_bert_features_list = []
|
||||
norm_text_batch = []
|
||||
bert_max_len = 0
|
||||
@ -376,16 +377,18 @@ class TTS:
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
all_phones = phones.clone()
|
||||
# norm_text = item["norm_text"]
|
||||
|
||||
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
||||
phones_max_len = max(phones_max_len, phones.shape[-1])
|
||||
|
||||
phones_list.append(phones)
|
||||
all_phones_list.append(all_phones)
|
||||
all_phones_len_list.append(all_phones.shape[-1])
|
||||
all_bert_features_list.append(all_bert_features)
|
||||
norm_text_batch.append(item["norm_text"])
|
||||
# phones_batch = phones_list
|
||||
phones_batch = phones_list
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
|
||||
all_bert_features_batch.zero_()
|
||||
@ -397,6 +400,7 @@ class TTS:
|
||||
batch = {
|
||||
"phones": phones_batch,
|
||||
"all_phones": all_phones_batch,
|
||||
"all_phones_len": torch.LongTensor(all_phones_len_list),
|
||||
"all_bert_features": all_bert_features_batch,
|
||||
"norm_text": norm_text_batch
|
||||
}
|
||||
@ -541,10 +545,12 @@ class TTS:
|
||||
t3 = ttime()
|
||||
batch_phones = item["phones"]
|
||||
all_phoneme_ids = item["all_phones"]
|
||||
all_phoneme_lens = item["all_phones_len"]
|
||||
all_bert_features = item["all_bert_features"]
|
||||
norm_text = item["norm_text"]
|
||||
|
||||
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
|
||||
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
|
||||
all_bert_features = all_bert_features.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
all_bert_features = all_bert_features.half()
|
||||
@ -558,7 +564,7 @@ class TTS:
|
||||
with torch.no_grad():
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
None,
|
||||
all_phoneme_lens,
|
||||
prompt,
|
||||
all_bert_features,
|
||||
# prompt_phone_len=ph_offset,
|
||||
@ -588,7 +594,7 @@ class TTS:
|
||||
## 改成串行处理
|
||||
batch_audio_fragment = []
|
||||
for i, idx in enumerate(idx_list):
|
||||
phones = batch_phones[i].clone().unsqueeze(0).to(self.configs.device)
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment =(self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spepc
|
||||
|
Loading…
x
Reference in New Issue
Block a user