modified: GPT_SoVITS/AR/models/t2s_model.py

modified:   GPT_SoVITS/TTS_infer_pack/TTS.py
	modified:   GPT_SoVITS/module/models.py
	modified:   api_v2.py
This commit is contained in:
ChasonJiang 2025-11-24 20:47:32 +08:00
parent af7b95bc9d
commit 08d6ed0d8c
4 changed files with 118 additions and 82 deletions

View File

@ -827,7 +827,7 @@ class Text2SemanticDecoder(nn.Module):
**kwargs,
):
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
sim_thershold = kwargs.get("sim_thershold", 0.3)
chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3)
check_token_num = 2
@ -927,9 +927,9 @@ class Text2SemanticDecoder(nn.Module):
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num):
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - sim_thershold
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold
score[score<0]=-1
score[:-1]=score[:-1]+score[1:]
score[:-1]=score[:-1]+score[1:] ##考虑连续两个token
argmax_idx = score.argmax()
if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length:

View File

@ -278,6 +278,8 @@ class TTS_Config:
mute_tokens: dict = {
"v1" : 486,
"v2" : 486,
"v2Pro": 486,
"v2ProPlus": 486,
"v3" : 486,
"v4" : 486,
}
@ -1009,7 +1011,7 @@ class TTS:
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
@ -1023,7 +1025,7 @@ class TTS:
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"streaming_mode": False, # bool. return audio chunk by chunk.
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
}
returns:
Tuple[int, np.ndarray]: sampling rate and audio data.
@ -1039,7 +1041,7 @@ class TTS:
top_k: int = inputs.get("top_k", 5)
top_p: float = inputs.get("top_p", 1)
temperature: float = inputs.get("temperature", 1)
text_split_method: str = inputs.get("text_split_method", "cut0")
text_split_method: str = inputs.get("text_split_method", "cut1")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
@ -1055,7 +1057,8 @@ class TTS:
super_sampling = inputs.get("super_sampling", False)
streaming_mode = inputs.get("streaming_mode", False)
overlap_length = inputs.get("overlap_length", 2)
chunk_length = inputs.get("chunk_length", 24)
min_chunk_length = inputs.get("min_chunk_length", 16)
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值若大于该阈值则视为可切分点。
if parallel_infer and not streaming_mode:
print(i18n("并行推理模式已开启"))
@ -1249,6 +1252,15 @@ class TTS:
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
)
refer_audio_spec = []
sv_emb = [] if self.is_v2pro else None
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
spec = spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec)
if self.is_v2pro:
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
if not streaming_mode:
print(f"############ {i18n('预测语义Token')} ############")
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
@ -1267,14 +1279,6 @@ class TTS:
t4 = time.perf_counter()
t_34 += t4 - t3
refer_audio_spec = []
if self.is_v2pro:
sv_emb = []
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
spec = spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec)
if self.is_v2pro:
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
batch_audio_fragment = []
@ -1293,7 +1297,7 @@ class TTS:
print(f"############ {i18n('合成音频')} ############")
if not self.configs.use_vocoder:
if speed_factor == 1.0:
print(f"{i18n('并行合成中')}...")
print(f"{i18n('合成中')}...")
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
@ -1306,15 +1310,11 @@ class TTS:
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
if self.is_v2pro != True:
_batch_audio_fragment, _, _ = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else:
_batch_audio_fragment = self.vits_model.decode(
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
)
_batch_audio_fragment = _batch_audio_fragment.detach()[0, 0, :]
).detach()[0, 0, :]
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment = [
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
@ -1327,20 +1327,14 @@ class TTS:
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
if self.is_v2pro != True:
audio_fragment, _, _ = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else:
audio_fragment = self.vits_model.decode(
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
)
audio_fragment=audio_fragment.detach()[0, 0, :]
).detach()[0, 0, :]
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
else:
if parallel_infer:
print(f"{i18n('并行合成中')}...")
audio_fragments, y, y_mask = self.using_vocoder_synthesis_batched_infer(
audio_fragments = self.using_vocoder_synthesis_batched_infer(
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.extend(audio_fragments)
@ -1350,16 +1344,16 @@ class TTS:
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment, y, y_mask = self.using_vocoder_synthesis(
audio_fragment = self.using_vocoder_synthesis(
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.append(audio_fragment)
else:
refer_audio_spec: torch.Tensor = [
item.to(dtype=self.precision, device=self.configs.device)
for item in self.prompt_cache["refer_spec"]
]
# refer_audio_spec: torch.Tensor = [
# item.to(dtype=self.precision, device=self.configs.device)
# for item in self.prompt_cache["refer_spec"]
# ]
semantic_token_generator =self.t2s_model.model.infer_panel(
all_phoneme_ids[0].unsqueeze(0),
all_phoneme_lens,
@ -1372,8 +1366,9 @@ class TTS:
max_len=max_len,
repetition_penalty=repetition_penalty,
streaming_mode=True,
chunk_length=chunk_length,
chunk_length=min_chunk_length,
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
chunk_split_thershold=chunk_split_thershold,
)
t4 = time.perf_counter()
t_34 += t4 - t3
@ -1381,15 +1376,15 @@ class TTS:
is_first_chunk = True
if not self.configs.use_vocoder:
if speed_factor == 1.0:
upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
else:
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
# if speed_factor == 1.0:
# upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
# else:
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
else:
if speed_factor == 1.0:
upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
else:
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
# if speed_factor == 1.0:
# upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
# else:
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
last_audio_chunk = None
# last_tokens = None
@ -1419,10 +1414,8 @@ class TTS:
if not is_first_chunk and semantic_tokens.shape[-1] < 10:
overlap_len = overlap_length+(10-semantic_tokens.shape[-1])
# overlap_size = math.ceil(overlap_len*upsample_rate)
else:
overlap_len = overlap_length
# overlap_size = math.ceil(overlap_length*upsample_rate)
if not self.configs.use_vocoder:
@ -1433,10 +1426,11 @@ class TTS:
# else:
# token_padding_length = 0
audio_chunk, latent, latent_mask = self.vits_model.decode(
audio_chunk, latent, latent_mask = self.vits_model.decode_steaming(
_semantic_tokens.unsqueeze(0),
phones, refer_audio_spec,
speed=speed_factor,
sv_emb=sv_emb,
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
if last_latent is not None else None,
@ -1444,13 +1438,7 @@ class TTS:
)
audio_chunk=audio_chunk.detach()[0, 0, :]
else:
audio_chunk, latent, latent_mask = self.using_vocoder_synthesis(
_semantic_tokens.unsqueeze(0), phones,
speed=speed_factor, sample_steps=sample_steps,
result_length = semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
if last_latent is not None else None,
)
raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式"))
if overlap_len>overlap_length:
audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):]
@ -1614,7 +1602,7 @@ class TTS:
return sr, audio
def using_vocoder_synthesis(
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None, overlap_frames:torch.Tensor=None
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
@ -1623,7 +1611,7 @@ class TTS:
raw_entry = raw_entry[0]
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
ref_sr = self.prompt_cache["raw_sr"]
ref_audio = ref_audio.to(self.configs.device).float()
@ -1649,7 +1637,7 @@ class TTS:
chunk_len = T_chunk - T_min
mel2 = mel2.to(self.precision)
fea_todo, ge, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length, overlap_frames=overlap_frames)
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
cfm_resss = []
idx = 0
@ -1672,11 +1660,11 @@ class TTS:
cfm_res = torch.cat(cfm_resss, 2)
cfm_res = denorm_spec(cfm_res)
with torch.no_grad():
with torch.inference_mode():
wav_gen = self.vocoder(cfm_res)
audio = wav_gen[0][0] # .cpu().detach().numpy()
return audio, y, y_mask
return audio
def using_vocoder_synthesis_batched_infer(
self,
@ -1693,7 +1681,7 @@ class TTS:
raw_entry = raw_entry[0]
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
fea_ref, ge, y, y_mask = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
ref_sr = self.prompt_cache["raw_sr"]
ref_audio = ref_audio.to(self.configs.device).float()
@ -1731,7 +1719,7 @@ class TTS:
semantic_tokens = (
semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
feat, _, y, y_mask = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
feat_list.append(feat)
feat_lens.append(feat.shape[2])
@ -1791,7 +1779,7 @@ class TTS:
audio_fragments.append(audio_fragment)
audio = audio[feat_len * upsample_rate :]
return audio_fragments, y, y_mask
return audio_fragments
def sola_algorithm(
self,

View File

@ -990,8 +990,57 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
return ge
if type(refer) == list:
ges = []
for idx, _refer in enumerate(refer):
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
ges.append(ge)
ge = torch.stack(ges, 0).mean(0)
else:
ge = get_ge(refer, sv_emb)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def decode_steaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
@ -1031,7 +1080,10 @@ class SynthesizerTrn(nn.Module):
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
, result_length=result_length, overlap_frames=overlap_frames, padding_length=padding_length)
result_length=result_length,
overlap_frames=overlap_frames,
padding_length=padding_length
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -1277,7 +1329,7 @@ class SynthesizerTrnV3(nn.Module):
return cfm_loss
@torch.no_grad()
def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None, overlap_frames:torch.Tensor=None):
def decode_encp(self, codes, text, refer, ge=None, speed=1):
# print(2333333,refer.shape)
# ge=None
if ge == None:
@ -1285,26 +1337,22 @@ class SynthesizerTrnV3(nn.Module):
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
if speed == 1:
sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4))
sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4))
else:
sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4) / speed) + 1
sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
result_length = result_length * 2 if result_length is not None else None
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length, overlap_frames=overlap_frames)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea, ge, y_, y_mask_
return fea, ge
def extract_latent(self, x):
ssl = self.ssl_proj(x)

View File

@ -30,7 +30,7 @@ POST:
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
@ -42,7 +42,7 @@ POST:
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
}
```
@ -174,7 +174,7 @@ class TTS_Request(BaseModel):
sample_steps: int = 32
super_sampling: bool = False
overlap_length: int = 2
chunk_length: int = 24
min_chunk_length: int = 16
return_fragment: bool = False
@ -333,7 +333,7 @@ async def tts_handle(req: dict):
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
}
returns:
@ -416,7 +416,7 @@ async def tts_get_endpoint(
sample_steps: int = 32,
super_sampling: bool = False,
overlap_length: int = 2,
chunk_length: int = 24,
min_chunk_length: int = 16,
return_fragment: bool = False,
):
req = {
@ -443,7 +443,7 @@ async def tts_get_endpoint(
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
"overlap_length": int(overlap_length),
"chunk_length": int(chunk_length),
"min_chunk_length": int(min_chunk_length),
"return_fragment": return_fragment,
}
return await tts_handle(req)