diff --git a/.gitignore b/.gitignore index 6f846a91..28b8a7a5 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,8 @@ reference GPT_weights SoVITS_weights TEMP +PortableGit ffmpeg.exe ffprobe.exe - +tmp_audio +trained diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c5f227c1..a9ab1562 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -249,8 +249,6 @@ class TTS: if self.configs.is_half and str(self.configs.device)!="cpu": self.bert_model = self.bert_model.half() - - def init_vits_weights(self, weights_path: str): print(f"Loading VITS weights from {weights_path}") self.configs.vits_weights_path = weights_path @@ -274,7 +272,7 @@ class TTS: # if ("pretrained" not in weights_path): if hasattr(vits_model, "enc_q"): del vits_model.enc_q - + vits_model = vits_model.to(self.configs.device) vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) @@ -282,7 +280,6 @@ class TTS: if self.configs.is_half and str(self.configs.device)!="cpu": self.vits_model = self.vits_model.half() - def init_t2s_weights(self, weights_path: str): print(f"Loading Text2Semantic weights from {weights_path}") self.configs.t2s_weights_path = weights_path @@ -299,7 +296,7 @@ class TTS: self.t2s_model = t2s_model if self.configs.is_half and str(self.configs.device)!="cpu": self.t2s_model = self.t2s_model.half() - + def enable_half_precision(self, enable: bool = True): ''' To enable half precision for the TTS model. @@ -310,7 +307,7 @@ class TTS: if str(self.configs.device) == "cpu" and enable: print("Half precision is not supported on CPU.") return - + self.configs.is_half = enable self.precision = torch.float16 if enable else torch.float32 self.configs.save_configs() @@ -332,7 +329,7 @@ class TTS: self.bert_model = self.bert_model.float() if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.float() - + def set_device(self, device: torch.device): ''' To set the device for all models. @@ -349,7 +346,7 @@ class TTS: self.bert_model = self.bert_model.to(device) if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.to(device) - + def set_ref_audio(self, ref_audio_path:str): ''' To set the reference audio for the TTS model, @@ -359,7 +356,7 @@ class TTS: ''' self._set_prompt_semantic(ref_audio_path) self._set_ref_spec(ref_audio_path) - + def _set_ref_spec(self, ref_audio_path): audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) audio = torch.FloatTensor(audio) @@ -378,8 +375,7 @@ class TTS: spec = spec.half() # self.refer_spec = spec self.prompt_cache["refer_spec"] = spec - - + def _set_prompt_semantic(self, ref_wav_path:str): zero_wav = np.zeros( int(self.configs.sampling_rate * 0.3), @@ -404,10 +400,10 @@ class TTS: 1, 2 ) # .float() codes = self.vits_model.extract_latent(hubert_feature) - + prompt_semantic = codes[0, 0].to(self.configs.device) self.prompt_cache["prompt_semantic"] = prompt_semantic - + def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None): seq = sequences[0] ndim = seq.dim() @@ -420,7 +416,8 @@ class TTS: max_length = max(seq_lengths) else: max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length - + # 我爱套 torch.no_grad() + # with torch.no_grad(): padded_sequences = [] for seq, length in zip(sequences, seq_lengths): padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1) @@ -428,7 +425,7 @@ class TTS: padded_sequences.append(padded_seq) batch = torch.stack(padded_sequences) return batch - + def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, @@ -438,112 +435,114 @@ class TTS: precision:torch.dtype=torch.float32, ): - _data:list = [] - index_and_len_list = [] - for idx, item in enumerate(data): - norm_text_len = len(item["norm_text"]) - index_and_len_list.append([idx, norm_text_len]) + # 但是这里不能套,反而会负优化 + # with torch.no_grad(): + _data:list = [] + index_and_len_list = [] + for idx, item in enumerate(data): + norm_text_len = len(item["norm_text"]) + index_and_len_list.append([idx, norm_text_len]) - batch_index_list = [] - if split_bucket: - index_and_len_list.sort(key=lambda x: x[1]) - index_and_len_list = np.array(index_and_len_list, dtype=np.int64) - - batch_index_list_len = 0 - pos = 0 - while pos =threshold) or (pos_end-pos==1): - batch_index=index_and_len_list[pos:pos_end, 0].tolist() - batch_index_list_len += len(batch_index) - batch_index_list.append(batch_index) - pos = pos_end - break - pos_end=pos_end-1 - - assert batch_index_list_len == len(data) - - else: - for i in range(len(data)): - if i%batch_size == 0: - batch_index_list.append([]) - batch_index_list[-1].append(i) + batch_index_list = [] + if split_bucket: + index_and_len_list.sort(key=lambda x: x[1]) + index_and_len_list = np.array(index_and_len_list, dtype=np.int64) - - for batch_idx, index_list in enumerate(batch_index_list): - item_list = [data[idx] for idx in index_list] - phones_list = [] - phones_len_list = [] - # bert_features_list = [] - all_phones_list = [] - all_phones_len_list = [] - all_bert_features_list = [] - norm_text_batch = [] - bert_max_len = 0 - phones_max_len = 0 - for item in item_list: - if prompt_data is not None: - all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ + batch_index_list_len = 0 + pos = 0 + while pos =threshold) or (pos_end-pos==1): + batch_index=index_and_len_list[pos:pos_end, 0].tolist() + batch_index_list_len += len(batch_index) + batch_index_list.append(batch_index) + pos = pos_end + break + pos_end=pos_end-1 + + assert batch_index_list_len == len(data) + + else: + for i in range(len(data)): + if i%batch_size == 0: + batch_index_list.append([]) + batch_index_list[-1].append(i) + + for batch_idx, index_list in enumerate(batch_index_list): + item_list = [data[idx] for idx in index_list] + phones_list = [] + phones_len_list = [] + # bert_features_list = [] + all_phones_list = [] + all_phones_len_list = [] + all_bert_features_list = [] + norm_text_batch = [] + bert_max_len = 0 + phones_max_len = 0 + # 但是这里也不能套,反而会负优化 + # with torch.no_grad(): + for item in item_list: + if prompt_data is not None: + all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ + .to(dtype=precision, device=device) + all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device) + phones = torch.LongTensor(item["phones"]).to(device) + # norm_text = prompt_data["norm_text"]+item["norm_text"] + else: + all_bert_features = item["bert_features"]\ .to(dtype=precision, device=device) - all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device) - phones = torch.LongTensor(item["phones"]).to(device) - # norm_text = prompt_data["norm_text"]+item["norm_text"] - else: - all_bert_features = item["bert_features"]\ - .to(dtype=precision, device=device) - phones = torch.LongTensor(item["phones"]).to(device) - all_phones = phones - # norm_text = item["norm_text"] + phones = torch.LongTensor(item["phones"]).to(device) + all_phones = phones + # norm_text = item["norm_text"] + + bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) + phones_max_len = max(phones_max_len, phones.shape[-1]) + + phones_list.append(phones) + phones_len_list.append(phones.shape[-1]) + all_phones_list.append(all_phones) + all_phones_len_list.append(all_phones.shape[-1]) + all_bert_features_list.append(all_bert_features) + norm_text_batch.append(item["norm_text"]) + + phones_batch = phones_list + all_phones_batch = all_phones_list + all_bert_features_batch = all_bert_features_list + + # max_len = max(bert_max_len, phones_max_len) + # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) + #### 直接对phones和bert_features进行pad,会增大复读概率。 + # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) + # all_bert_features_batch = all_bert_features_list + # all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device) + # for idx, item in enumerate(all_bert_features_list): + # all_bert_features_batch[idx, :, : item.shape[-1]] = item + + # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) + # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] + # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list] + # all_phones_batch = torch.stack(all_phones_list, dim=0) + + # all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list] + # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list] + # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) + + batch = { + "phones": phones_batch, + "phones_len": torch.LongTensor(phones_len_list).to(device), + "all_phones": all_phones_batch, + "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), + "all_bert_features": all_bert_features_batch, + "norm_text": norm_text_batch + } + _data.append(batch) + + return _data, batch_index_list - bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) - phones_max_len = max(phones_max_len, phones.shape[-1]) - - phones_list.append(phones) - phones_len_list.append(phones.shape[-1]) - all_phones_list.append(all_phones) - all_phones_len_list.append(all_phones.shape[-1]) - all_bert_features_list.append(all_bert_features) - norm_text_batch.append(item["norm_text"]) - - phones_batch = phones_list - all_phones_batch = all_phones_list - all_bert_features_batch = all_bert_features_list - - - # max_len = max(bert_max_len, phones_max_len) - # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) - #### 直接对phones和bert_features进行pad,会增大复读概率。 - # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) - # all_bert_features_batch = all_bert_features_list - # all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device) - # for idx, item in enumerate(all_bert_features_list): - # all_bert_features_batch[idx, :, : item.shape[-1]] = item - - # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) - # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] - # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list] - # all_phones_batch = torch.stack(all_phones_list, dim=0) - - # all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list] - # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list] - # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) - - batch = { - "phones": phones_batch, - "phones_len": torch.LongTensor(phones_len_list).to(device), - "all_phones": all_phones_batch, - "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), - "all_bert_features": all_bert_features_batch, - "norm_text": norm_text_batch - } - _data.append(batch) - - return _data, batch_index_list - def recovery_order(self, data:list, batch_index_list:list)->list: ''' Recovery the order of the audio according to the batch_index_list. @@ -567,8 +566,7 @@ class TTS: Stop the inference process. ''' self.stop_flag = True - - + def run(self, inputs:dict): """ Text to speech inference. @@ -596,156 +594,159 @@ class TTS: returns: tuple[int, np.ndarray]: sampling rate and audio data. """ - ########## variables initialization ########### - self.stop_flag:bool = False - text:str = inputs.get("text", "") - text_lang:str = inputs.get("text_lang", "") - ref_audio_path:str = inputs.get("ref_audio_path", "") - prompt_text:str = inputs.get("prompt_text", "") - prompt_lang:str = inputs.get("prompt_lang", "") - top_k:int = inputs.get("top_k", 5) - top_p:float = inputs.get("top_p", 1) - temperature:float = inputs.get("temperature", 1) - text_split_method:str = inputs.get("text_split_method", "cut0") - batch_size = inputs.get("batch_size", 1) - batch_threshold = inputs.get("batch_threshold", 0.75) - speed_factor = inputs.get("speed_factor", 1.0) - split_bucket = inputs.get("split_bucket", True) - return_fragment = inputs.get("return_fragment", False) - fragment_interval = inputs.get("fragment_interval", 0.3) - seed = inputs.get("seed", -1) - seed = -1 if seed in ["", None] else seed - actual_seed = set_seed(seed) - if return_fragment: - # split_bucket = False - print(i18n("分段返回模式已开启")) + def make_batch(batch_texts): + batch_data = [] + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(batch_texts): + phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) + if phones is None: + continue + res={ + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + batch_data.append(res) + if len(batch_data) == 0: + return None + batch, _ = self.to_batch(batch_data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=False, + device=self.configs.device, + precision=self.precision + ) + return batch[0] + + # 直接给全体套一个torch.no_grad() + with torch.no_grad(): + ########## variables initialization ########### + self.stop_flag:bool = False + text:str = inputs.get("text", "") + text_lang:str = inputs.get("text_lang", "") + ref_audio_path:str = inputs.get("ref_audio_path", "") + prompt_text:str = inputs.get("prompt_text", "") + prompt_lang:str = inputs.get("prompt_lang", "") + top_k:int = inputs.get("top_k", 5) + top_p:float = inputs.get("top_p", 1) + temperature:float = inputs.get("temperature", 1) + text_split_method:str = inputs.get("text_split_method", "cut0") + batch_size = inputs.get("batch_size", 1) + batch_threshold = inputs.get("batch_threshold", 0.75) + speed_factor = inputs.get("speed_factor", 1.0) + split_bucket = inputs.get("split_bucket", True) + return_fragment = inputs.get("return_fragment", False) + fragment_interval = inputs.get("fragment_interval", 0.3) + seed = inputs.get("seed", -1) + seed = -1 if seed in ["", None] else seed + actual_seed = set_seed(seed) + + if return_fragment: + # split_bucket = False + print(i18n("分段返回模式已开启")) + if split_bucket: + split_bucket = False + print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + if split_bucket: - split_bucket = False - print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) - - if split_bucket: - print(i18n("分桶处理模式已开启")) - - if fragment_interval<0.01: - fragment_interval = 0.01 - print(i18n("分段间隔过小,已自动设置为0.01")) - - no_prompt_text = False - if prompt_text in [None, ""]: - no_prompt_text = True - - assert text_lang in self.configs.languages - if not no_prompt_text: - assert prompt_lang in self.configs.languages + print(i18n("分桶处理模式已开启")) - if ref_audio_path in [None, ""] and \ - ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)): - raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") + if fragment_interval<0.01: + fragment_interval = 0.01 + print(i18n("分段间隔过小,已自动设置为0.01")) + no_prompt_text = False + if prompt_text in [None, ""]: + no_prompt_text = True - ###### setting reference audio and prompt text preprocessing ######## - t0 = ttime() - if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): + assert text_lang in self.configs.languages + if not no_prompt_text: + assert prompt_lang in self.configs.languages + + if ref_audio_path in [None, ""] and \ + ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)): + raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") + + ###### setting reference audio and prompt text preprocessing ######## + t0 = ttime() + if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): self.set_ref_audio(ref_audio_path) - if not no_prompt_text: - prompt_text = prompt_text.strip("\n") - if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "." - print(i18n("实际输入的参考文本:"), prompt_text) - if self.prompt_cache["prompt_text"] != prompt_text: - self.prompt_cache["prompt_text"] = prompt_text - self.prompt_cache["prompt_lang"] = prompt_lang - phones, bert_features, norm_text = \ - self.text_preprocessor.segment_and_extract_feature_for_text( - prompt_text, - prompt_lang) - self.prompt_cache["phones"] = phones - self.prompt_cache["bert_features"] = bert_features - self.prompt_cache["norm_text"] = norm_text - - - ###### text preprocessing ######## - t1 = ttime() - data:list = None - if not return_fragment: - data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) - if len(data) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) - return - - batch_index_list:list = None - data, batch_index_list = self.to_batch(data, - prompt_data=self.prompt_cache if not no_prompt_text else None, - batch_size=batch_size, - threshold=batch_threshold, - split_bucket=split_bucket, - device=self.configs.device, - precision=self.precision - ) - else: - print(i18n("############ 切分文本 ############")) - texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) - data = [] - for i in range(len(texts)): - if i%batch_size == 0: - data.append([]) - data[-1].append(texts[i]) - - def make_batch(batch_texts): - batch_data = [] - print(i18n("############ 提取文本Bert特征 ############")) - for text in tqdm(batch_texts): - phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) - if phones is None: - continue - res={ - "phones": phones, - "bert_features": bert_features, - "norm_text": norm_text, - } - batch_data.append(res) - if len(batch_data) == 0: - return None - batch, _ = self.to_batch(batch_data, - prompt_data=self.prompt_cache if not no_prompt_text else None, - batch_size=batch_size, - threshold=batch_threshold, - split_bucket=False, - device=self.configs.device, - precision=self.precision - ) - return batch[0] - - t2 = ttime() - try: - print("############ 推理 ############") - ###### inference ###### - t_34 = 0.0 - t_45 = 0.0 - audio = [] - for item in data: - t3 = ttime() - if return_fragment: - item = make_batch(item) - if item is None: - continue - - batch_phones:List[torch.LongTensor] = item["phones"] - batch_phones_len:torch.LongTensor = item["phones_len"] - all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] - all_phoneme_lens:torch.LongTensor = item["all_phones_len"] - all_bert_features:List[torch.LongTensor] = item["all_bert_features"] - norm_text:str = item["norm_text"] - - print(i18n("前端处理后的文本(每句):"), norm_text) - if no_prompt_text : - prompt = None - else: - prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) - - with torch.no_grad(): + if not no_prompt_text: + prompt_text = prompt_text.strip("\n") + if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "." + print(i18n("实际输入的参考文本:"), prompt_text) + if self.prompt_cache["prompt_text"] != prompt_text: + self.prompt_cache["prompt_text"] = prompt_text + self.prompt_cache["prompt_lang"] = prompt_lang + phones, bert_features, norm_text = \ + self.text_preprocessor.segment_and_extract_feature_for_text( + prompt_text, + prompt_lang) + self.prompt_cache["phones"] = phones + self.prompt_cache["bert_features"] = bert_features + self.prompt_cache["norm_text"] = norm_text + + ###### text preprocessing ######## + t1 = ttime() + data:list = None + if not return_fragment: + data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) + if len(data) == 0: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + return + + batch_index_list:list = None + data, batch_index_list = self.to_batch(data, + prompt_data=self.prompt_cache if not no_prompt_text else None, + batch_size=batch_size, + threshold=batch_threshold, + split_bucket=split_bucket, + device=self.configs.device, + precision=self.precision + ) + else: + print(i18n("############ 切分文本 ############")) + texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) + data = [] + for i in range(len(texts)): + if i%batch_size == 0: + data.append([]) + data[-1].append(texts[i]) + + + + t2 = ttime() + try: + print("############ 推理 ############") + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + for item in data: + t3 = ttime() + if return_fragment: + item = make_batch(item) + if item is None: + continue + + batch_phones:List[torch.LongTensor] = item["phones"] + batch_phones_len:torch.LongTensor = item["phones_len"] + all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] + all_phoneme_lens:torch.LongTensor = item["all_phones_len"] + all_bert_features:List[torch.LongTensor] = item["all_bert_features"] + norm_text:str = item["norm_text"] + + print(i18n("前端处理后的文本(每句):"), norm_text) + if no_prompt_text : + prompt = None + else: + prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) + + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_lens, @@ -757,97 +758,99 @@ class TTS: temperature=temperature, early_stop_num=self.configs.hz * self.configs.max_sec, ) - t4 = ttime() - t_34 += t4 - t3 - - refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\ - .to(dtype=self.precision, device=self.configs.device) - - batch_audio_fragment = [] + t4 = ttime() + t_34 += t4 - t3 - # ## vits并行推理 method 1 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # max_len = 0 - # for i in range(0, len(batch_phones)): - # max_len = max(max_len, batch_phones[i].shape[-1]) - # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment = (self.vits_model.batched_decode( - # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec - # )) - - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] - audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec - ).detach()[0, 0, :]) - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] - + refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\ + .to(dtype=self.precision, device=self.configs.device) - # ## vits串行推理 - # for i, idx in enumerate(idx_list): - # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - # audio_fragment =(self.vits_model.decode( - # _pred_semantic, phones, refer_audio_spec - # ).detach()[0, 0, :]) - # batch_audio_fragment.append( - # audio_fragment - # ) ###试试重建不带上prompt部分 + batch_audio_fragment = [] + + # 这里要记得加 torch.no_grad() 不然速度慢一大截 + # with torch.no_grad(): - t5 = ttime() - t_45 += t5 - t4 - if return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield self.audio_postprocess([batch_audio_fragment], + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec + # )) + + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = (self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + + # ## vits串行推理 + # for i, idx in enumerate(idx_list): + # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + # audio_fragment =(self.vits_model.decode( + # _pred_semantic, phones, refer_audio_spec + # ).detach()[0, 0, :]) + # batch_audio_fragment.append( + # audio_fragment + # ) ###试试重建不带上prompt部分 + + t5 = ttime() + t_45 += t5 - t4 + if return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) + yield self.audio_postprocess([batch_audio_fragment], + self.configs.sampling_rate, + None, + speed_factor, + False, + fragment_interval + ) + else: + audio.append(batch_audio_fragment) + + if self.stop_flag: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + return + + if not return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + yield self.audio_postprocess(audio, self.configs.sampling_rate, - None, + batch_index_list, speed_factor, - False, + split_bucket, fragment_interval ) - else: - audio.append(batch_audio_fragment) - - if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) - return - if not return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) - yield self.audio_postprocess(audio, - self.configs.sampling_rate, - batch_index_list, - speed_factor, - split_bucket, - fragment_interval - ) + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + # 重置模型, 否则会导致显存释放不完全。 + del self.t2s_model + del self.vits_model + self.t2s_model = None + self.vits_model = None + self.init_t2s_weights(self.configs.t2s_weights_path) + self.init_vits_weights(self.configs.vits_weights_path) + raise e + finally: + self.empty_cache() - except Exception as e: - traceback.print_exc() - # 必须返回一个空音频, 否则会导致显存不释放。 - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) - # 重置模型, 否则会导致显存释放不完全。 - del self.t2s_model - del self.vits_model - self.t2s_model = None - self.vits_model = None - self.init_t2s_weights(self.configs.t2s_weights_path) - self.init_vits_weights(self.configs.vits_weights_path) - raise e - finally: - self.empty_cache() - def empty_cache(self): try: if "cuda" in str(self.configs.device): @@ -856,7 +859,7 @@ class TTS: torch.mps.empty_cache() except: pass - + def audio_postprocess(self, audio:List[torch.Tensor], sr:int, @@ -870,36 +873,32 @@ class TTS: dtype=self.precision, device=self.configs.device ) - + for i, batch in enumerate(audio): for j, audio_fragment in enumerate(batch): max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 if max_audio>1: audio_fragment/=max_audio audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) audio[i][j] = audio_fragment.cpu().numpy() - - + if split_bucket: audio = self.recovery_order(audio, batch_index_list) else: # audio = [item for batch in audio for item in batch] audio = sum(audio, []) - - + audio = np.concatenate(audio, 0) audio = (audio * 32768).astype(np.int16) - + try: if speed_factor != 1.0: audio = speed_change(audio, speed=speed_factor, sr=int(sr)) except Exception as e: print(f"Failed to change speed of audio: \n{e}") - + return sr, audio - - - - + + def speed_change(input_audio:np.ndarray, speed:float, sr:int): # 将 NumPy 数组转换为原始 PCM 流 raw_audio = input_audio.astype(np.int16).tobytes() @@ -919,4 +918,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int): # 将管道输出解码为 NumPy 数组 processed_audio = np.frombuffer(out, np.int16) - return processed_audio \ No newline at end of file + return processed_audio