mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-16 01:06:57 +08:00
modified: GPT_SoVITS/AR/models/t2s_model.py
modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py
This commit is contained in:
parent
9ff381b519
commit
0825ae80e1
@ -826,6 +826,13 @@ class Text2SemanticDecoder(nn.Module):
|
||||
chunk_length: int = 24,
|
||||
**kwargs,
|
||||
):
|
||||
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
|
||||
sim_thershold = kwargs.get("sim_thershold", 0.3)
|
||||
min_chunk_len = kwargs.get("min_chunk_len", 12)
|
||||
limited_chunk_len = kwargs.get("limited_chunk_len", False)
|
||||
only_for_the_first_chunk = kwargs.get("only_for_the_first_chunk", True)
|
||||
|
||||
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_position(x)
|
||||
@ -877,6 +884,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
is_yield = False
|
||||
token_counter = 0
|
||||
for idx in tqdm(range(1500)):
|
||||
token_counter+=1
|
||||
@ -921,9 +929,39 @@ class Text2SemanticDecoder(nn.Module):
|
||||
yield (y[:, -token_counter:]) if token_counter!= 0 else None, True
|
||||
break
|
||||
|
||||
if streaming_mode and token_counter == chunk_length:
|
||||
# if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len):
|
||||
# sim = mute_emb_sim_matrix[y[0,-1]]
|
||||
# if sim >= sim_thershold: is_yield = True
|
||||
# elif streaming_mode and (mute_emb_sim_matrix is None):
|
||||
# is_yield = token_counter == chunk_length
|
||||
|
||||
# if streaming_mode and is_yield:
|
||||
# is_yield = False
|
||||
# yield y[:, -token_counter:], False
|
||||
# token_counter = 0
|
||||
|
||||
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter > min_chunk_len):
|
||||
last_sim = mute_emb_sim_matrix[y[0,-1]]
|
||||
|
||||
if (not limited_chunk_len) and last_sim >= sim_thershold:
|
||||
yield y[:, -token_counter:], False
|
||||
token_counter = 0
|
||||
# if is_first_package: is_first_package = False
|
||||
|
||||
elif limited_chunk_len and token_counter == chunk_length:
|
||||
# is_first_package = False
|
||||
limited_chunk_len = False if only_for_the_first_chunk else limited_chunk_len
|
||||
sim = mute_emb_sim_matrix[y[0,-(token_counter-min_chunk_len):]]
|
||||
# print(f"sim:{sim}")
|
||||
i = chunk_length-(sim.argmax()+min_chunk_len+1)
|
||||
token_counter = i
|
||||
yield y[:, -chunk_length:-i] if i!= 0 else y[:, -chunk_length:], False
|
||||
|
||||
|
||||
elif streaming_mode and (mute_emb_sim_matrix is None):
|
||||
is_yield = token_counter == chunk_length
|
||||
yield y[:, -token_counter:], False
|
||||
token_counter = 0
|
||||
yield y[:, -chunk_length:], False
|
||||
|
||||
|
||||
####################### update next step ###################################
|
||||
|
||||
@ -281,6 +281,7 @@ class TTS_Config:
|
||||
"v3" : 486,
|
||||
"v4" : 486,
|
||||
}
|
||||
mute_emb_sim_matrix: torch.Tensor = None
|
||||
# "all_zh",#全部按中文识别
|
||||
# "en",#全部按英文识别#######不变
|
||||
# "all_ja",#全部按日文识别
|
||||
@ -604,6 +605,11 @@ class TTS:
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
codebook = t2s_model.model.ar_audio_embedding.weight.clone()
|
||||
mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0)
|
||||
sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1)
|
||||
self.configs.mute_emb_sim_matrix = sim_matrix
|
||||
|
||||
def init_vocoder(self, version: str):
|
||||
if version == "v3":
|
||||
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
|
||||
@ -1065,6 +1071,7 @@ class TTS:
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||
else:
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||||
# self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||
elif parallel_infer and streaming_mode:
|
||||
print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式"))
|
||||
parallel_infer = False
|
||||
@ -1216,6 +1223,7 @@ class TTS:
|
||||
t_34 = 0.0
|
||||
t_45 = 0.0
|
||||
audio = []
|
||||
is_first_package = True
|
||||
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
||||
for item in data:
|
||||
t3 = time.perf_counter()
|
||||
@ -1299,13 +1307,14 @@ class TTS:
|
||||
)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
if self.is_v2pro != True:
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
_batch_audio_fragment, _, _ = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
).detach()[0, 0, :]
|
||||
)
|
||||
_batch_audio_fragment = _batch_audio_fragment.detach()[0, 0, :]
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment = [
|
||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||
@ -1319,18 +1328,19 @@ class TTS:
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
if self.is_v2pro != True:
|
||||
audio_fragment = self.vits_model.decode(
|
||||
audio_fragment, _, _ = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
).detach()[0, 0, :]
|
||||
)
|
||||
audio_fragment=audio_fragment.detach()[0, 0, :]
|
||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||
else:
|
||||
if parallel_infer:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
||||
audio_fragments, y, y_mask = self.using_vocoder_synthesis_batched_infer(
|
||||
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.extend(audio_fragments)
|
||||
@ -1340,7 +1350,7 @@ class TTS:
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.using_vocoder_synthesis(
|
||||
audio_fragment, y, y_mask = self.using_vocoder_synthesis(
|
||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.append(audio_fragment)
|
||||
@ -1364,6 +1374,9 @@ class TTS:
|
||||
repetition_penalty=repetition_penalty,
|
||||
streaming_mode=True,
|
||||
chunk_length=chunk_length,
|
||||
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
|
||||
only_for_the_first_chunk=is_first_package,
|
||||
limited_chunk_len=True
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
@ -1382,8 +1395,10 @@ class TTS:
|
||||
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
|
||||
|
||||
last_audio_chunk = None
|
||||
last_tokens = None
|
||||
# last_tokens = None
|
||||
last_latent = None
|
||||
previous_tokens = []
|
||||
overlap_len = overlap_length
|
||||
overlap_size = math.ceil(overlap_length*upsample_rate)
|
||||
for semantic_tokens, is_final in semantic_token_generator:
|
||||
if semantic_tokens is None and last_audio_chunk is not None:
|
||||
@ -1396,29 +1411,45 @@ class TTS:
|
||||
0.0,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
continue
|
||||
break
|
||||
|
||||
_semantic_tokens = semantic_tokens
|
||||
print(f"semantic_tokens shape:{semantic_tokens.shape}")
|
||||
|
||||
previous_tokens.append(semantic_tokens)
|
||||
|
||||
_semantic_tokens = torch.cat(previous_tokens, dim=-1)
|
||||
|
||||
if not is_first_chunk and semantic_tokens.shape[-1] < 10:
|
||||
overlap_len = overlap_length+(10-semantic_tokens.shape[-1])
|
||||
# overlap_size = math.ceil(overlap_len*upsample_rate)
|
||||
else:
|
||||
overlap_len = overlap_length
|
||||
# overlap_size = math.ceil(overlap_length*upsample_rate)
|
||||
|
||||
|
||||
if not self.configs.use_vocoder:
|
||||
audio_chunk = self.vits_model.decode(
|
||||
audio_chunk, latent, latent_mask = self.vits_model.decode(
|
||||
_semantic_tokens.unsqueeze(0),
|
||||
phones, refer_audio_spec,
|
||||
speed=speed_factor,
|
||||
result_length=semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None
|
||||
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
||||
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
||||
if last_latent is not None else None,
|
||||
# result_length=chunk_length if not is_first_chunk else None
|
||||
).detach()[0, 0, :]
|
||||
)
|
||||
audio_chunk=audio_chunk.detach()[0, 0, :]
|
||||
else:
|
||||
audio_chunk = self.using_vocoder_synthesis(
|
||||
audio_chunk, latent, latent_mask = self.using_vocoder_synthesis(
|
||||
_semantic_tokens.unsqueeze(0), phones,
|
||||
speed=speed_factor, sample_steps=sample_steps,
|
||||
result_length = semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None
|
||||
result_length = semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
||||
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
||||
if last_latent is not None else None,
|
||||
)
|
||||
|
||||
if overlap_len>overlap_length:
|
||||
audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):]
|
||||
|
||||
audio_chunk_ = audio_chunk
|
||||
if is_first_chunk and not is_final:
|
||||
@ -1433,7 +1464,12 @@ class TTS:
|
||||
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
|
||||
)
|
||||
|
||||
# audio_chunk_ = (
|
||||
# audio_chunk_[overlap_size:-overlap_size] if not is_final \
|
||||
# else audio_chunk_[overlap_size:]
|
||||
# )
|
||||
|
||||
last_latent = latent
|
||||
last_audio_chunk = audio_chunk
|
||||
yield self.audio_postprocess(
|
||||
[[audio_chunk_]],
|
||||
@ -1444,8 +1480,12 @@ class TTS:
|
||||
0.0,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
# print(f"first_package_delay: {time.perf_counter()-t0:.3f}")
|
||||
|
||||
if is_first_package:
|
||||
print(f"first_package_delay: {time.perf_counter()-t0:.3f}")
|
||||
is_first_package = False
|
||||
|
||||
|
||||
yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16)
|
||||
|
||||
t5 = time.perf_counter()
|
||||
@ -1553,8 +1593,13 @@ class TTS:
|
||||
t2 = time.perf_counter()
|
||||
print(f"超采样用时:{t2 - t1:.3f}s")
|
||||
else:
|
||||
audio = audio.float() * 32768
|
||||
audio = audio.to(dtype=torch.int16).cpu().numpy()
|
||||
# audio = audio.float() * 32768
|
||||
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy()
|
||||
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
|
||||
# try:
|
||||
# if speed_factor != 1.0:
|
||||
@ -1565,7 +1610,7 @@ class TTS:
|
||||
return sr, audio
|
||||
|
||||
def using_vocoder_synthesis(
|
||||
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None
|
||||
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None, overlap_frames:torch.Tensor=None
|
||||
):
|
||||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||
@ -1574,7 +1619,7 @@ class TTS:
|
||||
raw_entry = raw_entry[0]
|
||||
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||||
ref_sr = self.prompt_cache["raw_sr"]
|
||||
ref_audio = ref_audio.to(self.configs.device).float()
|
||||
@ -1600,7 +1645,7 @@ class TTS:
|
||||
chunk_len = T_chunk - T_min
|
||||
|
||||
mel2 = mel2.to(self.precision)
|
||||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length)
|
||||
fea_todo, ge, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length, overlap_frames=overlap_frames)
|
||||
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
@ -1627,7 +1672,7 @@ class TTS:
|
||||
wav_gen = self.vocoder(cfm_res)
|
||||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||
|
||||
return audio
|
||||
return audio, y, y_mask
|
||||
|
||||
def using_vocoder_synthesis_batched_infer(
|
||||
self,
|
||||
@ -1644,7 +1689,7 @@ class TTS:
|
||||
raw_entry = raw_entry[0]
|
||||
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||||
ref_sr = self.prompt_cache["raw_sr"]
|
||||
ref_audio = ref_audio.to(self.configs.device).float()
|
||||
@ -1682,7 +1727,7 @@ class TTS:
|
||||
semantic_tokens = (
|
||||
semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
feat, _, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
feat_list.append(feat)
|
||||
feat_lens.append(feat.shape[2])
|
||||
|
||||
@ -1742,7 +1787,7 @@ class TTS:
|
||||
audio_fragments.append(audio_fragment)
|
||||
audio = audio[feat_len * upsample_rate :]
|
||||
|
||||
return audio_fragments
|
||||
return audio_fragments, y, y_mask
|
||||
|
||||
def sola_algorithm(
|
||||
self,
|
||||
@ -1766,6 +1811,13 @@ class TTS:
|
||||
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
|
||||
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
|
||||
)
|
||||
|
||||
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
|
||||
# f2_[: (overlap_len - idx)] = (
|
||||
# window * f2_[: (overlap_len - idx)]
|
||||
# + (1-window) * f1[-(overlap_len - idx) :]
|
||||
# )
|
||||
|
||||
audio_fragments[i + 1] = f2_
|
||||
|
||||
return torch.cat(audio_fragments, 0)
|
||||
|
||||
@ -151,6 +151,8 @@ class DurationPredictor(nn.Module):
|
||||
return x * x_mask
|
||||
|
||||
|
||||
HANN_WINDOW = {}
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -209,7 +211,7 @@ class TextEncoder(nn.Module):
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None):
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
@ -227,13 +229,29 @@ class TextEncoder(nn.Module):
|
||||
if result_length is not None:
|
||||
y = y[:, :, -result_length:]
|
||||
y_mask = y_mask[:, :, -result_length:]
|
||||
|
||||
if overlap_frames is not None:
|
||||
overlap_len = overlap_frames.shape[-1]
|
||||
window = HANN_WINDOW.get(overlap_len, None)
|
||||
if window is None:
|
||||
HANN_WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
|
||||
window = HANN_WINDOW[overlap_len]
|
||||
window = window.to(y.device)
|
||||
y[:,:,:overlap_len] = (
|
||||
window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len]
|
||||
+ window[overlap_len:].view(1, 1, -1) * overlap_frames
|
||||
)
|
||||
y_ = y
|
||||
y_mask_ = y_mask
|
||||
|
||||
|
||||
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
stats = self.proj(y) * y_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return y, m, logs, y_mask
|
||||
return y, m, logs, y_mask, y_, y_mask_
|
||||
|
||||
def extract_latent(self, x):
|
||||
x = self.ssl_proj(x)
|
||||
@ -926,7 +944,7 @@ class SynthesizerTrn(nn.Module):
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
|
||||
@ -954,7 +972,7 @@ class SynthesizerTrn(nn.Module):
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -963,7 +981,7 @@ class SynthesizerTrn(nn.Module):
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None):
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None):
|
||||
def get_ge(refer, sv_emb):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
@ -995,20 +1013,20 @@ class SynthesizerTrn(nn.Module):
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
result_length = (2*result_length) if result_length is not None else None
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(
|
||||
quantized,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||
speed,
|
||||
, result_length=result_length)
|
||||
, result_length=result_length, overlap_frames=overlap_frames)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
return o, y_, y_mask_
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
@ -1232,7 +1250,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
||||
fea, y_mask_ = self.wns1(
|
||||
@ -1248,7 +1266,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
return cfm_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None):
|
||||
def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None, overlap_frames:torch.Tensor=None):
|
||||
# print(2333333,refer.shape)
|
||||
# ge=None
|
||||
if ge == None:
|
||||
@ -1270,12 +1288,12 @@ class SynthesizerTrnV3(nn.Module):
|
||||
result_length = result_length * 2 if result_length is not None else None
|
||||
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length)
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length, overlap_frames=overlap_frames)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea, ge
|
||||
return fea, ge, y_, y_mask_
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
@ -1387,7 +1405,7 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
@ -1430,7 +1448,7 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user