更细粒度的流式推理模式 (#2671)

* 更好的流式推理模式

* 清理无用代码

* modified:   GPT_SoVITS/AR/models/t2s_model.py
	modified:   GPT_SoVITS/TTS_infer_pack/TTS.py
	modified:   GPT_SoVITS/module/models.py

* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py

* modified:   .gitignore
	modified:   GPT_SoVITS/AR/models/t2s_model.py
	modified:   GPT_SoVITS/TTS_infer_pack/TTS.py
	modified:   GPT_SoVITS/module/models.py

* modified:   GPT_SoVITS/AR/models/t2s_model.py
	modified:   GPT_SoVITS/TTS_infer_pack/TTS.py
	modified:   GPT_SoVITS/module/models.py
	modified:   api_v2.py

* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py

* 更正拼写错误

* 支持固定chunk长度的流式推理,优化sola算法

* 修复api_v2的ogg格式传输问题
This commit is contained in:
ChasonJiang 2025-11-28 21:12:41 +08:00 committed by GitHub
parent 11aa78bd9b
commit 92ab59c553
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 548 additions and 166 deletions

View File

@ -794,7 +794,7 @@ class Text2SemanticDecoder(nn.Module):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(
y, idx = next(self.infer_panel_naive(
x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
@ -805,7 +805,7 @@ class Text2SemanticDecoder(nn.Module):
temperature,
repetition_penalty,
**kwargs,
)
))
y_list.append(y[0])
idx_list.append(idx)
@ -822,8 +822,15 @@ class Text2SemanticDecoder(nn.Module):
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
streaming_mode: bool = False,
chunk_length: int = 24,
**kwargs,
):
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3)
check_token_num = 2
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
@ -875,7 +882,10 @@ class Text2SemanticDecoder(nn.Module):
.to(device=x.device, dtype=torch.bool)
)
token_counter = 0
curr_ptr = prefix_len
for idx in tqdm(range(1500)):
token_counter+=1
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
else:
@ -900,22 +910,56 @@ class Text2SemanticDecoder(nn.Module):
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
y=y[:, :-1]
token_counter -= 1
if idx == 1499:
stop = True
if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
if streaming_mode:
yield y[:, curr_ptr:] if curr_ptr<y.shape[1] else None, True
break
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num):
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold
score[score<0]=-1
score[:-1]=score[:-1]+score[1:] ##考虑连续两个token
argmax_idx = score.argmax()
if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length:
print(f"\n\ncurr_ptr:{curr_ptr}")
yield y[:, curr_ptr:], False
token_counter -= argmax_idx+1
curr_ptr += argmax_idx+1
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
yield y[:, -token_counter:], False
curr_ptr+=token_counter
token_counter = 0
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if not streaming_mode:
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx
yield y, 0
yield y, idx
def infer_panel(
self,
@ -930,6 +974,6 @@ class Text2SemanticDecoder(nn.Module):
repetition_penalty: float = 1.35,
**kwargs,
):
return self.infer_panel_naive(
return next(self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)
))

View File

