diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 509f0592..f3d7dd54 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -568,8 +568,9 @@ class TTS: Stop the inference process. ''' self.stop_flag = True - - + + # 使用装饰器 + @torch.no_grad() def run(self, inputs:dict): """ Text to speech inference. @@ -600,260 +601,259 @@ class TTS: - # 直接给全体套一个torch.no_grad() - with torch.no_grad(): - ########## variables initialization ########### - self.stop_flag:bool = False - text:str = inputs.get("text", "") - text_lang:str = inputs.get("text_lang", "") - ref_audio_path:str = inputs.get("ref_audio_path", "") - prompt_text:str = inputs.get("prompt_text", "") - prompt_lang:str = inputs.get("prompt_lang", "") - top_k:int = inputs.get("top_k", 5) - top_p:float = inputs.get("top_p", 1) - temperature:float = inputs.get("temperature", 1) - text_split_method:str = inputs.get("text_split_method", "cut0") - batch_size = inputs.get("batch_size", 1) - batch_threshold = inputs.get("batch_threshold", 0.75) - speed_factor = inputs.get("speed_factor", 1.0) - split_bucket = inputs.get("split_bucket", True) - return_fragment = inputs.get("return_fragment", False) - fragment_interval = inputs.get("fragment_interval", 0.3) - seed = inputs.get("seed", -1) - seed = -1 if seed in ["", None] else seed - actual_seed = set_seed(seed) - - if return_fragment: - # split_bucket = False - print(i18n("分段返回模式已开启")) - if split_bucket: - split_bucket = False - print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + + ########## variables initialization ########### + self.stop_flag:bool = False + text:str = inputs.get("text", "") + text_lang:str = inputs.get("text_lang", "") + ref_audio_path:str = inputs.get("ref_audio_path", "") + prompt_text:str = inputs.get("prompt_text", "") + prompt_lang:str = inputs.get("prompt_lang", "") + top_k:int = inputs.get("top_k", 5) + top_p:float = inputs.get("top_p", 1) + temperature:float = inputs.get("temperature", 1) + text_split_method:str = inputs.get("text_split_method", "cut0") + batch_size = inputs.get("batch_size", 1) + batch_threshold = inputs.get("batch_threshold", 0.75) + speed_factor = inputs.get("speed_factor", 1.0) + split_bucket = inputs.get("split_bucket", True) + return_fragment = inputs.get("return_fragment", False) + fragment_interval = inputs.get("fragment_interval", 0.3) + seed = inputs.get("seed", -1) + seed = -1 if seed in ["", None] else seed + actual_seed = set_seed(seed) + if return_fragment: + # split_bucket = False + print(i18n("分段返回模式已开启")) if split_bucket: - print(i18n("分桶处理模式已开启")) + split_bucket = False + print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) - if fragment_interval<0.01: - fragment_interval = 0.01 - print(i18n("分段间隔过小,已自动设置为0.01")) + if split_bucket: + print(i18n("分桶处理模式已开启")) - no_prompt_text = False - if prompt_text in [None, ""]: - no_prompt_text = True + if fragment_interval<0.01: + fragment_interval = 0.01 + print(i18n("分段间隔过小,已自动设置为0.01")) - assert text_lang in self.configs.languages - if not no_prompt_text: - assert prompt_lang in self.configs.languages + no_prompt_text = False + if prompt_text in [None, ""]: + no_prompt_text = True - if ref_audio_path in [None, ""] and \ - ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)): - raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") + assert text_lang in self.configs.languages + if not no_prompt_text: + assert prompt_lang in self.configs.languages - ###### setting reference audio and prompt text preprocessing ######## - t0 = ttime() - if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): - self.set_ref_audio(ref_audio_path) + if ref_audio_path in [None, ""] and \ + ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)): + raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") - if not no_prompt_text: - prompt_text = prompt_text.strip("\n") - if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "." - print(i18n("实际输入的参考文本:"), prompt_text) - if self.prompt_cache["prompt_text"] != prompt_text: - self.prompt_cache["prompt_text"] = prompt_text - self.prompt_cache["prompt_lang"] = prompt_lang - phones, bert_features, norm_text = \ - self.text_preprocessor.segment_and_extract_feature_for_text( - prompt_text, - prompt_lang) - self.prompt_cache["phones"] = phones - self.prompt_cache["bert_features"] = bert_features - self.prompt_cache["norm_text"] = norm_text + ###### setting reference audio and prompt text preprocessing ######## + t0 = ttime() + if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): + self.set_ref_audio(ref_audio_path) - ###### text preprocessing ######## - t1 = ttime() - data:list = None - if not return_fragment: - data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) - if len(data) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) - return + if not no_prompt_text: + prompt_text = prompt_text.strip("\n") + if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "." + print(i18n("实际输入的参考文本:"), prompt_text) + if self.prompt_cache["prompt_text"] != prompt_text: + self.prompt_cache["prompt_text"] = prompt_text + self.prompt_cache["prompt_lang"] = prompt_lang + phones, bert_features, norm_text = \ + self.text_preprocessor.segment_and_extract_feature_for_text( + prompt_text, + prompt_lang) + self.prompt_cache["phones"] = phones + self.prompt_cache["bert_features"] = bert_features + self.prompt_cache["norm_text"] = norm_text - batch_index_list:list = None - data, batch_index_list = self.to_batch(data, - prompt_data=self.prompt_cache if not no_prompt_text else None, - batch_size=batch_size, - threshold=batch_threshold, - split_bucket=split_bucket, - device=self.configs.device, - precision=self.precision - ) - else: - print(i18n("############ 切分文本 ############")) - texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) - data = [] - for i in range(len(texts)): - if i%batch_size == 0: - data.append([]) - data[-1].append(texts[i]) - - def make_batch(batch_texts): - batch_data = [] - print(i18n("############ 提取文本Bert特征 ############")) - for text in tqdm(batch_texts): - phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) - if phones is None: - continue - res={ - "phones": phones, - "bert_features": bert_features, - "norm_text": norm_text, - } - batch_data.append(res) - if len(batch_data) == 0: - return None - batch, _ = self.to_batch(batch_data, + ###### text preprocessing ######## + t1 = ttime() + data:list = None + if not return_fragment: + data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) + if len(data) == 0: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + return + + batch_index_list:list = None + data, batch_index_list = self.to_batch(data, prompt_data=self.prompt_cache if not no_prompt_text else None, batch_size=batch_size, threshold=batch_threshold, - split_bucket=False, + split_bucket=split_bucket, device=self.configs.device, precision=self.precision ) - return batch[0] + else: + print(i18n("############ 切分文本 ############")) + texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) + data = [] + for i in range(len(texts)): + if i%batch_size == 0: + data.append([]) + data[-1].append(texts[i]) + + def make_batch(batch_texts): + batch_data = [] + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(batch_texts): + phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) + if phones is None: + continue + res={ + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + batch_data.append(res) + if len(batch_data) == 0: + return None + batch, _ = self.to_batch(batch_data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=False, + device=self.configs.device, + precision=self.precision + ) + return batch[0] - t2 = ttime() - try: - print("############ 推理 ############") - ###### inference ###### - t_34 = 0.0 - t_45 = 0.0 - audio = [] - for item in data: - t3 = ttime() - if return_fragment: - item = make_batch(item) - if item is None: - continue + t2 = ttime() + try: + print("############ 推理 ############") + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + for item in data: + t3 = ttime() + if return_fragment: + item = make_batch(item) + if item is None: + continue - batch_phones:List[torch.LongTensor] = item["phones"] - batch_phones_len:torch.LongTensor = item["phones_len"] - all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] - all_phoneme_lens:torch.LongTensor = item["all_phones_len"] - all_bert_features:List[torch.LongTensor] = item["all_bert_features"] - norm_text:str = item["norm_text"] + batch_phones:List[torch.LongTensor] = item["phones"] + batch_phones_len:torch.LongTensor = item["phones_len"] + all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] + all_phoneme_lens:torch.LongTensor = item["all_phones_len"] + all_bert_features:List[torch.LongTensor] = item["all_bert_features"] + norm_text:str = item["norm_text"] - print(i18n("前端处理后的文本(每句):"), norm_text) - if no_prompt_text : - prompt = None - else: - prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) + print(i18n("前端处理后的文本(每句):"), norm_text) + if no_prompt_text : + prompt = None + else: + prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - ) - t4 = ttime() - t_34 += t4 - t3 + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + ) + t4 = ttime() + t_34 += t4 - t3 - refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\ - .to(dtype=self.precision, device=self.configs.device) + refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\ + .to(dtype=self.precision, device=self.configs.device) - batch_audio_fragment = [] + batch_audio_fragment = [] - # 这里要记得加 torch.no_grad() 不然速度慢一大截 - # with torch.no_grad(): - - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec - # )) + # 这里要记得加 torch.no_grad() 不然速度慢一大截 + # with torch.no_grad(): + + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec + # )) - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] - audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec - ).detach()[0, 0, :]) - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = (self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] - # ## vits串行推理 - # for i, idx in enumerate(idx_list): - # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - # audio_fragment =(self.vits_model.decode( - # _pred_semantic, phones, refer_audio_spec - # ).detach()[0, 0, :]) - # batch_audio_fragment.append( - # audio_fragment - # ) ###试试重建不带上prompt部分 + # ## vits串行推理 + # for i, idx in enumerate(idx_list): + # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + # audio_fragment =(self.vits_model.decode( + # _pred_semantic, phones, refer_audio_spec + # ).detach()[0, 0, :]) + # batch_audio_fragment.append( + # audio_fragment + # ) ###试试重建不带上prompt部分 - t5 = ttime() - t_45 += t5 - t4 - if return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield self.audio_postprocess([batch_audio_fragment], - self.configs.sampling_rate, - None, - speed_factor, - False, - fragment_interval - ) - else: - audio.append(batch_audio_fragment) - - if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) - return - - if not return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) - yield self.audio_postprocess(audio, + t5 = ttime() + t_45 += t5 - t4 + if return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) + yield self.audio_postprocess([batch_audio_fragment], self.configs.sampling_rate, - batch_index_list, + None, speed_factor, - split_bucket, + False, fragment_interval ) + else: + audio.append(batch_audio_fragment) - except Exception as e: - traceback.print_exc() - # 必须返回一个空音频, 否则会导致显存不释放。 - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) - # 重置模型, 否则会导致显存释放不完全。 - del self.t2s_model - del self.vits_model - self.t2s_model = None - self.vits_model = None - self.init_t2s_weights(self.configs.t2s_weights_path) - self.init_vits_weights(self.configs.vits_weights_path) - raise e - finally: - self.empty_cache() + if self.stop_flag: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + return + + if not return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + yield self.audio_postprocess(audio, + self.configs.sampling_rate, + batch_index_list, + speed_factor, + split_bucket, + fragment_interval + ) + + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + # 重置模型, 否则会导致显存释放不完全。 + del self.t2s_model + del self.vits_model + self.t2s_model = None + self.vits_model = None + self.init_t2s_weights(self.configs.t2s_weights_path) + self.init_vits_weights(self.configs.vits_weights_path) + raise e + finally: + self.empty_cache() def empty_cache(self): try: