diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index dfd6eb0..da95111 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -707,19 +707,33 @@ class Text2SemanticDecoder(nn.Module): y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) ref_free = True - x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) - value=True, - ) - y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + ##### create mask ##### + bsz = x.shape[0] + src_len = x_len + y_len + y_lens = torch.LongTensor([y_len]*bsz).to(x.device) + y_mask = make_pad_mask(y_lens) + x_mask = make_pad_mask(x_lens) + + # (bsz, x_len + y_len) + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + + x_mask = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( - x.device - ) + + xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz*self.num_head, -1, -1).to(x.device) + # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) + xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(bsz, src_len, src_len).repeat(self.num_head, 1, 1) + xy_attn_mask = xy_mask.logical_or(xy_padding_mask) + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) y_list = [None]*y.shape[0] batch_idx_map = list(range(y.shape[0])) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 62bb2e9..e569f52 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -3,6 +3,8 @@ import math import os, sys import random import traceback + +from tqdm import tqdm now_dir = os.getcwd() sys.path.append(now_dir) import ffmpeg @@ -552,14 +554,15 @@ class TTS: "prompt_text": "", # str. prompt text for the reference audio "prompt_lang": "", # str. language of the prompt text for the reference audio "top_k": 5, # int. top k sampling - "top_p": 1, # float. top p sampling - "temperature": 1, # float. temperature for sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling "text_split_method": "", # str. text split method, see text_segmentaion_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket: True, # bool. whether to split the batch into multiple buckets. "return_fragment": False, # bool. step by step return the audio fragment. "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. } returns: tulpe[int, np.ndarray]: sampling rate and audio data. @@ -580,9 +583,10 @@ class TTS: speed_factor = inputs.get("speed_factor", 1.0) split_bucket = inputs.get("split_bucket", True) return_fragment = inputs.get("return_fragment", False) + fragment_interval = inputs.get("fragment_interval", 0.3) if return_fragment: - split_bucket = False + # split_bucket = False print(i18n("分段返回模式已开启")) if split_bucket: split_bucket = False @@ -590,7 +594,10 @@ class TTS: if split_bucket: print(i18n("分桶处理模式已开启")) - + + if fragment_interval<0.01: + fragment_interval = 0.01 + print(i18n("分段间隔过小,已自动设置为0.01")) no_prompt_text = False if prompt_text in [None, ""]: @@ -627,19 +634,52 @@ class TTS: ###### text preprocessing ######## - data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) - if len(data) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), - dtype=np.int16) - return - t1 = ttime() - data, batch_index_list = self.to_batch(data, - prompt_data=self.prompt_cache if not no_prompt_text else None, - batch_size=batch_size, - threshold=batch_threshold, - split_bucket=split_bucket - ) + data:list = None + if not return_fragment: + data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) + if len(data) == 0: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + return + + batch_index_list:list = None + data, batch_index_list = self.to_batch(data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=split_bucket + ) + else: + print(i18n("############ 切分文本 ############")) + texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) + data = [] + for i in range(len(texts)): + if i%batch_size == 0: + data.append([]) + data[-1].append(texts[i]) + + def make_batch(batch_texts): + batch_data = [] + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(batch_texts): + phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) + if phones is None: + continue + res={ + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + batch_data.append(res) + batch, _ = self.to_batch(batch_data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=False + ) + return batch[0] + t2 = ttime() try: print("############ 推理 ############") @@ -649,6 +689,9 @@ class TTS: audio = [] for item in data: t3 = ttime() + if return_fragment: + item = make_batch(item) + batch_phones = item["phones"] batch_phones_len = item["phones_len"] all_phoneme_ids = item["all_phones"] @@ -734,14 +777,16 @@ class TTS: print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) yield self.audio_postprocess([batch_audio_fragment], self.configs.sampling_rate, - batch_index_list, + None, speed_factor, - split_bucket) + False, + fragment_interval + ) else: audio.append(batch_audio_fragment) if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16) return @@ -751,7 +796,9 @@ class TTS: self.configs.sampling_rate, batch_index_list, speed_factor, - split_bucket) + split_bucket, + fragment_interval + ) except Exception as e: traceback.print_exc() # 必须返回一个空音频, 否则会导致显存不释放。 @@ -781,9 +828,11 @@ class TTS: sr:int, batch_index_list:list=None, speed_factor:float=1.0, - split_bucket:bool=True)->tuple[int, np.ndarray]: + split_bucket:bool=True, + fragment_interval:float=0.3 + )->tuple[int, np.ndarray]: zero_wav = torch.zeros( - int(self.configs.sampling_rate * 0.3), + int(self.configs.sampling_rate * fragment_interval), dtype=self.precison, device=self.configs.device ) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 4c9d90f..7ba5d1e 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -91,7 +91,7 @@ def inference(text, text_lang, top_p, temperature, text_split_method, batch_size, speed_factor, ref_text_free, - split_bucket + split_bucket,fragment_interval, ): inputs={ "text": text, @@ -106,7 +106,8 @@ def inference(text, text_lang, "batch_size":int(batch_size), "speed_factor":float(speed_factor), "split_bucket":split_bucket, - "return_fragment":False + "return_fragment":False, + "fragment_interval":fragment_interval, } for item in tts_pipline.run(inputs): @@ -188,6 +189,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Column(): batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) + fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True) speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) @@ -216,7 +218,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: top_k, top_p, temperature, how_to_cut, batch_size, speed_factor, ref_text_free, - split_bucket + split_bucket,fragment_interval, ], [output], )