diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ae1fcc3c..d72aa393 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -827,7 +827,7 @@ class Text2SemanticDecoder(nn.Module): **kwargs, ): mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None) - sim_thershold = kwargs.get("sim_thershold", 0.3) + chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3) check_token_num = 2 @@ -927,9 +927,9 @@ class Text2SemanticDecoder(nn.Module): if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num): - score = mute_emb_sim_matrix[y[0, curr_ptr:]] - sim_thershold + score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold score[score<0]=-1 - score[:-1]=score[:-1]+score[1:] + score[:-1]=score[:-1]+score[1:] ##考虑连续两个token argmax_idx = score.argmax() if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length: diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 2f097f6c..a0c0d6ba 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -278,6 +278,8 @@ class TTS_Config: mute_tokens: dict = { "v1" : 486, "v2" : 486, + "v2Pro": 486, + "v2ProPlus": 486, "v3" : 486, "v4" : 486, } @@ -1009,7 +1011,7 @@ class TTS: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling - "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket: True, # bool. whether to split the batch into multiple buckets. @@ -1023,7 +1025,7 @@ class TTS: "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "streaming_mode": False, # bool. return audio chunk by chunk. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) } returns: Tuple[int, np.ndarray]: sampling rate and audio data. @@ -1039,7 +1041,7 @@ class TTS: top_k: int = inputs.get("top_k", 5) top_p: float = inputs.get("top_p", 1) temperature: float = inputs.get("temperature", 1) - text_split_method: str = inputs.get("text_split_method", "cut0") + text_split_method: str = inputs.get("text_split_method", "cut1") batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) speed_factor = inputs.get("speed_factor", 1.0) @@ -1055,7 +1057,8 @@ class TTS: super_sampling = inputs.get("super_sampling", False) streaming_mode = inputs.get("streaming_mode", False) overlap_length = inputs.get("overlap_length", 2) - chunk_length = inputs.get("chunk_length", 24) + min_chunk_length = inputs.get("min_chunk_length", 16) + chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。 if parallel_infer and not streaming_mode: print(i18n("并行推理模式已开启")) @@ -1249,6 +1252,15 @@ class TTS: self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) ) + refer_audio_spec = [] + + sv_emb = [] if self.is_v2pro else None + for spec, audio_tensor in self.prompt_cache["refer_spec"]: + spec = spec.to(dtype=self.precision, device=self.configs.device) + refer_audio_spec.append(spec) + if self.is_v2pro: + sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) + if not streaming_mode: print(f"############ {i18n('预测语义Token')} ############") pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( @@ -1267,14 +1279,6 @@ class TTS: t4 = time.perf_counter() t_34 += t4 - t3 - refer_audio_spec = [] - if self.is_v2pro: - sv_emb = [] - for spec, audio_tensor in self.prompt_cache["refer_spec"]: - spec = spec.to(dtype=self.precision, device=self.configs.device) - refer_audio_spec.append(spec) - if self.is_v2pro: - sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) batch_audio_fragment = [] @@ -1293,7 +1297,7 @@ class TTS: print(f"############ {i18n('合成音频')} ############") if not self.configs.use_vocoder: if speed_factor == 1.0: - print(f"{i18n('并行合成中')}...") + print(f"{i18n('合成中')}...") # ## vits并行推理 method 2 pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] upsample_rate = math.prod(self.vits_model.upsample_rates) @@ -1306,15 +1310,11 @@ class TTS: torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) ) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - if self.is_v2pro != True: - _batch_audio_fragment, _, _ = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - _batch_audio_fragment = self.vits_model.decode( + + _batch_audio_fragment = self.vits_model.decode( all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ) - _batch_audio_fragment = _batch_audio_fragment.detach()[0, 0, :] + ).detach()[0, 0, :] + audio_frag_end_idx.insert(0, 0) batch_audio_fragment = [ _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] @@ -1327,20 +1327,14 @@ class TTS: _pred_semantic = ( pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - if self.is_v2pro != True: - audio_fragment, _, _ = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] - else: - audio_fragment = self.vits_model.decode( + audio_fragment = self.vits_model.decode( _pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb - ) - audio_fragment=audio_fragment.detach()[0, 0, :] + ).detach()[0, 0, :] batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 else: if parallel_infer: print(f"{i18n('并行合成中')}...") - audio_fragments, y, y_mask = self.using_vocoder_synthesis_batched_infer( + audio_fragments = self.using_vocoder_synthesis_batched_infer( idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps ) batch_audio_fragment.extend(audio_fragments) @@ -1350,16 +1344,16 @@ class TTS: _pred_semantic = ( pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment, y, y_mask = self.using_vocoder_synthesis( + audio_fragment = self.using_vocoder_synthesis( _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps ) batch_audio_fragment.append(audio_fragment) else: - refer_audio_spec: torch.Tensor = [ - item.to(dtype=self.precision, device=self.configs.device) - for item in self.prompt_cache["refer_spec"] - ] + # refer_audio_spec: torch.Tensor = [ + # item.to(dtype=self.precision, device=self.configs.device) + # for item in self.prompt_cache["refer_spec"] + # ] semantic_token_generator =self.t2s_model.model.infer_panel( all_phoneme_ids[0].unsqueeze(0), all_phoneme_lens, @@ -1372,8 +1366,9 @@ class TTS: max_len=max_len, repetition_penalty=repetition_penalty, streaming_mode=True, - chunk_length=chunk_length, + chunk_length=min_chunk_length, mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix, + chunk_split_thershold=chunk_split_thershold, ) t4 = time.perf_counter() t_34 += t4 - t3 @@ -1381,15 +1376,15 @@ class TTS: is_first_chunk = True if not self.configs.use_vocoder: - if speed_factor == 1.0: - upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1) - else: - upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor) + # if speed_factor == 1.0: + # upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1) + # else: + upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor) else: - if speed_factor == 1.0: - upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4) - else: - upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) + # if speed_factor == 1.0: + # upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4) + # else: + upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor) last_audio_chunk = None # last_tokens = None @@ -1419,10 +1414,8 @@ class TTS: if not is_first_chunk and semantic_tokens.shape[-1] < 10: overlap_len = overlap_length+(10-semantic_tokens.shape[-1]) - # overlap_size = math.ceil(overlap_len*upsample_rate) else: overlap_len = overlap_length - # overlap_size = math.ceil(overlap_length*upsample_rate) if not self.configs.use_vocoder: @@ -1433,10 +1426,11 @@ class TTS: # else: # token_padding_length = 0 - audio_chunk, latent, latent_mask = self.vits_model.decode( + audio_chunk, latent, latent_mask = self.vits_model.decode_steaming( _semantic_tokens.unsqueeze(0), phones, refer_audio_spec, speed=speed_factor, + sv_emb=sv_emb, result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ if last_latent is not None else None, @@ -1444,13 +1438,7 @@ class TTS: ) audio_chunk=audio_chunk.detach()[0, 0, :] else: - audio_chunk, latent, latent_mask = self.using_vocoder_synthesis( - _semantic_tokens.unsqueeze(0), phones, - speed=speed_factor, sample_steps=sample_steps, - result_length = semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, - overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ - if last_latent is not None else None, - ) + raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式")) if overlap_len>overlap_length: audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):] @@ -1614,7 +1602,7 @@ class TTS: return sr, audio def using_vocoder_synthesis( - self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None, overlap_frames:torch.Tensor=None + self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32 ): prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) @@ -1623,7 +1611,7 @@ class TTS: raw_entry = raw_entry[0] refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device) - fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_sr = self.prompt_cache["raw_sr"] ref_audio = ref_audio.to(self.configs.device).float() @@ -1649,7 +1637,7 @@ class TTS: chunk_len = T_chunk - T_min mel2 = mel2.to(self.precision) - fea_todo, ge, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length, overlap_frames=overlap_frames) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) cfm_resss = [] idx = 0 @@ -1672,11 +1660,11 @@ class TTS: cfm_res = torch.cat(cfm_resss, 2) cfm_res = denorm_spec(cfm_res) - with torch.no_grad(): + with torch.inference_mode(): wav_gen = self.vocoder(cfm_res) audio = wav_gen[0][0] # .cpu().detach().numpy() - return audio, y, y_mask + return audio def using_vocoder_synthesis_batched_infer( self, @@ -1693,7 +1681,7 @@ class TTS: raw_entry = raw_entry[0] refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device) - fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_sr = self.prompt_cache["raw_sr"] ref_audio = ref_audio.to(self.configs.device).float() @@ -1731,7 +1719,7 @@ class TTS: semantic_tokens = ( semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - feat, _, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) feat_list.append(feat) feat_lens.append(feat.shape[2]) @@ -1791,7 +1779,7 @@ class TTS: audio_fragments.append(audio_fragment) audio = audio[feat_len * upsample_rate :] - return audio_fragments, y, y_mask + return audio_fragments def sola_algorithm( self, diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 6cb317f6..5049017f 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -990,8 +990,57 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o, y_mask, (z, z_p, m_p, logs_p) + @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): + def get_ge(refer, sv_emb): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + if self.version == "v1": + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + return ge + + if type(refer) == list: + ges = [] + for idx, _refer in enumerate(refer): + ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None) + ges.append(ge) + ge = torch.stack(ges, 0).mean(0) + else: + ge = get_ge(refer, sv_emb) + + y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + x, m_p, logs_p, y_mask, _, _ = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o + + + @torch.no_grad() + def decode_steaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): def get_ge(refer, sv_emb): ge = None if refer is not None: @@ -1031,7 +1080,10 @@ class SynthesizerTrn(nn.Module): text_lengths, self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, speed, - , result_length=result_length, overlap_frames=overlap_frames, padding_length=padding_length) + result_length=result_length, + overlap_frames=overlap_frames, + padding_length=padding_length + ) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) @@ -1277,7 +1329,7 @@ class SynthesizerTrnV3(nn.Module): return cfm_loss @torch.no_grad() - def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None, overlap_frames:torch.Tensor=None): + def decode_encp(self, codes, text, refer, ge=None, speed=1): # print(2333333,refer.shape) # ge=None if ge == None: @@ -1285,26 +1337,22 @@ class SynthesizerTrnV3(nn.Module): refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) - if speed == 1: - sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4)) + sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4)) else: - sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4) / speed) + 1 + sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1 y_lengths1 = torch.LongTensor([sizee]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT - result_length = result_length * 2 if result_length is not None else None - - - x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length, overlap_frames=overlap_frames) + x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) fea = self.bridge(x) fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT ####more wn paramter to learn mel fea, y_mask_ = self.wns1(fea, y_lengths1, ge) - return fea, ge, y_, y_mask_ + return fea, ge def extract_latent(self, x): ssl = self.ssl_proj(x) diff --git a/api_v2.py b/api_v2.py index 6dcec851..7aeb5c16 100644 --- a/api_v2.py +++ b/api_v2.py @@ -30,7 +30,7 @@ POST: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling - "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket": True, # bool. whether to split the batch into multiple buckets. @@ -42,7 +42,7 @@ POST: "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) } ``` @@ -174,7 +174,7 @@ class TTS_Request(BaseModel): sample_steps: int = 32 super_sampling: bool = False overlap_length: int = 2 - chunk_length: int = 24 + min_chunk_length: int = 16 return_fragment: bool = False @@ -333,7 +333,7 @@ async def tts_handle(req: dict): "sample_steps": 32, # int. number of sampling steps for VITS model V3. "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. - "chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size) + "min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) "return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode) } returns: @@ -416,7 +416,7 @@ async def tts_get_endpoint( sample_steps: int = 32, super_sampling: bool = False, overlap_length: int = 2, - chunk_length: int = 24, + min_chunk_length: int = 16, return_fragment: bool = False, ): req = { @@ -443,7 +443,7 @@ async def tts_get_endpoint( "sample_steps": int(sample_steps), "super_sampling": super_sampling, "overlap_length": int(overlap_length), - "chunk_length": int(chunk_length), + "min_chunk_length": int(min_chunk_length), "return_fragment": return_fragment, } return await tts_handle(req)