diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index c8ad3d82..0947531d 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -321,6 +321,11 @@ class Text2SemanticDecoder(nn.Module): # 错位 return targets[:, :-1], targets[:, 1:] + def check_end(self, logits, samples): + if torch.all(torch.argmax(logits, dim=-1) == self.EOS) or torch.all(samples[:, 0] == self.EOS): + + return True + def infer_panel( self, x, #####全部文本token @@ -338,8 +343,8 @@ class Text2SemanticDecoder(nn.Module): # AR Decoder y = prompts - - x_len = x.shape[1] + ref_y_len = y.shape[1] if y is not None else 0 + x_len = x_lens.max() x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) stop = False # print(1111111,self.num_layers) @@ -385,10 +390,10 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( x.device ) - + batch_end_index = [0 for _ in range(x.shape[0])] for idx in tqdm(range(1500)): - + xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer( xy_dec[:, -1] @@ -397,17 +402,18 @@ class Text2SemanticDecoder(nn.Module): if(idx==0):###第一次跑不能EOS否则没有了 logits = logits[:, :-1] ###刨除1024终止符号的概率 samples = sample( - logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature - )[0].unsqueeze(0) + logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature + )[0] # 本次生成的 semantic_ids 和之前的 y 构成新的 y # print(samples.shape)#[1,1]#第一个1是bs - y = torch.concat([y, samples], dim=1) + y = torch.concat([y, samples], dim=1) + # print(y) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) stop = True - if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + if self.check_end(logits, samples): # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) stop = True if stop: @@ -443,6 +449,7 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.zeros( (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device ) + batch_end_index = [min(torch.where(by == self.EOS)[0]) if self.EOS in by else idx for by in y ] if ref_free: - return y[:, :-1], 0 - return y[:, :-1], idx-1 + return y[:, :-1], [0 for _ in range(y.shape[0])], batch_end_index + return y[:, :-1], [ref_y_len for _ in range(y.shape[0])], batch_end_index diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index 9678c7e1..ce0a98b7 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -115,17 +115,17 @@ def logits_to_probs( top_p: Optional[int] = None, repetition_penalty: float = 1.0, ): - if previous_tokens is not None: - previous_tokens = previous_tokens.squeeze() + # if previous_tokens is not None: + # previous_tokens = previous_tokens.squeeze() # print(logits.shape,previous_tokens.shape) # pdb.set_trace() if previous_tokens is not None and repetition_penalty != 1.0: previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.gather(logits, dim=1, index=previous_tokens) score = torch.where( score < 0, score * repetition_penalty, score / repetition_penalty ) - logits.scatter_(dim=0, index=previous_tokens, src=score) + logits.scatter_(dim=1, index=previous_tokens, src=score) if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) @@ -133,9 +133,9 @@ def logits_to_probs( torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 ) sorted_indices_to_remove = cum_probs > top_p - sorted_indices_to_remove[0] = False # keep at least one option + sorted_indices_to_remove[:, 0] = False # keep at least one option indices_to_remove = sorted_indices_to_remove.scatter( - dim=0, index=sorted_indices, src=sorted_indices_to_remove + dim=1, index=sorted_indices, src=sorted_indices_to_remove ) logits = logits.masked_fill(indices_to_remove, -float("Inf")) @@ -143,7 +143,7 @@ def logits_to_probs( if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) + pivot = v[: , -1].unsqueeze(-1) logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache.py index 7be241da..e5a625be 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache.py @@ -448,10 +448,9 @@ def multi_head_attention_forward_patched( k = k.view(bsz, num_heads, src_len, head_dim) v = v.view(bsz, num_heads, src_len, head_dim) - # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): attn_output = scaled_dot_product_attention( q, k, v, attn_mask, dropout_p, is_causal - ) + ) attn_output = ( attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) diff --git a/GPT_SoVITS/AR/modules/transformer.py b/GPT_SoVITS/AR/modules/transformer.py index 7921f48e..6bd43f88 100644 --- a/GPT_SoVITS/AR/modules/transformer.py +++ b/GPT_SoVITS/AR/modules/transformer.py @@ -297,7 +297,7 @@ class TransformerEncoderLayer(nn.Module): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) - + if self.norm_first: x = x + self._sa_block( self.norm1(x, stage_embedding), diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index ee099627..497df991 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -66,6 +66,7 @@ from time import time as ttime from module.mel_processing import spectrogram_torch from my_utils import load_audio from tools.i18n.i18n import I18nAuto +from utils import tensor_padding i18n = I18nAuto() @@ -311,7 +312,7 @@ def merge_short_text_in_array(texts, threshold): result[len(result) - 1] += text return result -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False, infer_batch = 1): if prompt_text is None or len(prompt_text) == 0: ref_free = True t0 = ttime() @@ -371,28 +372,48 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if not ref_free: phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language) - for text in texts: - # 解决输入目标文本的空行导致报错的问题 - if (len(text.strip()) == 0): - continue - if (text[-1] not in splits): text += "。" if text_language != "en" else "." - print(i18n("实际输入的目标文本(每句):"), text) - phones2,bert2,norm_text2=get_phones_and_bert(text, text_language) - print(i18n("前端处理后的文本(每句):"), norm_text2) - if not ref_free: - bert = torch.cat([bert1, bert2], 1) - all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) + batch_num = len(texts) // infer_batch + if len(texts) % infer_batch != 0: + batch_num += 1 + for i in range(batch_num): + if i == batch_num - 1: + text = texts[i * infer_batch:] else: - bert = bert2 - all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) + text = texts[i * infer_batch: (i + 1) * infer_batch] + # for text in texts: + # 解决输入目标文本的空行导致报错的问题 + # 过滤每一个需要合成的文本 + all_phoneme_ids_batch = [] + all_bert_batch = [] + all_phoneme_len_batch = [] + all_vits_phones2 = [] + for T in text: + if (len(T.strip()) == 0): + continue + if (T[-1] not in splits): T += "。" if text_language != "en" else "." + print(i18n("实际输入的目标文本(每句):"), T) + phones2,bert2,norm_text2=get_phones_and_bert(T, text_language) + all_vits_phones2+=phones2 + print(i18n("前端处理后的文本(每句):"), norm_text2) + if not ref_free: + bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) + else: + bert = bert2 + all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + all_phoneme_ids_batch.append(all_phoneme_ids) + all_bert_batch.append(bert) + all_phoneme_len_batch.append(all_phoneme_ids.shape[-1]) - bert = bert.to(device).unsqueeze(0) - all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - prompt = prompt_semantic.unsqueeze(0).to(device) + all_phoneme_ids = tensor_padding(all_phoneme_ids_batch) + bert = tensor_padding(all_bert_batch) + all_phoneme_len = torch.tensor(all_phoneme_len_batch).to(device) + prompt = prompt_semantic.unsqueeze(0).to(device).expand(all_phoneme_ids.shape[0], -1, -1) t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer( - pred_semantic, idx = t2s_model.model.infer_panel( + pred_semantic, sidx, eidx = t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_len, None if ref_free else prompt, @@ -405,9 +426,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, ) t3 = ttime() # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:, -idx:].unsqueeze( - 0 - ) # .unsqueeze(0)#mq要多unsqueeze一次 + pred_semantic = torch.cat([pred[si:ei] for pred, si, ei in zip(pred_semantic, sidx, eidx)], dim=0).unsqueeze(0).unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 refer = get_spepc(hps, ref_wav_path) # .to(device) if is_half == True: refer = refer.half().to(device) @@ -416,7 +435,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] audio = ( vq_model.decode( - pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer + pred_semantic, torch.LongTensor(all_vits_phones2).to(device).unsqueeze(0), refer ) .detach() .cpu() @@ -588,12 +607,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) - inference_button = gr.Button(i18n("合成语音"), variant="primary") + with gr.Row(): + batch_size = gr.Slider(minimum=1,maximum=32,step=1,label=i18n("batch_size"),value=1,interactive=True) + inference_button = gr.Button(i18n("合成语音"), variant="primary") output = gr.Audio(label=i18n("输出的语音")) inference_button.click( get_tts_wav, - [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free], + [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free, batch_size], [output], ) diff --git a/GPT_SoVITS/utils.py b/GPT_SoVITS/utils.py index 7984b5a8..99e10aad 100644 --- a/GPT_SoVITS/utils.py +++ b/GPT_SoVITS/utils.py @@ -6,6 +6,7 @@ import logging import json import subprocess import traceback +from typing import List import librosa import numpy as np @@ -270,7 +271,6 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim del_routine = lambda x: [os.remove(x), del_info(x)] rs = [del_routine(fn) for fn in to_del] - def get_hparams_from_dir(model_dir): config_save_path = os.path.join(model_dir, "config.json") with open(config_save_path, "r") as f: @@ -331,6 +331,12 @@ def get_logger(model_dir, filename="train.log"): return logger +def tensor_padding(tensor_list: List[torch.Tensor]): + max_len = max([t.shape[-1] for t in tensor_list]) + tensor_list = [torch.cat([t, torch.zeros(*list(t.shape[:-1]), max_len - t.size(-1)).to(t)], dim=-1) for t in tensor_list] + + return torch.cat(tensor_list, dim=0) + class HParams: def __init__(self, **kwargs): for k, v in kwargs.items(): diff --git a/gweight.txt b/gweight.txt new file mode 100644 index 00000000..26947402 --- /dev/null +++ b/gweight.txt @@ -0,0 +1 @@ +GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt \ No newline at end of file diff --git a/sweight.txt b/sweight.txt new file mode 100644 index 00000000..1885189b --- /dev/null +++ b/sweight.txt @@ -0,0 +1 @@ +GPT_SoVITS/pretrained_models/s2G488k.pth \ No newline at end of file