From af7b95bc9d4e586d769cbd58db1b7582eb07fc5a Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Mon, 24 Nov 2025 18:52:35 +0800 Subject: [PATCH] modified: .gitignore modified: GPT_SoVITS/AR/models/t2s_model.py modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py --- GPT_SoVITS/AR/models/t2s_model.py | 49 ++++++++++--------------------- GPT_SoVITS/TTS_infer_pack/TTS.py | 12 +++++--- GPT_SoVITS/module/models.py | 27 ++++++++++++----- 3 files changed, 42 insertions(+), 46 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 806729d0..ae1fcc3c 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -828,9 +828,7 @@ class Text2SemanticDecoder(nn.Module): ): mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None) sim_thershold = kwargs.get("sim_thershold", 0.3) - min_chunk_len = kwargs.get("min_chunk_len", 12) - limited_chunk_len = kwargs.get("limited_chunk_len", False) - only_for_the_first_chunk = kwargs.get("only_for_the_first_chunk", True) + check_token_num = 2 x = self.ar_text_embedding(x) @@ -884,8 +882,8 @@ class Text2SemanticDecoder(nn.Module): .to(device=x.device, dtype=torch.bool) ) - is_yield = False token_counter = 0 + curr_ptr = prefix_len for idx in tqdm(range(1500)): token_counter+=1 if xy_attn_mask is not None: @@ -924,42 +922,25 @@ class Text2SemanticDecoder(nn.Module): print("bad zero prediction") # print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") if streaming_mode: - # y=y[:, :-1] - # res_len = (y.shape[1] - prefix_len)%chunk_length - yield (y[:, -token_counter:]) if token_counter!= 0 else None, True + yield y[:, curr_ptr:] if curr_ptr min_chunk_len): - # sim = mute_emb_sim_matrix[y[0,-1]] - # if sim >= sim_thershold: is_yield = True - # elif streaming_mode and (mute_emb_sim_matrix is None): - # is_yield = token_counter == chunk_length - # if streaming_mode and is_yield: - # is_yield = False - # yield y[:, -token_counter:], False - # token_counter = 0 + if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num): + score = mute_emb_sim_matrix[y[0, curr_ptr:]] - sim_thershold + score[score<0]=-1 + score[:-1]=score[:-1]+score[1:] + argmax_idx = score.argmax() - if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len): - last_sim = mute_emb_sim_matrix[y[0,-1]] + if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length: + print(f"\n\ncurr_ptr:{curr_ptr}") + yield y[:, curr_ptr:], False + token_counter -= argmax_idx+1 + curr_ptr += argmax_idx+1 - if (not limited_chunk_len) and last_sim >= sim_thershold: - yield y[:, -token_counter:], False - token_counter = 0 - # if is_first_package: is_first_package = False - elif limited_chunk_len and token_counter == chunk_length: - # is_first_package = False - limited_chunk_len = False if only_for_the_first_chunk else limited_chunk_len - sim = mute_emb_sim_matrix[y[0,-(token_counter-min_chunk_len):]] - # print(f"sim:{sim}") - i = chunk_length-(sim.argmax()+min_chunk_len+1) - token_counter = i - yield y[:, -chunk_length:-i] if i!= 0 else y[:, -chunk_length:], False - - - elif streaming_mode and (mute_emb_sim_matrix is None): - is_yield = token_counter == chunk_length + elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length): + token_counter == chunk_length yield y[:, -token_counter:], False token_counter = 0 diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 813117a2..2f097f6c 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1365,7 +1365,6 @@ class TTS: all_phoneme_lens, prompt, all_bert_features[0].unsqueeze(0), - # prompt_phone_len=ph_offset, top_k=top_k, top_p=top_p, temperature=temperature, @@ -1375,8 +1374,6 @@ class TTS: streaming_mode=True, chunk_length=chunk_length, mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix, - only_for_the_first_chunk=is_first_package, - limited_chunk_len=True ) t4 = time.perf_counter() t_34 += t4 - t3 @@ -1429,6 +1426,13 @@ class TTS: if not self.configs.use_vocoder: + token_padding_length = 0 + # token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1] + # if token_padding_length>0: + # _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486) + # else: + # token_padding_length = 0 + audio_chunk, latent, latent_mask = self.vits_model.decode( _semantic_tokens.unsqueeze(0), phones, refer_audio_spec, @@ -1436,7 +1440,7 @@ class TTS: result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None, overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \ if last_latent is not None else None, - # result_length=chunk_length if not is_first_chunk else None + padding_length=token_padding_length ) audio_chunk=audio_chunk.detach()[0, 0, :] else: diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 2076960e..6cb317f6 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -151,7 +151,7 @@ class DurationPredictor(nn.Module): return x * x_mask -HANN_WINDOW = {} +WINDOW = {} class TextEncoder(nn.Module): def __init__( @@ -211,7 +211,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None): + def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y = self.ssl_proj(y * y_mask) * y_mask @@ -224,23 +224,33 @@ class TextEncoder(nn.Module): text = self.text_embedding(text).transpose(1, 2) text = self.encoder_text(text * text_mask, text_mask) y = self.mrte(y, y_mask, text, text_mask, ge) + + if padding_length is not None and padding_length!=0: + y = y[:, :, :-padding_length] + y_mask = y_mask[:, :, :-padding_length] + + y = self.encoder2(y * y_mask, y_mask) if result_length is not None: y = y[:, :, -result_length:] y_mask = y_mask[:, :, -result_length:] - + if overlap_frames is not None: overlap_len = overlap_frames.shape[-1] - window = HANN_WINDOW.get(overlap_len, None) + window = WINDOW.get(overlap_len, None) if window is None: - HANN_WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype) - window = HANN_WINDOW[overlap_len] + # WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype) + WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2)) + window = WINDOW[overlap_len] + + window = window.to(y.device) y[:,:,:overlap_len] = ( window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len] + window[overlap_len:].view(1, 1, -1) * overlap_frames ) + y_ = y y_mask_ = y_mask @@ -981,7 +991,7 @@ class SynthesizerTrn(nn.Module): return o, y_mask, (z, z_p, m_p, logs_p) @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None): + def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): def get_ge(refer, sv_emb): ge = None if refer is not None: @@ -1013,6 +1023,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") result_length = (2*result_length) if result_length is not None else None + padding_length = (2*padding_length) if padding_length is not None else None x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p( quantized, y_lengths, @@ -1020,7 +1031,7 @@ class SynthesizerTrn(nn.Module): text_lengths, self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, speed, - , result_length=result_length, overlap_frames=overlap_frames) + , result_length=result_length, overlap_frames=overlap_frames, padding_length=padding_length) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True)