mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-16 09:16:59 +08:00
modified: .gitignore
modified: GPT_SoVITS/AR/models/t2s_model.py modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py
This commit is contained in:
parent
d08214dd22
commit
af7b95bc9d
@ -828,9 +828,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
):
|
):
|
||||||
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
|
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
|
||||||
sim_thershold = kwargs.get("sim_thershold", 0.3)
|
sim_thershold = kwargs.get("sim_thershold", 0.3)
|
||||||
min_chunk_len = kwargs.get("min_chunk_len", 12)
|
check_token_num = 2
|
||||||
limited_chunk_len = kwargs.get("limited_chunk_len", False)
|
|
||||||
only_for_the_first_chunk = kwargs.get("only_for_the_first_chunk", True)
|
|
||||||
|
|
||||||
|
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
@ -884,8 +882,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
.to(device=x.device, dtype=torch.bool)
|
.to(device=x.device, dtype=torch.bool)
|
||||||
)
|
)
|
||||||
|
|
||||||
is_yield = False
|
|
||||||
token_counter = 0
|
token_counter = 0
|
||||||
|
curr_ptr = prefix_len
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
token_counter+=1
|
token_counter+=1
|
||||||
if xy_attn_mask is not None:
|
if xy_attn_mask is not None:
|
||||||
@ -924,42 +922,25 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
print("bad zero prediction")
|
print("bad zero prediction")
|
||||||
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||||
if streaming_mode:
|
if streaming_mode:
|
||||||
# y=y[:, :-1]
|
yield y[:, curr_ptr:] if curr_ptr<y.shape[1] else None, True
|
||||||
# res_len = (y.shape[1] - prefix_len)%chunk_length
|
|
||||||
yield (y[:, -token_counter:]) if token_counter!= 0 else None, True
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len):
|
|
||||||
# sim = mute_emb_sim_matrix[y[0,-1]]
|
|
||||||
# if sim >= sim_thershold: is_yield = True
|
|
||||||
# elif streaming_mode and (mute_emb_sim_matrix is None):
|
|
||||||
# is_yield = token_counter == chunk_length
|
|
||||||
|
|
||||||
# if streaming_mode and is_yield:
|
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num):
|
||||||
# is_yield = False
|
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - sim_thershold
|
||||||
# yield y[:, -token_counter:], False
|
score[score<0]=-1
|
||||||
# token_counter = 0
|
score[:-1]=score[:-1]+score[1:]
|
||||||
|
argmax_idx = score.argmax()
|
||||||
|
|
||||||
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len):
|
if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length:
|
||||||
last_sim = mute_emb_sim_matrix[y[0,-1]]
|
print(f"\n\ncurr_ptr:{curr_ptr}")
|
||||||
|
yield y[:, curr_ptr:], False
|
||||||
|
token_counter -= argmax_idx+1
|
||||||
|
curr_ptr += argmax_idx+1
|
||||||
|
|
||||||
if (not limited_chunk_len) and last_sim >= sim_thershold:
|
|
||||||
yield y[:, -token_counter:], False
|
|
||||||
token_counter = 0
|
|
||||||
# if is_first_package: is_first_package = False
|
|
||||||
|
|
||||||
elif limited_chunk_len and token_counter == chunk_length:
|
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
|
||||||
# is_first_package = False
|
token_counter == chunk_length
|
||||||
limited_chunk_len = False if only_for_the_first_chunk else limited_chunk_len
|
|
||||||
sim = mute_emb_sim_matrix[y[0,-(token_counter-min_chunk_len):]]
|
|
||||||
# print(f"sim:{sim}")
|
|
||||||
i = chunk_length-(sim.argmax()+min_chunk_len+1)
|
|
||||||
token_counter = i
|
|
||||||
yield y[:, -chunk_length:-i] if i!= 0 else y[:, -chunk_length:], False
|
|
||||||
|
|
||||||
|
|
||||||
elif streaming_mode and (mute_emb_sim_matrix is None):
|
|
||||||
is_yield = token_counter == chunk_length
|
|
||||||
yield y[:, -token_counter:], False
|
yield y[:, -token_counter:], False
|
||||||
token_counter = 0
|
token_counter = 0
|
||||||
|
|
||||||
|
|||||||
@ -1365,7 +1365,6 @@ class TTS:
|
|||||||
all_phoneme_lens,
|
all_phoneme_lens,
|
||||||
prompt,
|
prompt,
|
||||||
all_bert_features[0].unsqueeze(0),
|
all_bert_features[0].unsqueeze(0),
|
||||||
# prompt_phone_len=ph_offset,
|
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@ -1375,8 +1374,6 @@ class TTS:
|
|||||||
streaming_mode=True,
|
streaming_mode=True,
|
||||||
chunk_length=chunk_length,
|
chunk_length=chunk_length,
|
||||||
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
|
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
|
||||||
only_for_the_first_chunk=is_first_package,
|
|
||||||
limited_chunk_len=True
|
|
||||||
)
|
)
|
||||||
t4 = time.perf_counter()
|
t4 = time.perf_counter()
|
||||||
t_34 += t4 - t3
|
t_34 += t4 - t3
|
||||||
@ -1429,6 +1426,13 @@ class TTS:
|
|||||||
|
|
||||||
|
|
||||||
if not self.configs.use_vocoder:
|
if not self.configs.use_vocoder:
|
||||||
|
token_padding_length = 0
|
||||||
|
# token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1]
|
||||||
|
# if token_padding_length>0:
|
||||||
|
# _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486)
|
||||||
|
# else:
|
||||||
|
# token_padding_length = 0
|
||||||
|
|
||||||
audio_chunk, latent, latent_mask = self.vits_model.decode(
|
audio_chunk, latent, latent_mask = self.vits_model.decode(
|
||||||
_semantic_tokens.unsqueeze(0),
|
_semantic_tokens.unsqueeze(0),
|
||||||
phones, refer_audio_spec,
|
phones, refer_audio_spec,
|
||||||
@ -1436,7 +1440,7 @@ class TTS:
|
|||||||
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
||||||
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
||||||
if last_latent is not None else None,
|
if last_latent is not None else None,
|
||||||
# result_length=chunk_length if not is_first_chunk else None
|
padding_length=token_padding_length
|
||||||
)
|
)
|
||||||
audio_chunk=audio_chunk.detach()[0, 0, :]
|
audio_chunk=audio_chunk.detach()[0, 0, :]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class DurationPredictor(nn.Module):
|
|||||||
return x * x_mask
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
HANN_WINDOW = {}
|
WINDOW = {}
|
||||||
|
|
||||||
class TextEncoder(nn.Module):
|
class TextEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -211,7 +211,7 @@ class TextEncoder(nn.Module):
|
|||||||
|
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None):
|
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||||
|
|
||||||
y = self.ssl_proj(y * y_mask) * y_mask
|
y = self.ssl_proj(y * y_mask) * y_mask
|
||||||
@ -224,23 +224,33 @@ class TextEncoder(nn.Module):
|
|||||||
text = self.text_embedding(text).transpose(1, 2)
|
text = self.text_embedding(text).transpose(1, 2)
|
||||||
text = self.encoder_text(text * text_mask, text_mask)
|
text = self.encoder_text(text * text_mask, text_mask)
|
||||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||||
|
|
||||||
|
if padding_length is not None and padding_length!=0:
|
||||||
|
y = y[:, :, :-padding_length]
|
||||||
|
y_mask = y_mask[:, :, :-padding_length]
|
||||||
|
|
||||||
|
|
||||||
y = self.encoder2(y * y_mask, y_mask)
|
y = self.encoder2(y * y_mask, y_mask)
|
||||||
|
|
||||||
if result_length is not None:
|
if result_length is not None:
|
||||||
y = y[:, :, -result_length:]
|
y = y[:, :, -result_length:]
|
||||||
y_mask = y_mask[:, :, -result_length:]
|
y_mask = y_mask[:, :, -result_length:]
|
||||||
|
|
||||||
if overlap_frames is not None:
|
if overlap_frames is not None:
|
||||||
overlap_len = overlap_frames.shape[-1]
|
overlap_len = overlap_frames.shape[-1]
|
||||||
window = HANN_WINDOW.get(overlap_len, None)
|
window = WINDOW.get(overlap_len, None)
|
||||||
if window is None:
|
if window is None:
|
||||||
HANN_WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
|
# WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
|
||||||
window = HANN_WINDOW[overlap_len]
|
WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2))
|
||||||
|
window = WINDOW[overlap_len]
|
||||||
|
|
||||||
|
|
||||||
window = window.to(y.device)
|
window = window.to(y.device)
|
||||||
y[:,:,:overlap_len] = (
|
y[:,:,:overlap_len] = (
|
||||||
window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len]
|
window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len]
|
||||||
+ window[overlap_len:].view(1, 1, -1) * overlap_frames
|
+ window[overlap_len:].view(1, 1, -1) * overlap_frames
|
||||||
)
|
)
|
||||||
|
|
||||||
y_ = y
|
y_ = y
|
||||||
y_mask_ = y_mask
|
y_mask_ = y_mask
|
||||||
|
|
||||||
@ -981,7 +991,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None):
|
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||||
def get_ge(refer, sv_emb):
|
def get_ge(refer, sv_emb):
|
||||||
ge = None
|
ge = None
|
||||||
if refer is not None:
|
if refer is not None:
|
||||||
@ -1013,6 +1023,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
result_length = (2*result_length) if result_length is not None else None
|
result_length = (2*result_length) if result_length is not None else None
|
||||||
|
padding_length = (2*padding_length) if padding_length is not None else None
|
||||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(
|
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(
|
||||||
quantized,
|
quantized,
|
||||||
y_lengths,
|
y_lengths,
|
||||||
@ -1020,7 +1031,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
text_lengths,
|
text_lengths,
|
||||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||||
speed,
|
speed,
|
||||||
, result_length=result_length, overlap_frames=overlap_frames)
|
, result_length=result_length, overlap_frames=overlap_frames, padding_length=padding_length)
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user