mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-16 01:06:57 +08:00
modified: GPT_SoVITS/AR/models/t2s_model.py
modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/module/models.py modified: api_v2.py
This commit is contained in:
parent
af7b95bc9d
commit
08d6ed0d8c
@ -827,7 +827,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
|
||||
sim_thershold = kwargs.get("sim_thershold", 0.3)
|
||||
chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3)
|
||||
check_token_num = 2
|
||||
|
||||
|
||||
@ -927,9 +927,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
|
||||
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num):
|
||||
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - sim_thershold
|
||||
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold
|
||||
score[score<0]=-1
|
||||
score[:-1]=score[:-1]+score[1:]
|
||||
score[:-1]=score[:-1]+score[1:] ##考虑连续两个token
|
||||
argmax_idx = score.argmax()
|
||||
|
||||
if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length:
|
||||
|
||||
@ -278,6 +278,8 @@ class TTS_Config:
|
||||
mute_tokens: dict = {
|
||||
"v1" : 486,
|
||||
"v2" : 486,
|
||||
"v2Pro": 486,
|
||||
"v2ProPlus": 486,
|
||||
"v3" : 486,
|
||||
"v4" : 486,
|
||||
}
|
||||
@ -1009,7 +1011,7 @@ class TTS:
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
||||
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
@ -1023,7 +1025,7 @@ class TTS:
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"streaming_mode": False, # bool. return audio chunk by chunk.
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
}
|
||||
returns:
|
||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||
@ -1039,7 +1041,7 @@ class TTS:
|
||||
top_k: int = inputs.get("top_k", 5)
|
||||
top_p: float = inputs.get("top_p", 1)
|
||||
temperature: float = inputs.get("temperature", 1)
|
||||
text_split_method: str = inputs.get("text_split_method", "cut0")
|
||||
text_split_method: str = inputs.get("text_split_method", "cut1")
|
||||
batch_size = inputs.get("batch_size", 1)
|
||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||
speed_factor = inputs.get("speed_factor", 1.0)
|
||||
@ -1055,7 +1057,8 @@ class TTS:
|
||||
super_sampling = inputs.get("super_sampling", False)
|
||||
streaming_mode = inputs.get("streaming_mode", False)
|
||||
overlap_length = inputs.get("overlap_length", 2)
|
||||
chunk_length = inputs.get("chunk_length", 24)
|
||||
min_chunk_length = inputs.get("min_chunk_length", 16)
|
||||
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。
|
||||
|
||||
if parallel_infer and not streaming_mode:
|
||||
print(i18n("并行推理模式已开启"))
|
||||
@ -1249,6 +1252,15 @@ class TTS:
|
||||
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||
)
|
||||
|
||||
refer_audio_spec = []
|
||||
|
||||
sv_emb = [] if self.is_v2pro else None
|
||||
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
||||
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
||||
refer_audio_spec.append(spec)
|
||||
if self.is_v2pro:
|
||||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
||||
|
||||
if not streaming_mode:
|
||||
print(f"############ {i18n('预测语义Token')} ############")
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
@ -1267,14 +1279,6 @@ class TTS:
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spec = []
|
||||
if self.is_v2pro:
|
||||
sv_emb = []
|
||||
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
||||
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
||||
refer_audio_spec.append(spec)
|
||||
if self.is_v2pro:
|
||||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
||||
|
||||
batch_audio_fragment = []
|
||||
|
||||
@ -1293,7 +1297,7 @@ class TTS:
|
||||
print(f"############ {i18n('合成音频')} ############")
|
||||
if not self.configs.use_vocoder:
|
||||
if speed_factor == 1.0:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
print(f"{i18n('合成中')}...")
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
@ -1306,15 +1310,11 @@ class TTS:
|
||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
if self.is_v2pro != True:
|
||||
_batch_audio_fragment, _, _ = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
)
|
||||
_batch_audio_fragment = _batch_audio_fragment.detach()[0, 0, :]
|
||||
).detach()[0, 0, :]
|
||||
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment = [
|
||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||
@ -1327,20 +1327,14 @@ class TTS:
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
if self.is_v2pro != True:
|
||||
audio_fragment, _, _ = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
audio_fragment = self.vits_model.decode(
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
)
|
||||
audio_fragment=audio_fragment.detach()[0, 0, :]
|
||||
).detach()[0, 0, :]
|
||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||
else:
|
||||
if parallel_infer:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
audio_fragments, y, y_mask = self.using_vocoder_synthesis_batched_infer(
|
||||
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
||||
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.extend(audio_fragments)
|
||||
@ -1350,16 +1344,16 @@ class TTS:
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment, y, y_mask = self.using_vocoder_synthesis(
|
||||
audio_fragment = self.using_vocoder_synthesis(
|
||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.append(audio_fragment)
|
||||
|
||||
else:
|
||||
refer_audio_spec: torch.Tensor = [
|
||||
item.to(dtype=self.precision, device=self.configs.device)
|
||||
for item in self.prompt_cache["refer_spec"]
|
||||
]
|
||||
# refer_audio_spec: torch.Tensor = [
|
||||
# item.to(dtype=self.precision, device=self.configs.device)
|
||||
# for item in self.prompt_cache["refer_spec"]
|
||||
# ]
|
||||
semantic_token_generator =self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids[0].unsqueeze(0),
|
||||
all_phoneme_lens,
|
||||
@ -1372,8 +1366,9 @@ class TTS:
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
streaming_mode=True,
|
||||
chunk_length=chunk_length,
|
||||
chunk_length=min_chunk_length,
|
||||
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
|
||||
chunk_split_thershold=chunk_split_thershold,
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
@ -1381,15 +1376,15 @@ class TTS:
|
||||
is_first_chunk = True
|
||||
|
||||
if not self.configs.use_vocoder:
|
||||
if speed_factor == 1.0:
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
|
||||
else:
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
|
||||
# if speed_factor == 1.0:
|
||||
# upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
|
||||
# else:
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
|
||||
else:
|
||||
if speed_factor == 1.0:
|
||||
upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
|
||||
else:
|
||||
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
|
||||
# if speed_factor == 1.0:
|
||||
# upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
|
||||
# else:
|
||||
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
|
||||
|
||||
last_audio_chunk = None
|
||||
# last_tokens = None
|
||||
@ -1419,10 +1414,8 @@ class TTS:
|
||||
|
||||
if not is_first_chunk and semantic_tokens.shape[-1] < 10:
|
||||
overlap_len = overlap_length+(10-semantic_tokens.shape[-1])
|
||||
# overlap_size = math.ceil(overlap_len*upsample_rate)
|
||||
else:
|
||||
overlap_len = overlap_length
|
||||
# overlap_size = math.ceil(overlap_length*upsample_rate)
|
||||
|
||||
|
||||
if not self.configs.use_vocoder:
|
||||
@ -1433,10 +1426,11 @@ class TTS:
|
||||
# else:
|
||||
# token_padding_length = 0
|
||||
|
||||
audio_chunk, latent, latent_mask = self.vits_model.decode(
|
||||
audio_chunk, latent, latent_mask = self.vits_model.decode_steaming(
|
||||
_semantic_tokens.unsqueeze(0),
|
||||
phones, refer_audio_spec,
|
||||
speed=speed_factor,
|
||||
sv_emb=sv_emb,
|
||||
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
||||
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
||||
if last_latent is not None else None,
|
||||
@ -1444,13 +1438,7 @@ class TTS:
|
||||
)
|
||||
audio_chunk=audio_chunk.detach()[0, 0, :]
|
||||
else:
|
||||
audio_chunk, latent, latent_mask = self.using_vocoder_synthesis(
|
||||
_semantic_tokens.unsqueeze(0), phones,
|
||||
speed=speed_factor, sample_steps=sample_steps,
|
||||
result_length = semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
||||
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
||||
if last_latent is not None else None,
|
||||
)
|
||||
raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式"))
|
||||
|
||||
if overlap_len>overlap_length:
|
||||
audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):]
|
||||
@ -1614,7 +1602,7 @@ class TTS:
|
||||
return sr, audio
|
||||
|
||||
def using_vocoder_synthesis(
|
||||
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None, overlap_frames:torch.Tensor=None
|
||||
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
|
||||
):
|
||||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||
@ -1623,7 +1611,7 @@ class TTS:
|
||||
raw_entry = raw_entry[0]
|
||||
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||||
ref_sr = self.prompt_cache["raw_sr"]
|
||||
ref_audio = ref_audio.to(self.configs.device).float()
|
||||
@ -1649,7 +1637,7 @@ class TTS:
|
||||
chunk_len = T_chunk - T_min
|
||||
|
||||
mel2 = mel2.to(self.precision)
|
||||
fea_todo, ge, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length, overlap_frames=overlap_frames)
|
||||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
@ -1672,11 +1660,11 @@ class TTS:
|
||||
cfm_res = torch.cat(cfm_resss, 2)
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.inference_mode():
|
||||
wav_gen = self.vocoder(cfm_res)
|
||||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||
|
||||
return audio, y, y_mask
|
||||
return audio
|
||||
|
||||
def using_vocoder_synthesis_batched_infer(
|
||||
self,
|
||||
@ -1693,7 +1681,7 @@ class TTS:
|
||||
raw_entry = raw_entry[0]
|
||||
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||||
ref_sr = self.prompt_cache["raw_sr"]
|
||||
ref_audio = ref_audio.to(self.configs.device).float()
|
||||
@ -1731,7 +1719,7 @@ class TTS:
|
||||
semantic_tokens = (
|
||||
semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
feat, _, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
feat_list.append(feat)
|
||||
feat_lens.append(feat.shape[2])
|
||||
|
||||
@ -1791,7 +1779,7 @@ class TTS:
|
||||
audio_fragments.append(audio_fragment)
|
||||
audio = audio[feat_len * upsample_rate :]
|
||||
|
||||
return audio_fragments, y, y_mask
|
||||
return audio_fragments
|
||||
|
||||
def sola_algorithm(
|
||||
self,
|
||||
|
||||
@ -990,8 +990,57 @@ class SynthesizerTrn(nn.Module):
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
def get_ge(refer, sv_emb):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
||||
ge += sv_emb.unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
return ge
|
||||
|
||||
if type(refer) == list:
|
||||
ges = []
|
||||
for idx, _refer in enumerate(refer):
|
||||
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||
ges.append(ge)
|
||||
ge = torch.stack(ges, 0).mean(0)
|
||||
else:
|
||||
ge = get_ge(refer, sv_emb)
|
||||
|
||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
|
||||
quantized,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||
speed,
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_steaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||
def get_ge(refer, sv_emb):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
@ -1031,7 +1080,10 @@ class SynthesizerTrn(nn.Module):
|
||||
text_lengths,
|
||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||
speed,
|
||||
, result_length=result_length, overlap_frames=overlap_frames, padding_length=padding_length)
|
||||
result_length=result_length,
|
||||
overlap_frames=overlap_frames,
|
||||
padding_length=padding_length
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -1277,7 +1329,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
return cfm_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None, overlap_frames:torch.Tensor=None):
|
||||
def decode_encp(self, codes, text, refer, ge=None, speed=1):
|
||||
# print(2333333,refer.shape)
|
||||
# ge=None
|
||||
if ge == None:
|
||||
@ -1285,26 +1337,22 @@ class SynthesizerTrnV3(nn.Module):
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
||||
|
||||
if speed == 1:
|
||||
sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4))
|
||||
sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4))
|
||||
else:
|
||||
sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4) / speed) + 1
|
||||
sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1
|
||||
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
result_length = result_length * 2 if result_length is not None else None
|
||||
|
||||
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length, overlap_frames=overlap_frames)
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea, ge, y_, y_mask_
|
||||
return fea, ge
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
|
||||
12
api_v2.py
12
api_v2.py
@ -30,7 +30,7 @@ POST:
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
@ -42,7 +42,7 @@ POST:
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
|
||||
}
|
||||
```
|
||||
@ -174,7 +174,7 @@ class TTS_Request(BaseModel):
|
||||
sample_steps: int = 32
|
||||
super_sampling: bool = False
|
||||
overlap_length: int = 2
|
||||
chunk_length: int = 24
|
||||
min_chunk_length: int = 16
|
||||
return_fragment: bool = False
|
||||
|
||||
|
||||
@ -333,7 +333,7 @@ async def tts_handle(req: dict):
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
|
||||
}
|
||||
returns:
|
||||
@ -416,7 +416,7 @@ async def tts_get_endpoint(
|
||||
sample_steps: int = 32,
|
||||
super_sampling: bool = False,
|
||||
overlap_length: int = 2,
|
||||
chunk_length: int = 24,
|
||||
min_chunk_length: int = 16,
|
||||
return_fragment: bool = False,
|
||||
):
|
||||
req = {
|
||||
@ -443,7 +443,7 @@ async def tts_get_endpoint(
|
||||
"sample_steps": int(sample_steps),
|
||||
"super_sampling": super_sampling,
|
||||
"overlap_length": int(overlap_length),
|
||||
"chunk_length": int(chunk_length),
|
||||
"min_chunk_length": int(min_chunk_length),
|
||||
"return_fragment": return_fragment,
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user