diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index e140b4f..ed46b2b 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -97,7 +97,7 @@ class T2SBlock: k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) + attn = F.scaled_dot_product_attention(q, k, v, attn_mask) attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) @@ -532,6 +532,20 @@ class Text2SemanticDecoder(nn.Module): y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) ref_free = True + + ##### create mask ##### + bsz = x.shape[0] + src_len = x_len + y_len + y_lens = torch.LongTensor([y_len]*bsz).to(x.device) + y_mask = make_pad_mask(y_lens) + x_mask = make_pad_mask(x_lens) + + + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + _xy_padding_mask = ( + xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1) + ) + x_attn_mask_pad = F.pad( x_attn_mask, (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) @@ -545,7 +559,12 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( x.device ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + ###### decode ##### y_list = [None]*y.shape[0] batch_idx_map = list(range(y.shape[0])) idx_list = [None]*y.shape[0] diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index cc460b8..ba29a03 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -361,6 +361,7 @@ class TTS: phones_list = [] # bert_features_list = [] all_phones_list = [] + all_phones_len_list = [] all_bert_features_list = [] norm_text_batch = [] bert_max_len = 0 @@ -376,16 +377,18 @@ class TTS: phones = torch.LongTensor(item["phones"]) all_phones = phones.clone() # norm_text = item["norm_text"] + bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) phones_max_len = max(phones_max_len, phones.shape[-1]) phones_list.append(phones) all_phones_list.append(all_phones) + all_phones_len_list.append(all_phones.shape[-1]) all_bert_features_list.append(all_bert_features) norm_text_batch.append(item["norm_text"]) - # phones_batch = phones_list + phones_batch = phones_list max_len = max(bert_max_len, phones_max_len) - phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) + # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len) all_bert_features_batch.zero_() @@ -397,6 +400,7 @@ class TTS: batch = { "phones": phones_batch, "all_phones": all_phones_batch, + "all_phones_len": torch.LongTensor(all_phones_len_list), "all_bert_features": all_bert_features_batch, "norm_text": norm_text_batch } @@ -541,10 +545,12 @@ class TTS: t3 = ttime() batch_phones = item["phones"] all_phoneme_ids = item["all_phones"] + all_phoneme_lens = item["all_phones_len"] all_bert_features = item["all_bert_features"] norm_text = item["norm_text"] all_phoneme_ids = all_phoneme_ids.to(self.configs.device) + all_phoneme_lens = all_phoneme_lens.to(self.configs.device) all_bert_features = all_bert_features.to(self.configs.device) if self.configs.is_half: all_bert_features = all_bert_features.half() @@ -558,7 +564,7 @@ class TTS: with torch.no_grad(): pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( all_phoneme_ids, - None, + all_phoneme_lens, prompt, all_bert_features, # prompt_phone_len=ph_offset, @@ -588,7 +594,7 @@ class TTS: ## 改成串行处理 batch_audio_fragment = [] for i, idx in enumerate(idx_list): - phones = batch_phones[i].clone().unsqueeze(0).to(self.configs.device) + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 audio_fragment =(self.vits_model.decode( _pred_semantic, phones, refer_audio_spepc