modified: .gitignore

modified:   GPT_SoVITS/AR/models/t2s_model.py
	modified:   GPT_SoVITS/TTS_infer_pack/TTS.py
	modified:   GPT_SoVITS/module/models.py
This commit is contained in:
ChasonJiang 2025-11-24 18:52:35 +08:00
parent d08214dd22
commit af7b95bc9d
3 changed files with 42 additions and 46 deletions

View File

@ -828,9 +828,7 @@ class Text2SemanticDecoder(nn.Module):
):
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
sim_thershold = kwargs.get("sim_thershold", 0.3)
min_chunk_len = kwargs.get("min_chunk_len", 12)
limited_chunk_len = kwargs.get("limited_chunk_len", False)
only_for_the_first_chunk = kwargs.get("only_for_the_first_chunk", True)
check_token_num = 2
x = self.ar_text_embedding(x)
@ -884,8 +882,8 @@ class Text2SemanticDecoder(nn.Module):
.to(device=x.device, dtype=torch.bool)
)
is_yield = False
token_counter = 0
curr_ptr = prefix_len
for idx in tqdm(range(1500)):
token_counter+=1
if xy_attn_mask is not None:
@ -924,42 +922,25 @@ class Text2SemanticDecoder(nn.Module):
print("bad zero prediction")
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
if streaming_mode:
# y=y[:, :-1]
# res_len = (y.shape[1] - prefix_len)%chunk_length
yield (y[:, -token_counter:]) if token_counter!= 0 else None, True
yield y[:, curr_ptr:] if curr_ptr<y.shape[1] else None, True
break
# if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len):
# sim = mute_emb_sim_matrix[y[0,-1]]
# if sim >= sim_thershold: is_yield = True
# elif streaming_mode and (mute_emb_sim_matrix is None):
# is_yield = token_counter == chunk_length
# if streaming_mode and is_yield:
# is_yield = False
# yield y[:, -token_counter:], False
# token_counter = 0
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num):
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - sim_thershold
score[score<0]=-1
score[:-1]=score[:-1]+score[1:]
argmax_idx = score.argmax()
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len):
last_sim = mute_emb_sim_matrix[y[0,-1]]
if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length:
print(f"\n\ncurr_ptr:{curr_ptr}")
yield y[:, curr_ptr:], False
token_counter -= argmax_idx+1
curr_ptr += argmax_idx+1
if (not limited_chunk_len) and last_sim >= sim_thershold:
yield y[:, -token_counter:], False
token_counter = 0
# if is_first_package: is_first_package = False
elif limited_chunk_len and token_counter == chunk_length:
# is_first_package = False
limited_chunk_len = False if only_for_the_first_chunk else limited_chunk_len
sim = mute_emb_sim_matrix[y[0,-(token_counter-min_chunk_len):]]
# print(f"sim:{sim}")
i = chunk_length-(sim.argmax()+min_chunk_len+1)
token_counter = i
yield y[:, -chunk_length:-i] if i!= 0 else y[:, -chunk_length:], False
elif streaming_mode and (mute_emb_sim_matrix is None):
is_yield = token_counter == chunk_length
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
token_counter == chunk_length
yield y[:, -token_counter:], False
token_counter = 0

View File

@ -1365,7 +1365,6 @@ class TTS:
all_phoneme_lens,
prompt,
all_bert_features[0].unsqueeze(0),
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
@ -1375,8 +1374,6 @@ class TTS:
streaming_mode=True,
chunk_length=chunk_length,
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
only_for_the_first_chunk=is_first_package,
limited_chunk_len=True
)
t4 = time.perf_counter()
t_34 += t4 - t3
@ -1429,6 +1426,13 @@ class TTS:
if not self.configs.use_vocoder:
token_padding_length = 0
# token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1]
# if token_padding_length>0:
# _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486)
# else:
# token_padding_length = 0
audio_chunk, latent, latent_mask = self.vits_model.decode(
_semantic_tokens.unsqueeze(0),
phones, refer_audio_spec,
@ -1436,7 +1440,7 @@ class TTS:
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
if last_latent is not None else None,
# result_length=chunk_length if not is_first_chunk else None
padding_length=token_padding_length
)
audio_chunk=audio_chunk.detach()[0, 0, :]
else:

View File

@ -151,7 +151,7 @@ class DurationPredictor(nn.Module):
return x * x_mask
HANN_WINDOW = {}
WINDOW = {}
class TextEncoder(nn.Module):
def __init__(
@ -211,7 +211,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None):
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y = self.ssl_proj(y * y_mask) * y_mask
@ -224,23 +224,33 @@ class TextEncoder(nn.Module):
text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask)
y = self.mrte(y, y_mask, text, text_mask, ge)
if padding_length is not None and padding_length!=0:
y = y[:, :, :-padding_length]
y_mask = y_mask[:, :, :-padding_length]
y = self.encoder2(y * y_mask, y_mask)
if result_length is not None:
y = y[:, :, -result_length:]
y_mask = y_mask[:, :, -result_length:]
if overlap_frames is not None:
overlap_len = overlap_frames.shape[-1]
window = HANN_WINDOW.get(overlap_len, None)
window = WINDOW.get(overlap_len, None)
if window is None:
HANN_WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
window = HANN_WINDOW[overlap_len]
# WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2))
window = WINDOW[overlap_len]
window = window.to(y.device)
y[:,:,:overlap_len] = (
window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len]
+ window[overlap_len:].view(1, 1, -1) * overlap_frames
)
y_ = y
y_mask_ = y_mask
@ -981,7 +991,7 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None):
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
@ -1013,6 +1023,7 @@ class SynthesizerTrn(nn.Module):
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
result_length = (2*result_length) if result_length is not None else None
padding_length = (2*padding_length) if padding_length is not None else None
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(
quantized,
y_lengths,
@ -1020,7 +1031,7 @@ class SynthesizerTrn(nn.Module):
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
, result_length=result_length, overlap_frames=overlap_frames)
, result_length=result_length, overlap_frames=overlap_frames, padding_length=padding_length)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)