From d60d8ea3fb5c1ab741365e4ebfb7fa3b4ea0853e Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Wed, 13 Mar 2024 16:25:27 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86OutOfMemoryError?= =?UTF-8?q?=E6=97=B6=EF=BC=8C=E6=98=BE=E5=AD=98=E6=97=A0=E6=B3=95=E9=87=8A?= =?UTF-8?q?=E6=94=BE=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 251 +++++++++++++++++-------------- 1 file changed, 139 insertions(+), 112 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 61ba7be1..62bb2e90 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -2,6 +2,7 @@ from copy import deepcopy import math import os, sys import random +import traceback now_dir = os.getcwd() sys.path.append(now_dir) import ffmpeg @@ -48,8 +49,18 @@ custom: """ - - +# def set_seed(seed): +# random.seed(seed) +# os.environ['PYTHONHASHSEED'] = str(seed) +# np.random.seed(seed) +# torch.manual_seed(seed) +# torch.cuda.manual_seed(seed) +# torch.cuda.manual_seed_all(seed) +# torch.backends.cudnn.deterministic = True +# torch.backends.cudnn.benchmark = False +# torch.backends.cudnn.enabled = True +# set_seed(1234) + class TTS_Config: default_configs={ "device": "cpu", @@ -630,125 +641,141 @@ class TTS: split_bucket=split_bucket ) t2 = ttime() + try: + print("############ 推理 ############") + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + for item in data: + t3 = ttime() + batch_phones = item["phones"] + batch_phones_len = item["phones_len"] + all_phoneme_ids = item["all_phones"] + all_phoneme_lens = item["all_phones_len"] + all_bert_features = item["all_bert_features"] + norm_text = item["norm_text"] + + # batch_phones = batch_phones.to(self.configs.device) + batch_phones_len = batch_phones_len.to(self.configs.device) + all_phoneme_ids = all_phoneme_ids.to(self.configs.device) + all_phoneme_lens = all_phoneme_lens.to(self.configs.device) + all_bert_features = all_bert_features.to(self.configs.device) + if self.configs.is_half: + all_bert_features = all_bert_features.half() - print("############ 推理 ############") - ###### inference ###### - t_34 = 0.0 - t_45 = 0.0 - audio = [] - for item in data: - t3 = ttime() - batch_phones = item["phones"] - batch_phones_len = item["phones_len"] - all_phoneme_ids = item["all_phones"] - all_phoneme_lens = item["all_phones_len"] - all_bert_features = item["all_bert_features"] - norm_text = item["norm_text"] - - # batch_phones = batch_phones.to(self.configs.device) - batch_phones_len = batch_phones_len.to(self.configs.device) - all_phoneme_ids = all_phoneme_ids.to(self.configs.device) - all_phoneme_lens = all_phoneme_lens.to(self.configs.device) - all_bert_features = all_bert_features.to(self.configs.device) - if self.configs.is_half: - all_bert_features = all_bert_features.half() - - print(i18n("前端处理后的文本(每句):"), norm_text) - if no_prompt_text : - prompt = None - else: - prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) - - with torch.no_grad(): - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - ) - t4 = ttime() - t_34 += t4 - t3 - - refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\ - .to(dtype=self.precison, device=self.configs.device) + print(i18n("前端处理后的文本(每句):"), norm_text) + if no_prompt_text : + prompt = None + else: + prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) - batch_audio_fragment = [] - - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc - # )) - - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] - audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( - all_pred_semantic, _batch_phones,refer_audio_spepc - ).detach()[0, 0, :]) - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + with torch.no_grad(): + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + ) + t4 = ttime() + t_34 += t4 - t3 + refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\ + .to(dtype=self.precison, device=self.configs.device) + + batch_audio_fragment = [] - # ## vits串行推理 - # for i, idx in enumerate(idx_list): - # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - # audio_fragment =(self.vits_model.decode( - # _pred_semantic, phones, refer_audio_spepc - # ).detach()[0, 0, :]) - # batch_audio_fragment.append( - # audio_fragment - # ) ###试试重建不带上prompt部分 - - t5 = ttime() - t_45 += t5 - t4 - if return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield self.audio_postprocess([batch_audio_fragment], + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc + # )) + + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = (self.vits_model.decode( + all_pred_semantic, _batch_phones,refer_audio_spepc + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + + + # ## vits串行推理 + # for i, idx in enumerate(idx_list): + # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + # audio_fragment =(self.vits_model.decode( + # _pred_semantic, phones, refer_audio_spepc + # ).detach()[0, 0, :]) + # batch_audio_fragment.append( + # audio_fragment + # ) ###试试重建不带上prompt部分 + + t5 = ttime() + t_45 += t5 - t4 + if return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) + yield self.audio_postprocess([batch_audio_fragment], + self.configs.sampling_rate, + batch_index_list, + speed_factor, + split_bucket) + else: + audio.append(batch_audio_fragment) + + if self.stop_flag: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), + dtype=np.int16) + return + + if not return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + yield self.audio_postprocess(audio, self.configs.sampling_rate, batch_index_list, speed_factor, - split_bucket) - else: - audio.append(batch_audio_fragment) - - if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), - dtype=np.int16) - return - - if not return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) - yield self.audio_postprocess(audio, - self.configs.sampling_rate, - batch_index_list, - speed_factor, - split_bucket) - - try: - torch.cuda.empty_cache() + split_bucket) + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + # 重置模型, 否则会导致显存释放不完全。 + del self.t2s_model + del self.vits_model + self.t2s_model = None + self.vits_model = None + self.init_t2s_weights(self.configs.t2s_weights_path) + self.init_vits_weights(self.configs.vits_weights_path) + finally: + self.empty_cache() + + def empty_cache(self): + try: + if str(self.configs.device) == "cuda": + torch.cuda.empty_cache() + elif str(self.configs.device) == "mps": + torch.mps.empty_cache() except: pass - - - + def audio_postprocess(self, audio:List[torch.Tensor], sr:int, From 7e012e7678ee8a74c68faada2a2885c2365bb8ba Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Wed, 13 Mar 2024 17:03:59 +0800 Subject: [PATCH 2/5] https://github.com/RVC-Boss/GPT-SoVITS/issues/747 https://github.com/RVC-Boss/GPT-SoVITS/issues/747 --- GPT_SoVITS/module/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index a4d22352..0a49513e 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -893,6 +893,7 @@ class SynthesizerTrn(nn.Module): if freeze_quantizer: self.ssl_proj.requires_grad_(False) self.quantizer.requires_grad_(False) + self.quentizer.eval() # self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False) From 252c9b7eb638958f6302b2788df08a84099a0f05 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Wed, 13 Mar 2024 19:51:24 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E6=8E=A8=E7=90=86=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E5=92=8C=E6=94=B9=E8=BF=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 32 ++++++++--- GPT_SoVITS/TTS_infer_pack/TTS.py | 93 +++++++++++++++++++++++-------- GPT_SoVITS/inference_webui.py | 8 ++- 3 files changed, 99 insertions(+), 34 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index dfd6eb0d..da95111b 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -707,19 +707,33 @@ class Text2SemanticDecoder(nn.Module): y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) ref_free = True - x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) - value=True, - ) - y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + ##### create mask ##### + bsz = x.shape[0] + src_len = x_len + y_len + y_lens = torch.LongTensor([y_len]*bsz).to(x.device) + y_mask = make_pad_mask(y_lens) + x_mask = make_pad_mask(x_lens) + + # (bsz, x_len + y_len) + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + + x_mask = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( - x.device - ) + + xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz*self.num_head, -1, -1).to(x.device) + # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) + xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(bsz, src_len, src_len).repeat(self.num_head, 1, 1) + xy_attn_mask = xy_mask.logical_or(xy_padding_mask) + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) y_list = [None]*y.shape[0] batch_idx_map = list(range(y.shape[0])) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 62bb2e90..e569f52d 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -3,6 +3,8 @@ import math import os, sys import random import traceback + +from tqdm import tqdm now_dir = os.getcwd() sys.path.append(now_dir) import ffmpeg @@ -552,14 +554,15 @@ class TTS: "prompt_text": "", # str. prompt text for the reference audio "prompt_lang": "", # str. language of the prompt text for the reference audio "top_k": 5, # int. top k sampling - "top_p": 1, # float. top p sampling - "temperature": 1, # float. temperature for sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling "text_split_method": "", # str. text split method, see text_segmentaion_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket: True, # bool. whether to split the batch into multiple buckets. "return_fragment": False, # bool. step by step return the audio fragment. "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. } returns: tulpe[int, np.ndarray]: sampling rate and audio data. @@ -580,9 +583,10 @@ class TTS: speed_factor = inputs.get("speed_factor", 1.0) split_bucket = inputs.get("split_bucket", True) return_fragment = inputs.get("return_fragment", False) + fragment_interval = inputs.get("fragment_interval", 0.3) if return_fragment: - split_bucket = False + # split_bucket = False print(i18n("分段返回模式已开启")) if split_bucket: split_bucket = False @@ -590,7 +594,10 @@ class TTS: if split_bucket: print(i18n("分桶处理模式已开启")) - + + if fragment_interval<0.01: + fragment_interval = 0.01 + print(i18n("分段间隔过小,已自动设置为0.01")) no_prompt_text = False if prompt_text in [None, ""]: @@ -627,19 +634,52 @@ class TTS: ###### text preprocessing ######## - data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) - if len(data) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), - dtype=np.int16) - return - t1 = ttime() - data, batch_index_list = self.to_batch(data, - prompt_data=self.prompt_cache if not no_prompt_text else None, - batch_size=batch_size, - threshold=batch_threshold, - split_bucket=split_bucket - ) + data:list = None + if not return_fragment: + data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) + if len(data) == 0: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + return + + batch_index_list:list = None + data, batch_index_list = self.to_batch(data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=split_bucket + ) + else: + print(i18n("############ 切分文本 ############")) + texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) + data = [] + for i in range(len(texts)): + if i%batch_size == 0: + data.append([]) + data[-1].append(texts[i]) + + def make_batch(batch_texts): + batch_data = [] + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(batch_texts): + phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) + if phones is None: + continue + res={ + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + batch_data.append(res) + batch, _ = self.to_batch(batch_data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=False + ) + return batch[0] + t2 = ttime() try: print("############ 推理 ############") @@ -649,6 +689,9 @@ class TTS: audio = [] for item in data: t3 = ttime() + if return_fragment: + item = make_batch(item) + batch_phones = item["phones"] batch_phones_len = item["phones_len"] all_phoneme_ids = item["all_phones"] @@ -734,14 +777,16 @@ class TTS: print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) yield self.audio_postprocess([batch_audio_fragment], self.configs.sampling_rate, - batch_index_list, + None, speed_factor, - split_bucket) + False, + fragment_interval + ) else: audio.append(batch_audio_fragment) if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16) return @@ -751,7 +796,9 @@ class TTS: self.configs.sampling_rate, batch_index_list, speed_factor, - split_bucket) + split_bucket, + fragment_interval + ) except Exception as e: traceback.print_exc() # 必须返回一个空音频, 否则会导致显存不释放。 @@ -781,9 +828,11 @@ class TTS: sr:int, batch_index_list:list=None, speed_factor:float=1.0, - split_bucket:bool=True)->tuple[int, np.ndarray]: + split_bucket:bool=True, + fragment_interval:float=0.3 + )->tuple[int, np.ndarray]: zero_wav = torch.zeros( - int(self.configs.sampling_rate * 0.3), + int(self.configs.sampling_rate * fragment_interval), dtype=self.precison, device=self.configs.device ) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 4c9d90f2..7ba5d1e3 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -91,7 +91,7 @@ def inference(text, text_lang, top_p, temperature, text_split_method, batch_size, speed_factor, ref_text_free, - split_bucket + split_bucket,fragment_interval, ): inputs={ "text": text, @@ -106,7 +106,8 @@ def inference(text, text_lang, "batch_size":int(batch_size), "speed_factor":float(speed_factor), "split_bucket":split_bucket, - "return_fragment":False + "return_fragment":False, + "fragment_interval":fragment_interval, } for item in tts_pipline.run(inputs): @@ -188,6 +189,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Column(): batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) + fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True) speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) @@ -216,7 +218,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: top_k, top_p, temperature, how_to_cut, batch_size, speed_factor, ref_text_free, - split_bucket + split_bucket,fragment_interval, ], [output], ) From 03ae7fdb03df551bc3ccaaf215e2ac598424edc0 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Wed, 13 Mar 2024 20:08:32 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=81=A5=E5=A3=AE?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index e569f52d..bcab7e84 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -672,6 +672,8 @@ class TTS: "norm_text": norm_text, } batch_data.append(res) + if len(batch_data) == 0: + return None batch, _ = self.to_batch(batch_data, prompt_data=self.prompt_cache if not no_prompt_text else None, batch_size=batch_size, @@ -691,6 +693,8 @@ class TTS: t3 = ttime() if return_fragment: item = make_batch(item) + if item is None: + continue batch_phones = item["phones"] batch_phones_len = item["phones_len"] From 3b11cd98145c0ccb734e7d76f75b795dedb270d8 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Wed, 13 Mar 2024 20:57:45 +0800 Subject: [PATCH 5/5] Update models.py --- GPT_SoVITS/module/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 0a49513e..a03f9942 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -893,7 +893,7 @@ class SynthesizerTrn(nn.Module): if freeze_quantizer: self.ssl_proj.requires_grad_(False) self.quantizer.requires_grad_(False) - self.quentizer.eval() + self.quantizer.eval() # self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False)