@ -275,6 +275,15 @@ class TTS_Config:
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"]
languages: list = v2_languages
mute_tokens: dict = {
"v1" : 486,
"v2" : 486,
"v2Pro": 486,
"v2ProPlus": 486,
"v3" : 486,
"v4" : 486,
}
mute_emb_sim_matrix: torch.Tensor = None
# "all_zh",#全部按中文识别
# "en",#全部按英文识别#######不变
# "all_ja",#全部按日文识别
@ -598,6 +607,11 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half()
codebook = t2s_model.model.ar_audio_embedding.weight.clone()
mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0)
sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1)
self.configs.mute_emb_sim_matrix = sim_matrix
def init_vocoder(self, version: str):
if version == "v3":
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
@ -997,18 +1011,22 @@ class TTS:
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
}
returns:
Tuple[int, np.ndarray]: sampling rate and audio data.
@ -1024,7 +1042,7 @@ class TTS:
top_k: int = inputs.get("top_k", 5)
top_p: float = inputs.get("top_p", 1)
temperature: float = inputs.get("temperature", 1)
text_split_method: str = inputs.get("text_split_method", "cut0")
text_split_method: str = inputs.get("text_split_method", "cut1")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
@ -1038,19 +1056,43 @@ class TTS:
repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 32)
super_sampling = inputs.get("super_sampling", False)
streaming_mode = inputs.get("streaming_mode", False)
overlap_length = inputs.get("overlap_length", 2)
min_chunk_length = inputs.get("min_chunk_length", 16)
fixed_length_chunk = inputs.get("fixed_length_chunk", False)
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值若大于该阈值则视为可切分点。
if parallel_infer:
if parallel_infer and not streaming_mode:
print(i18n("并行推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
elif not parallel_infer and streaming_mode and not self.configs.use_vocoder:
print(i18n("流式推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
elif streaming_mode and self.configs.use_vocoder:
print(i18n("SoVits V3/4模型不支持流式推理模式已自动回退到分段返回模式"))
streaming_mode = False
return_fragment = True
if parallel_infer:
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
else:
print(i18n("并行推理模式已关闭"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
# self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
elif parallel_infer and streaming_mode:
print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式"))
parallel_infer = False
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
else:
print(i18n("朴素推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
if return_fragment:
print(i18n("分段返回模式已开启"))
if split_bucket:
if return_fragment and streaming_mode:
print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回"))
return_fragment = False
if (return_fragment or streaming_mode) and split_bucket:
print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理"))
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
print(i18n("分桶处理模式已开启"))
@ -1063,9 +1105,9 @@ class TTS:
else:
print(i18n("分桶处理模式已关闭"))
if fragment_interval < 0.01:
fragment_interval = 0.01
print(i18n("分段间隔过小已自动设置为0.01"))
# if fragment_interval < 0.01:
# fragment_interval = 0.01
# print(i18n("分段间隔过小已自动设置为0.01"))
no_prompt_text = False
if prompt_text in [None, ""]:
@ -1126,7 +1168,7 @@ class TTS:
###### text preprocessing ########
t1 = time.perf_counter()
data: list = None
if not return_fragment:
if not (return_fragment or streaming_mode):
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0:
yield 16000, np.zeros(int(16000), dtype=np.int16)
@ -1186,10 +1228,11 @@ class TTS:
t_34 = 0.0
t_45 = 0.0
audio = []
is_first_package = True
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
for item in data:
t3 = time.perf_counter()
if return_fragment:
if return_fragment or streaming_mode:
item = make_batch(item)
if item is None:
continue
@ -1211,6 +1254,16 @@ class TTS:
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
)
refer_audio_spec = []
sv_emb = [] if self.is_v2pro else None
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
spec = spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec)
if self.is_v2pro:
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
if not streaming_mode:
print(f"############ {i18n('预测语义Token')} ############")
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids,
@ -1228,14 +1281,6 @@ class TTS:
t4 = time.perf_counter()
t_34 += t4 - t3
refer_audio_spec = []
if self.is_v2pro:
sv_emb = []
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
spec = spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec)
if self.is_v2pro:
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
batch_audio_fragment = []
@ -1267,14 +1312,11 @@ class TTS:
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
if self.is_v2pro != True:
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else:
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment = [
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
@ -1287,11 +1329,6 @@ class TTS:
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
if self.is_v2pro != True:
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else:
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
@ -1314,6 +1351,132 @@ class TTS:
)
batch_audio_fragment.append(audio_fragment)
else:
# refer_audio_spec: torch.Tensor = [
# item.to(dtype=self.precision, device=self.configs.device)
# for item in self.prompt_cache["refer_spec"]
# ]
semantic_token_generator =self.t2s_model.model.infer_panel(
all_phoneme_ids[0].unsqueeze(0),
all_phoneme_lens,
prompt,
all_bert_features[0].unsqueeze(0),
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
max_len=max_len,
repetition_penalty=repetition_penalty,
streaming_mode=True,
chunk_length=min_chunk_length,
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None,
chunk_split_thershold=chunk_split_thershold,
)
t4 = time.perf_counter()
t_34 += t4 - t3
phones = batch_phones[0].unsqueeze(0).to(self.configs.device)
is_first_chunk = True
if not self.configs.use_vocoder:
# if speed_factor == 1.0:
# upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
# else:
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
else:
# if speed_factor == 1.0:
# upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
# else:
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
last_audio_chunk = None
# last_tokens = None
last_latent = None
previous_tokens = []
overlap_len = overlap_length
overlap_size = math.ceil(overlap_length*upsample_rate)
for semantic_tokens, is_final in semantic_token_generator:
if semantic_tokens is None and last_audio_chunk is not None:
yield self.audio_postprocess(
[[last_audio_chunk[-overlap_size:]]],
output_sr,
None,
speed_factor,
False,
0.0,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
break
_semantic_tokens = semantic_tokens
print(f"semantic_tokens shape:{semantic_tokens.shape}")
previous_tokens.append(semantic_tokens)
_semantic_tokens = torch.cat(previous_tokens, dim=-1)
if not is_first_chunk and semantic_tokens.shape[-1] < 10:
overlap_len = overlap_length+(10-semantic_tokens.shape[-1])
else:
overlap_len = overlap_length
if not self.configs.use_vocoder:
token_padding_length = 0
# token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1]
# if token_padding_length>0:
# _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486)
# else:
# token_padding_length = 0
audio_chunk, latent, latent_mask = self.vits_model.decode_streaming(
_semantic_tokens.unsqueeze(0),
phones, refer_audio_spec,
speed=speed_factor,
sv_emb=sv_emb,
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
if last_latent is not None else None,
padding_length=token_padding_length
)
audio_chunk=audio_chunk.detach()[0, 0, :]
else:
raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式"))
if overlap_len>overlap_length:
audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):]
audio_chunk_ = audio_chunk
if is_first_chunk and not is_final:
is_first_chunk = False
audio_chunk_ = audio_chunk_[:-overlap_size]
elif is_first_chunk and is_final:
is_first_chunk = False
elif not is_first_chunk and not is_final:
audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size)
audio_chunk_ = (
audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
)
last_latent = latent
last_audio_chunk = audio_chunk
yield self.audio_postprocess(
[[audio_chunk_]],
output_sr,
None,
speed_factor,
False,
0.0,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
if is_first_package:
print(f"first_package_delay: {time.perf_counter()-t0:.3f}")
is_first_package = False
yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16)
t5 = time.perf_counter()
t_45 += t5 - t4
if return_fragment:
@ -1327,17 +1490,18 @@ class TTS:
fragment_interval,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
elif streaming_mode:...
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield 16000, np.zeros(int(16000), dtype=np.int16)
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
return
if not return_fragment:
if not (return_fragment or streaming_mode):
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if len(audio) == 0:
yield 16000, np.zeros(int(16000), dtype=np.int16)
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
return
yield self.audio_postprocess(
audio,
@ -1384,6 +1548,7 @@ class TTS:
fragment_interval: float = 0.3,
super_sampling: bool = False,
) -> Tuple[int, np.ndarray]:
if fragment_interval>0:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
)
@ -1393,7 +1558,7 @@ class TTS:
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
if max_audio > 1:
audio_fragment /= max_audio
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment
audio[i][j] = audio_fragment
if split_bucket:
@ -1413,13 +1578,18 @@ class TTS:
max_audio = np.abs(audio).max()
if max_audio > 1:
audio /= max_audio
audio = (audio * 32768).astype(np.int16)
t2 = time.perf_counter()
print(f"超采样用时:{t2 - t1:.3f}s")
else:
# audio = audio.float() * 32768
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy()
audio = audio.cpu().numpy()
audio = (audio * 32768).astype(np.int16)
# try:
# if speed_factor != 1.0:
# audio = speed_change(audio, speed=speed_factor, sr=int(sr))
@ -1612,24 +1782,43 @@ class TTS:
self,
audio_fragments: List[torch.Tensor],
overlap_len: int,
search_len:int= 320
):
# overlap_len-=search_len
dtype = audio_fragments[0].dtype
for i in range(len(audio_fragments) - 1):
f1 = audio_fragments[i]
f2 = audio_fragments[i + 1]
f1 = audio_fragments[i].float()
f2 = audio_fragments[i + 1].float()
w1 = f1[-overlap_len:]
w2 = f2[:overlap_len]
assert w1.shape == w2.shape
corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1), padding=w2.shape[-1] // 2).view(-1)[:-1]
idx = corr.argmax()
f1_ = f1[: -(overlap_len - idx)]
w2 = f2[:overlap_len+search_len]
# w2 = w2[-w2.shape[-1]//2:]
# assert w1.shape == w2.shape
corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1)
corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8
idx = (corr_norm/corr_den.sqrt()).argmax()
print(f"seg_idx: {idx}")
# idx = corr.argmax()
f1_ = f1[: -overlap_len]
audio_fragments[i] = f1_
f2_ = f2[idx:]
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
f2_[: (overlap_len - idx)] = (
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype)
f2_[: overlap_len] = (
window[: overlap_len] * f2_[: overlap_len]
+ window[overlap_len :] * f1[-overlap_len :]
)
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
# f2_[: (overlap_len - idx)] = (
# window * f2_[: (overlap_len - idx)]
# + (1-window) * f1[-(overlap_len - idx) :]
# )
audio_fragments[i + 1] = f2_
return torch.cat(audio_fragments, 0)
return torch.cat(audio_fragments, 0).to(dtype)

