diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 61ba7be..62bb2e9 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -2,6 +2,7 @@ from copy import deepcopy import math import os, sys import random +import traceback now_dir = os.getcwd() sys.path.append(now_dir) import ffmpeg @@ -48,8 +49,18 @@ custom: """ - - +# def set_seed(seed): +# random.seed(seed) +# os.environ['PYTHONHASHSEED'] = str(seed) +# np.random.seed(seed) +# torch.manual_seed(seed) +# torch.cuda.manual_seed(seed) +# torch.cuda.manual_seed_all(seed) +# torch.backends.cudnn.deterministic = True +# torch.backends.cudnn.benchmark = False +# torch.backends.cudnn.enabled = True +# set_seed(1234) + class TTS_Config: default_configs={ "device": "cpu", @@ -630,125 +641,141 @@ class TTS: split_bucket=split_bucket ) t2 = ttime() + try: + print("############ 推理 ############") + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + for item in data: + t3 = ttime() + batch_phones = item["phones"] + batch_phones_len = item["phones_len"] + all_phoneme_ids = item["all_phones"] + all_phoneme_lens = item["all_phones_len"] + all_bert_features = item["all_bert_features"] + norm_text = item["norm_text"] + + # batch_phones = batch_phones.to(self.configs.device) + batch_phones_len = batch_phones_len.to(self.configs.device) + all_phoneme_ids = all_phoneme_ids.to(self.configs.device) + all_phoneme_lens = all_phoneme_lens.to(self.configs.device) + all_bert_features = all_bert_features.to(self.configs.device) + if self.configs.is_half: + all_bert_features = all_bert_features.half() - print("############ 推理 ############") - ###### inference ###### - t_34 = 0.0 - t_45 = 0.0 - audio = [] - for item in data: - t3 = ttime() - batch_phones = item["phones"] - batch_phones_len = item["phones_len"] - all_phoneme_ids = item["all_phones"] - all_phoneme_lens = item["all_phones_len"] - all_bert_features = item["all_bert_features"] - norm_text = item["norm_text"] - - # batch_phones = batch_phones.to(self.configs.device) - batch_phones_len = batch_phones_len.to(self.configs.device) - all_phoneme_ids = all_phoneme_ids.to(self.configs.device) - all_phoneme_lens = all_phoneme_lens.to(self.configs.device) - all_bert_features = all_bert_features.to(self.configs.device) - if self.configs.is_half: - all_bert_features = all_bert_features.half() - - print(i18n("前端处理后的文本(每句):"), norm_text) - if no_prompt_text : - prompt = None - else: - prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) - - with torch.no_grad(): - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - ) - t4 = ttime() - t_34 += t4 - t3 - - refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\ - .to(dtype=self.precison, device=self.configs.device) + print(i18n("前端处理后的文本(每句):"), norm_text) + if no_prompt_text : + prompt = None + else: + prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) - batch_audio_fragment = [] - - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc - # )) - - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] - audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( - all_pred_semantic, _batch_phones,refer_audio_spepc - ).detach()[0, 0, :]) - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + with torch.no_grad(): + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, + ) + t4 = ttime() + t_34 += t4 - t3 + refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\ + .to(dtype=self.precison, device=self.configs.device) + + batch_audio_fragment = [] - # ## vits串行推理 - # for i, idx in enumerate(idx_list): - # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - # audio_fragment =(self.vits_model.decode( - # _pred_semantic, phones, refer_audio_spepc - # ).detach()[0, 0, :]) - # batch_audio_fragment.append( - # audio_fragment - # ) ###试试重建不带上prompt部分 - - t5 = ttime() - t_45 += t5 - t4 - if return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield self.audio_postprocess([batch_audio_fragment], + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc + # )) + + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = (self.vits_model.decode( + all_pred_semantic, _batch_phones,refer_audio_spepc + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + + + # ## vits串行推理 + # for i, idx in enumerate(idx_list): + # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + # audio_fragment =(self.vits_model.decode( + # _pred_semantic, phones, refer_audio_spepc + # ).detach()[0, 0, :]) + # batch_audio_fragment.append( + # audio_fragment + # ) ###试试重建不带上prompt部分 + + t5 = ttime() + t_45 += t5 - t4 + if return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) + yield self.audio_postprocess([batch_audio_fragment], + self.configs.sampling_rate, + batch_index_list, + speed_factor, + split_bucket) + else: + audio.append(batch_audio_fragment) + + if self.stop_flag: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), + dtype=np.int16) + return + + if not return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + yield self.audio_postprocess(audio, self.configs.sampling_rate, batch_index_list, speed_factor, - split_bucket) - else: - audio.append(batch_audio_fragment) - - if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), - dtype=np.int16) - return - - if not return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) - yield self.audio_postprocess(audio, - self.configs.sampling_rate, - batch_index_list, - speed_factor, - split_bucket) - - try: - torch.cuda.empty_cache() + split_bucket) + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + # 重置模型, 否则会导致显存释放不完全。 + del self.t2s_model + del self.vits_model + self.t2s_model = None + self.vits_model = None + self.init_t2s_weights(self.configs.t2s_weights_path) + self.init_vits_weights(self.configs.vits_weights_path) + finally: + self.empty_cache() + + def empty_cache(self): + try: + if str(self.configs.device) == "cuda": + torch.cuda.empty_cache() + elif str(self.configs.device) == "mps": + torch.mps.empty_cache() except: pass - - - + def audio_postprocess(self, audio:List[torch.Tensor], sr:int,