diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index a9ab1562..d3e394ac 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -272,7 +272,7 @@ class TTS: # if ("pretrained" not in weights_path): if hasattr(vits_model, "enc_q"): del vits_model.enc_q - + vits_model = vits_model.to(self.configs.device) vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) @@ -280,6 +280,7 @@ class TTS: if self.configs.is_half and str(self.configs.device)!="cpu": self.vits_model = self.vits_model.half() + def init_t2s_weights(self, weights_path: str): print(f"Loading Text2Semantic weights from {weights_path}") self.configs.t2s_weights_path = weights_path @@ -296,7 +297,7 @@ class TTS: self.t2s_model = t2s_model if self.configs.is_half and str(self.configs.device)!="cpu": self.t2s_model = self.t2s_model.half() - + def enable_half_precision(self, enable: bool = True): ''' To enable half precision for the TTS model. @@ -307,7 +308,7 @@ class TTS: if str(self.configs.device) == "cpu" and enable: print("Half precision is not supported on CPU.") return - + self.configs.is_half = enable self.precision = torch.float16 if enable else torch.float32 self.configs.save_configs() @@ -329,7 +330,7 @@ class TTS: self.bert_model = self.bert_model.float() if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.float() - + def set_device(self, device: torch.device): ''' To set the device for all models. @@ -346,7 +347,7 @@ class TTS: self.bert_model = self.bert_model.to(device) if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.to(device) - + def set_ref_audio(self, ref_audio_path:str): ''' To set the reference audio for the TTS model, @@ -356,7 +357,7 @@ class TTS: ''' self._set_prompt_semantic(ref_audio_path) self._set_ref_spec(ref_audio_path) - + def _set_ref_spec(self, ref_audio_path): audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) audio = torch.FloatTensor(audio) @@ -375,7 +376,8 @@ class TTS: spec = spec.half() # self.refer_spec = spec self.prompt_cache["refer_spec"] = spec - + + def _set_prompt_semantic(self, ref_wav_path:str): zero_wav = np.zeros( int(self.configs.sampling_rate * 0.3), @@ -400,10 +402,10 @@ class TTS: 1, 2 ) # .float() codes = self.vits_model.extract_latent(hubert_feature) - + prompt_semantic = codes[0, 0].to(self.configs.device) self.prompt_cache["prompt_semantic"] = prompt_semantic - + def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None): seq = sequences[0] ndim = seq.dim() @@ -416,8 +418,7 @@ class TTS: max_length = max(seq_lengths) else: max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length - # 我爱套 torch.no_grad() - # with torch.no_grad(): + padded_sequences = [] for seq, length in zip(sequences, seq_lengths): padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1) @@ -425,7 +426,7 @@ class TTS: padded_sequences.append(padded_seq) batch = torch.stack(padded_sequences) return batch - + def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, @@ -434,115 +435,116 @@ class TTS: device:torch.device=torch.device("cpu"), precision:torch.dtype=torch.float32, ): - # 但是这里不能套,反而会负优化 # with torch.no_grad(): - _data:list = [] - index_and_len_list = [] - for idx, item in enumerate(data): - norm_text_len = len(item["norm_text"]) - index_and_len_list.append([idx, norm_text_len]) + _data:list = [] + index_and_len_list = [] + for idx, item in enumerate(data): + norm_text_len = len(item["norm_text"]) + index_and_len_list.append([idx, norm_text_len]) - batch_index_list = [] - if split_bucket: - index_and_len_list.sort(key=lambda x: x[1]) - index_and_len_list = np.array(index_and_len_list, dtype=np.int64) + batch_index_list = [] + if split_bucket: + index_and_len_list.sort(key=lambda x: x[1]) + index_and_len_list = np.array(index_and_len_list, dtype=np.int64) + + batch_index_list_len = 0 + pos = 0 + while pos =threshold) or (pos_end-pos==1): + batch_index=index_and_len_list[pos:pos_end, 0].tolist() + batch_index_list_len += len(batch_index) + batch_index_list.append(batch_index) + pos = pos_end + break + pos_end=pos_end-1 + + assert batch_index_list_len == len(data) + + else: + for i in range(len(data)): + if i%batch_size == 0: + batch_index_list.append([]) + batch_index_list[-1].append(i) - batch_index_list_len = 0 - pos = 0 - while pos =threshold) or (pos_end-pos==1): - batch_index=index_and_len_list[pos:pos_end, 0].tolist() - batch_index_list_len += len(batch_index) - batch_index_list.append(batch_index) - pos = pos_end - break - pos_end=pos_end-1 - - assert batch_index_list_len == len(data) - - else: - for i in range(len(data)): - if i%batch_size == 0: - batch_index_list.append([]) - batch_index_list[-1].append(i) - - for batch_idx, index_list in enumerate(batch_index_list): - item_list = [data[idx] for idx in index_list] - phones_list = [] - phones_len_list = [] - # bert_features_list = [] - all_phones_list = [] - all_phones_len_list = [] - all_bert_features_list = [] - norm_text_batch = [] - bert_max_len = 0 - phones_max_len = 0 - # 但是这里也不能套,反而会负优化 - # with torch.no_grad(): - for item in item_list: - if prompt_data is not None: - all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ - .to(dtype=precision, device=device) - all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device) - phones = torch.LongTensor(item["phones"]).to(device) - # norm_text = prompt_data["norm_text"]+item["norm_text"] - else: - all_bert_features = item["bert_features"]\ + + for batch_idx, index_list in enumerate(batch_index_list): + item_list = [data[idx] for idx in index_list] + phones_list = [] + phones_len_list = [] + # bert_features_list = [] + all_phones_list = [] + all_phones_len_list = [] + all_bert_features_list = [] + norm_text_batch = [] + bert_max_len = 0 + phones_max_len = 0 + # 但是这里也不能套,反而会负优化 + # with torch.no_grad(): + for item in item_list: + if prompt_data is not None: + all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ .to(dtype=precision, device=device) - phones = torch.LongTensor(item["phones"]).to(device) - all_phones = phones - # norm_text = item["norm_text"] - - bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) - phones_max_len = max(phones_max_len, phones.shape[-1]) - - phones_list.append(phones) - phones_len_list.append(phones.shape[-1]) - all_phones_list.append(all_phones) - all_phones_len_list.append(all_phones.shape[-1]) - all_bert_features_list.append(all_bert_features) - norm_text_batch.append(item["norm_text"]) - - phones_batch = phones_list - all_phones_batch = all_phones_list - all_bert_features_batch = all_bert_features_list - - # max_len = max(bert_max_len, phones_max_len) - # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) - #### 直接对phones和bert_features进行pad,会增大复读概率。 - # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) - # all_bert_features_batch = all_bert_features_list - # all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device) - # for idx, item in enumerate(all_bert_features_list): - # all_bert_features_batch[idx, :, : item.shape[-1]] = item - - # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) - # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] - # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list] - # all_phones_batch = torch.stack(all_phones_list, dim=0) - - # all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list] - # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list] - # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) - - batch = { - "phones": phones_batch, - "phones_len": torch.LongTensor(phones_len_list).to(device), - "all_phones": all_phones_batch, - "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), - "all_bert_features": all_bert_features_batch, - "norm_text": norm_text_batch - } - _data.append(batch) - - return _data, batch_index_list + all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device) + phones = torch.LongTensor(item["phones"]).to(device) + # norm_text = prompt_data["norm_text"]+item["norm_text"] + else: + all_bert_features = item["bert_features"]\ + .to(dtype=precision, device=device) + phones = torch.LongTensor(item["phones"]).to(device) + all_phones = phones + # norm_text = item["norm_text"] + bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) + phones_max_len = max(phones_max_len, phones.shape[-1]) + + phones_list.append(phones) + phones_len_list.append(phones.shape[-1]) + all_phones_list.append(all_phones) + all_phones_len_list.append(all_phones.shape[-1]) + all_bert_features_list.append(all_bert_features) + norm_text_batch.append(item["norm_text"]) + + phones_batch = phones_list + all_phones_batch = all_phones_list + all_bert_features_batch = all_bert_features_list + + + # max_len = max(bert_max_len, phones_max_len) + # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) + #### 直接对phones和bert_features进行pad,会增大复读概率。 + # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) + # all_bert_features_batch = all_bert_features_list + # all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device) + # for idx, item in enumerate(all_bert_features_list): + # all_bert_features_batch[idx, :, : item.shape[-1]] = item + + # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) + # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] + # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list] + # all_phones_batch = torch.stack(all_phones_list, dim=0) + + # all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list] + # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list] + # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) + + batch = { + "phones": phones_batch, + "phones_len": torch.LongTensor(phones_len_list).to(device), + "all_phones": all_phones_batch, + "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), + "all_bert_features": all_bert_features_batch, + "norm_text": norm_text_batch + } + _data.append(batch) + + return _data, batch_index_list + def recovery_order(self, data:list, batch_index_list:list)->list: ''' Recovery the order of the audio according to the batch_index_list. @@ -566,7 +568,8 @@ class TTS: Stop the inference process. ''' self.stop_flag = True - + + def run(self, inputs:dict): """ Text to speech inference. @@ -850,7 +853,7 @@ class TTS: raise e finally: self.empty_cache() - + def empty_cache(self): try: if "cuda" in str(self.configs.device): @@ -859,7 +862,7 @@ class TTS: torch.mps.empty_cache() except: pass - + def audio_postprocess(self, audio:List[torch.Tensor], sr:int, @@ -873,32 +876,36 @@ class TTS: dtype=self.precision, device=self.configs.device ) - + for i, batch in enumerate(audio): for j, audio_fragment in enumerate(batch): max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 if max_audio>1: audio_fragment/=max_audio audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) audio[i][j] = audio_fragment.cpu().numpy() - + + if split_bucket: audio = self.recovery_order(audio, batch_index_list) else: # audio = [item for batch in audio for item in batch] audio = sum(audio, []) - + + audio = np.concatenate(audio, 0) audio = (audio * 32768).astype(np.int16) - + try: if speed_factor != 1.0: audio = speed_change(audio, speed=speed_factor, sr=int(sr)) except Exception as e: print(f"Failed to change speed of audio: \n{e}") - + return sr, audio - - + + + + def speed_change(input_audio:np.ndarray, speed:float, sr:int): # 将 NumPy 数组转换为原始 PCM 流 raw_audio = input_audio.astype(np.int16).tobytes() @@ -918,4 +925,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int): # 将管道输出解码为 NumPy 数组 processed_audio = np.frombuffer(out, np.int16) - return processed_audio + return processed_audio \ No newline at end of file