View File

@ -151,6 +151,8 @@ class DurationPredictor(nn.Module):
return x * x_mask
WINDOW = {}
class TextEncoder(nn.Module):
def __init__(
self,
@ -209,7 +211,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y = self.ssl_proj(y * y_mask) * y_mask
@ -222,13 +224,44 @@ class TextEncoder(nn.Module):
text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask)
y = self.mrte(y, y_mask, text, text_mask, ge)
if padding_length is not None and padding_length!=0:
y = y[:, :, :-padding_length]
y_mask = y_mask[:, :, :-padding_length]
y = self.encoder2(y * y_mask, y_mask)
if result_length is not None:
y = y[:, :, -result_length:]
y_mask = y_mask[:, :, -result_length:]
if overlap_frames is not None:
overlap_len = overlap_frames.shape[-1]
window = WINDOW.get(overlap_len, None)
if window is None:
# WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2))
window = WINDOW[overlap_len]
window = window.to(y.device)
y[:,:,:overlap_len] = (
window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len]
+ window[overlap_len:].view(1, 1, -1) * overlap_frames
)
y_ = y
y_mask_ = y_mask
if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask
return y, m, logs, y_mask, y_, y_mask_
def extract_latent(self, x):
x = self.ssl_proj(x)
@ -921,7 +954,7 @@ class SynthesizerTrn(nn.Module):
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
@ -949,7 +982,7 @@ class SynthesizerTrn(nn.Module):
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -957,6 +990,7 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
def get_ge(refer, sv_emb):
@ -989,7 +1023,7 @@ class SynthesizerTrn(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
quantized,
y_lengths,
text,
@ -1004,6 +1038,59 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
return ge
if type(refer) == list:
ges = []
for idx, _refer in enumerate(refer):
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
ges.append(ge)
ge = torch.stack(ges, 0).mean(0)
else:
ge = get_ge(refer, sv_emb)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
result_length = (2*result_length) if result_length is not None else None
padding_length = (2*padding_length) if padding_length is not None else None
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
result_length=result_length,
overlap_frames=overlap_frames,
padding_length=padding_length
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, y_, y_mask_
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
@ -1226,7 +1313,7 @@ class SynthesizerTrnV3(nn.Module):
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
fea, y_mask_ = self.wns1(
@ -1260,7 +1347,7 @@ class SynthesizerTrnV3(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
####more wn paramter to learn mel
@ -1377,7 +1464,7 @@ class SynthesizerTrnV3b(nn.Module):
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
@ -1420,7 +1507,7 @@ class SynthesizerTrnV3b(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel

View File

@ -30,17 +30,22 @@ POST:
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"streaming_mode": False, # bool. whether to return a streaming response.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
}
```
@ -121,6 +126,7 @@ from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from pydantic import BaseModel
import threading
# print(sys.path)
i18n = I18nAuto()
@ -170,12 +176,55 @@ class TTS_Request(BaseModel):
repetition_penalty: float = 1.35
sample_steps: int = 32
super_sampling: bool = False
overlap_length: int = 2
min_chunk_length: int = 16
return_fragment: bool = False
fixed_length_chunk: bool = False
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
# Author: AkagawaTsurunaki
# Issue:
# Stack overflow probabilistically occurs
# when the function `sf_writef_short` of `libsndfile_64bit.dll` is called
# using the Python library `soundfile`
# Note:
# This is an issue related to `libsndfile`, not this project itself.
# It happens when you generate a large audio tensor (about 499804 frames in my PC)
# and try to convert it to an ogg file.
# Related:
# https://github.com/RVC-Boss/GPT-SoVITS/issues/1199
# https://github.com/libsndfile/libsndfile/issues/1023
# https://github.com/bastibe/python-soundfile/issues/396
# Suggestion:
# Or split the whole audio data into smaller audio segment to avoid stack overflow?
def handle_pack_ogg():
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
# See: https://docs.python.org/3/library/threading.html
# The stack size of this thread is at least 32768
# If stack overflow error still occurs, just modify the `stack_size`.
# stack_size = n * 4096, where n should be a positive integer.
# Here we chose n = 4096.
stack_size = 4096 * 4096
try:
threading.stack_size(stack_size)
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
pack_ogg_thread.start()
pack_ogg_thread.join()
except RuntimeError as e:
# If changing the thread stack size is unsupported, a RuntimeError is raised.
print("RuntimeError: {}".format(e))
print("Changing the thread stack size is unsupported.")
except ValueError as e:
# If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified.
print("ValueError: {}".format(e))
print("The specified stack size is invalid.")
return io_buffer
@ -286,8 +335,8 @@ def check_params(req: dict):
)
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
# elif media_type == "ogg" and not streaming_mode:
# return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
if text_split_method not in cut_method_names:
return JSONResponse(
@ -307,7 +356,7 @@ async def tts_handle(req: dict):
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
@ -316,16 +365,19 @@ async def tts_handle(req: dict):
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
}
returns:
StreamingResponse: audio stream response.
@ -339,8 +391,10 @@ async def tts_handle(req: dict):
if check_res is not None:
return check_res
if streaming_mode or return_fragment:
req["return_fragment"] = True
req["streaming_mode"] = streaming_mode
req["return_fragment"] = return_fragment
streaming_mode = streaming_mode or return_fragment
try:
tts_generator = tts_pipeline.run(req)
@ -391,7 +445,7 @@ async def tts_get_endpoint(
top_k: int = 5,
top_p: float = 1,
temperature: float = 1,
text_split_method: str = "cut0",
text_split_method: str = "cut5",
batch_size: int = 1,
batch_threshold: float = 0.75,
split_bucket: bool = True,
@ -399,11 +453,15 @@ async def tts_get_endpoint(
fragment_interval: float = 0.3,
seed: int = -1,
media_type: str = "wav",
streaming_mode: bool = False,
parallel_infer: bool = True,
repetition_penalty: float = 1.35,
sample_steps: int = 32,
super_sampling: bool = False,
return_fragment: bool = False,
streaming_mode: bool = False,
overlap_length: int = 2,
min_chunk_length: int = 16,
fixed_length_chunk: bool = False,
):
req = {
"text": text,
@ -428,6 +486,10 @@ async def tts_get_endpoint(
"repetition_penalty": float(repetition_penalty),
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
"overlap_length": int(overlap_length),
"min_chunk_length": int(min_chunk_length),
"return_fragment": return_fragment,
"fixed_length_chunk": fixed_length_chunk
}
return await tts_handle(req)