mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-17 23:19:47 +08:00
更好的流式推理模式
This commit is contained in:
parent
13055fa569
commit
ab7589b5b4
@ -794,7 +794,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list = []
|
||||
idx_list = []
|
||||
for i in range(len(x)):
|
||||
y, idx = self.infer_panel_naive(
|
||||
y, idx = next(self.infer_panel_naive(
|
||||
x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
@ -805,7 +805,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs,
|
||||
)
|
||||
))
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
@ -822,6 +822,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
streaming_mode: bool = False,
|
||||
chunk_length: int = 24,
|
||||
**kwargs,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
@ -875,7 +877,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
token_counter = 0
|
||||
for idx in tqdm(range(1500)):
|
||||
token_counter+=1
|
||||
if xy_attn_mask is not None:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||
else:
|
||||
@ -900,22 +904,42 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
stop = True
|
||||
y=y[:, :-1]
|
||||
token_counter -= 1
|
||||
|
||||
if idx == 1499:
|
||||
stop = True
|
||||
|
||||
if stop:
|
||||
if y.shape[1] == 0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
if streaming_mode:
|
||||
# y=y[:, :-1]
|
||||
# res_len = (y.shape[1] - prefix_len)%chunk_length
|
||||
yield (y[:, -token_counter:]) if token_counter!= 0 else None, True
|
||||
break
|
||||
|
||||
if streaming_mode and token_counter == chunk_length:
|
||||
token_counter = 0
|
||||
yield y[:, -chunk_length:], False
|
||||
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx
|
||||
|
||||
|
||||
if not streaming_mode:
|
||||
if ref_free:
|
||||
yield y, 0
|
||||
yield y, idx
|
||||
|
||||
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
@ -930,6 +954,6 @@ class Text2SemanticDecoder(nn.Module):
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs,
|
||||
):
|
||||
return self.infer_panel_naive(
|
||||
return next(self.infer_panel_naive(
|
||||
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||
)
|
||||
))
|
||||
|
@ -258,6 +258,12 @@ class TTS_Config:
|
||||
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"]
|
||||
languages: list = v2_languages
|
||||
mute_tokens: dict = {
|
||||
"v1" : 486,
|
||||
"v2" : 486,
|
||||
"v3" : 486,
|
||||
"v4" : 486,
|
||||
}
|
||||
# "all_zh",#全部按中文识别
|
||||
# "en",#全部按英文识别#######不变
|
||||
# "all_ja",#全部按日文识别
|
||||
@ -956,7 +962,10 @@ class TTS:
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"streaming_mode": False, # bool. return audio chunk by chunk.
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
}
|
||||
returns:
|
||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||
@ -986,19 +995,40 @@ class TTS:
|
||||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||
sample_steps = inputs.get("sample_steps", 32)
|
||||
super_sampling = inputs.get("super_sampling", False)
|
||||
streaming_mode = inputs.get("streaming_mode", False)
|
||||
overlap_length = inputs.get("overlap_length", 2)
|
||||
chunk_length = inputs.get("chunk_length", 24)
|
||||
|
||||
if parallel_infer:
|
||||
if parallel_infer and not streaming_mode:
|
||||
print(i18n("并行推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||
elif not parallel_infer and streaming_mode and not self.configs.use_vocoder:
|
||||
print(i18n("流式推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||
elif streaming_mode and self.configs.use_vocoder:
|
||||
print(i18n("SoVits V3/4模型不支持流式推理模式,已自动回退到分段返回模式"))
|
||||
streaming_mode = False
|
||||
return_fragment = True
|
||||
if parallel_infer:
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||
else:
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||||
elif parallel_infer and streaming_mode:
|
||||
print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式"))
|
||||
parallel_infer = False
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||
else:
|
||||
print(i18n("并行推理模式已关闭"))
|
||||
print(i18n("朴素推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||||
|
||||
if return_fragment:
|
||||
print(i18n("分段返回模式已开启"))
|
||||
if split_bucket:
|
||||
split_bucket = False
|
||||
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
||||
if return_fragment and streaming_mode:
|
||||
print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回"))
|
||||
return_fragment = False
|
||||
|
||||
if (return_fragment or streaming_mode) and split_bucket:
|
||||
print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理"))
|
||||
split_bucket = False
|
||||
|
||||
|
||||
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
|
||||
print(i18n("分桶处理模式已开启"))
|
||||
@ -1011,9 +1041,9 @@ class TTS:
|
||||
else:
|
||||
print(i18n("分桶处理模式已关闭"))
|
||||
|
||||
if fragment_interval < 0.01:
|
||||
fragment_interval = 0.01
|
||||
print(i18n("分段间隔过小,已自动设置为0.01"))
|
||||
# if fragment_interval < 0.01:
|
||||
# fragment_interval = 0.01
|
||||
# print(i18n("分段间隔过小,已自动设置为0.01"))
|
||||
|
||||
no_prompt_text = False
|
||||
if prompt_text in [None, ""]:
|
||||
@ -1071,7 +1101,7 @@ class TTS:
|
||||
###### text preprocessing ########
|
||||
t1 = time.perf_counter()
|
||||
data: list = None
|
||||
if not return_fragment:
|
||||
if not (return_fragment or streaming_mode):
|
||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
||||
if len(data) == 0:
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
@ -1134,7 +1164,7 @@ class TTS:
|
||||
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
||||
for item in data:
|
||||
t3 = time.perf_counter()
|
||||
if return_fragment:
|
||||
if return_fragment or streaming_mode:
|
||||
item = make_batch(item)
|
||||
if item is None:
|
||||
continue
|
||||
@ -1156,94 +1186,214 @@ class TTS:
|
||||
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||
)
|
||||
|
||||
print(f"############ {i18n('预测语义Token')} ############")
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_lens,
|
||||
prompt,
|
||||
all_bert_features,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
if not streaming_mode:
|
||||
print(f"############ {i18n('预测语义Token')} ############")
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_lens,
|
||||
prompt,
|
||||
all_bert_features,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spec: torch.Tensor = [
|
||||
item.to(dtype=self.precision, device=self.configs.device)
|
||||
for item in self.prompt_cache["refer_spec"]
|
||||
]
|
||||
refer_audio_spec: torch.Tensor = [
|
||||
item.to(dtype=self.precision, device=self.configs.device)
|
||||
for item in self.prompt_cache["refer_spec"]
|
||||
]
|
||||
|
||||
batch_audio_fragment = []
|
||||
batch_audio_fragment = []
|
||||
|
||||
# ## vits并行推理 method 1
|
||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||
# max_len = 0
|
||||
# for i in range(0, len(batch_phones)):
|
||||
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||||
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
|
||||
# ))
|
||||
print(f"############ {i18n('合成音频')} ############")
|
||||
if not self.configs.use_vocoder:
|
||||
if speed_factor == 1.0:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [
|
||||
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
||||
for i in range(0, len(pred_semantic_list))
|
||||
]
|
||||
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = (
|
||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment = [
|
||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||
for i in range(1, len(audio_frag_end_idx))
|
||||
]
|
||||
else:
|
||||
# ## vits串行推理
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||
else:
|
||||
if parallel_infer:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
||||
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.extend(audio_fragments)
|
||||
else:
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.using_vocoder_synthesis(
|
||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||
# ## vits并行推理 method 1
|
||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||
# max_len = 0
|
||||
# for i in range(0, len(batch_phones)):
|
||||
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||||
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
|
||||
# ))
|
||||
print(f"############ {i18n('合成音频')} ############")
|
||||
if not self.configs.use_vocoder:
|
||||
if speed_factor == 1.0:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [
|
||||
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
||||
for i in range(0, len(pred_semantic_list))
|
||||
]
|
||||
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = (
|
||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
)
|
||||
batch_audio_fragment.append(audio_fragment)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment = [
|
||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||
for i in range(1, len(audio_frag_end_idx))
|
||||
]
|
||||
else:
|
||||
# ## vits串行推理
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||
else:
|
||||
if parallel_infer:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
||||
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.extend(audio_fragments)
|
||||
else:
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.using_vocoder_synthesis(
|
||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.append(audio_fragment)
|
||||
|
||||
else:
|
||||
refer_audio_spec: torch.Tensor = [
|
||||
item.to(dtype=self.precision, device=self.configs.device)
|
||||
for item in self.prompt_cache["refer_spec"]
|
||||
]
|
||||
semantic_token_generator =self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids[0].unsqueeze(0),
|
||||
all_phoneme_lens,
|
||||
prompt,
|
||||
all_bert_features[0].unsqueeze(0),
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
streaming_mode=True,
|
||||
chunk_length=chunk_length,
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
phones = batch_phones[0].unsqueeze(0).to(self.configs.device)
|
||||
is_first_chunk = True
|
||||
|
||||
if not self.configs.use_vocoder:
|
||||
if speed_factor == 1.0:
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
|
||||
else:
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
|
||||
else:
|
||||
if speed_factor == 1.0:
|
||||
upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
|
||||
else:
|
||||
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
|
||||
|
||||
last_audio_chunk = None
|
||||
last_tokens = None
|
||||
previous_tokens = []
|
||||
overlap_size = math.ceil(overlap_length*upsample_rate)
|
||||
for semantic_tokens, is_final in semantic_token_generator:
|
||||
if semantic_tokens is None and last_audio_chunk is not None:
|
||||
yield self.audio_postprocess(
|
||||
[[last_audio_chunk[-overlap_size:]]],
|
||||
output_sr,
|
||||
None,
|
||||
speed_factor,
|
||||
False,
|
||||
0.0,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
continue
|
||||
|
||||
_semantic_tokens = semantic_tokens
|
||||
# if is_first_chunk:
|
||||
# _semantic_tokens = torch.cat([torch.ones((1,overlap_length), dtype=torch.long, device=self.configs.device)*self.configs.mute_tokens[self.configs.version], _semantic_tokens], dim=-1)
|
||||
# else:
|
||||
# _semantic_tokens = torch.cat([last_tokens[:, -overlap_length:], _semantic_tokens], dim=-1)
|
||||
# # _semantic_tokens = torch.cat(previous_tokens+[_semantic_tokens,], dim=-1)
|
||||
|
||||
previous_tokens.append(semantic_tokens)
|
||||
|
||||
_semantic_tokens = torch.cat(previous_tokens, dim=-1)
|
||||
|
||||
|
||||
# last_tokens = semantic_tokens
|
||||
|
||||
# print(f"_semantic_tokens shape:{_semantic_tokens.shape}")
|
||||
|
||||
|
||||
if not self.configs.use_vocoder:
|
||||
audio_chunk = self.vits_model.decode(
|
||||
_semantic_tokens.unsqueeze(0),
|
||||
phones, refer_audio_spec,
|
||||
speed=speed_factor,
|
||||
result_length=semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None
|
||||
# result_length=chunk_length if not is_first_chunk else None
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
audio_chunk = self.using_vocoder_synthesis(
|
||||
_semantic_tokens.unsqueeze(0), phones,
|
||||
speed=speed_factor, sample_steps=sample_steps,
|
||||
result_length = semantic_tokens.shape[-1]+overlap_length if not is_first_chunk else None
|
||||
)
|
||||
|
||||
|
||||
|
||||
# if is_first_chunk:
|
||||
# audio_chunk = audio_chunk[overlap_size:]
|
||||
# # is_first_chunk = False
|
||||
|
||||
audio_chunk_ = audio_chunk
|
||||
if is_first_chunk and not is_final:
|
||||
is_first_chunk = False
|
||||
audio_chunk_ = audio_chunk_[:-overlap_size]
|
||||
elif is_first_chunk and is_final:
|
||||
is_first_chunk = False
|
||||
elif not is_first_chunk and not is_final:
|
||||
audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size)
|
||||
audio_chunk_ = (
|
||||
audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \
|
||||
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
|
||||
)
|
||||
# audio_chunk_ = audio_chunk_[:-overlap_size] if not is_final else audio_chunk_
|
||||
|
||||
last_audio_chunk = audio_chunk
|
||||
yield self.audio_postprocess(
|
||||
[[audio_chunk_]],
|
||||
output_sr,
|
||||
None,
|
||||
speed_factor,
|
||||
False,
|
||||
0.0,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
print(f"first_package_delay: {time.perf_counter()-t0:.3f}")
|
||||
|
||||
yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16)
|
||||
|
||||
t5 = time.perf_counter()
|
||||
t_45 += t5 - t4
|
||||
@ -1258,17 +1408,18 @@ class TTS:
|
||||
fragment_interval,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
elif streaming_mode:...
|
||||
else:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
if not (return_fragment or streaming_mode):
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
if len(audio) == 0:
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||||
return
|
||||
yield self.audio_postprocess(
|
||||
audio,
|
||||
@ -1315,16 +1466,17 @@ class TTS:
|
||||
fragment_interval: float = 0.3,
|
||||
super_sampling: bool = False,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
||||
)
|
||||
if fragment_interval>0:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
||||
)
|
||||
|
||||
for i, batch in enumerate(audio):
|
||||
for j, audio_fragment in enumerate(batch):
|
||||
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
|
||||
if max_audio > 1:
|
||||
audio_fragment /= max_audio
|
||||
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment
|
||||
audio[i][j] = audio_fragment
|
||||
|
||||
if split_bucket:
|
||||
@ -1344,12 +1496,12 @@ class TTS:
|
||||
max_audio = np.abs(audio).max()
|
||||
if max_audio > 1:
|
||||
audio /= max_audio
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
t2 = time.perf_counter()
|
||||
print(f"超采样用时:{t2 - t1:.3f}s")
|
||||
else:
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
audio = audio.float() * 32768
|
||||
audio = audio.to(dtype=torch.int16).cpu().numpy()
|
||||
|
||||
# try:
|
||||
# if speed_factor != 1.0:
|
||||
@ -1360,7 +1512,7 @@ class TTS:
|
||||
return sr, audio
|
||||
|
||||
def using_vocoder_synthesis(
|
||||
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
|
||||
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32, result_length:int=None
|
||||
):
|
||||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||
@ -1392,7 +1544,7 @@ class TTS:
|
||||
chunk_len = T_chunk - T_min
|
||||
|
||||
mel2 = mel2.to(self.precision)
|
||||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed, result_length=result_length)
|
||||
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
@ -1415,7 +1567,7 @@ class TTS:
|
||||
cfm_res = torch.cat(cfm_resss, 2)
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
|
||||
with torch.inference_mode():
|
||||
with torch.no_grad():
|
||||
wav_gen = self.vocoder(cfm_res)
|
||||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||
|
||||
|
@ -209,7 +209,7 @@ class TextEncoder(nn.Module):
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
@ -223,6 +223,11 @@ class TextEncoder(nn.Module):
|
||||
text = self.encoder_text(text * text_mask, text_mask)
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
|
||||
if result_length is not None:
|
||||
y = y[:, :, -result_length:]
|
||||
y_mask = y_mask[:, :, -result_length:]
|
||||
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
@ -940,7 +945,7 @@ class SynthesizerTrn(nn.Module):
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, result_length:int=None):
|
||||
def get_ge(refer):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
@ -967,7 +972,8 @@ class SynthesizerTrn(nn.Module):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
result_length = (2*result_length) if result_length is not None else None
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -1187,7 +1193,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
return cfm_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_encp(self, codes, text, refer, ge=None, speed=1):
|
||||
def decode_encp(self, codes, text, refer, ge=None, speed=1, result_length:int=None):
|
||||
# print(2333333,refer.shape)
|
||||
# ge=None
|
||||
if ge == None:
|
||||
@ -1195,17 +1201,21 @@ class SynthesizerTrnV3(nn.Module):
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
||||
|
||||
if speed == 1:
|
||||
sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4))
|
||||
sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4))
|
||||
else:
|
||||
sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4) / speed) + 1
|
||||
sizee = int((codes.size(2) if result_length is None else result_length) * (3.875 if self.version=="v3"else 4) / speed) + 1
|
||||
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
result_length = result_length * 2 if result_length is not None else None
|
||||
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed, result_length=result_length)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
|
25
api_v2.py
25
api_v2.py
@ -40,7 +40,10 @@ POST:
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
|
||||
}
|
||||
```
|
||||
|
||||
@ -170,6 +173,9 @@ class TTS_Request(BaseModel):
|
||||
repetition_penalty: float = 1.35
|
||||
sample_steps: int = 32
|
||||
super_sampling: bool = False
|
||||
overlap_length: int = 2
|
||||
chunk_length: int = 24
|
||||
return_fragment: bool = False
|
||||
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
@ -325,7 +331,10 @@ async def tts_handle(req: dict):
|
||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"chunk_length: 24, # int. chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
|
||||
}
|
||||
returns:
|
||||
StreamingResponse: audio stream response.
|
||||
@ -339,8 +348,10 @@ async def tts_handle(req: dict):
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
if streaming_mode or return_fragment:
|
||||
req["return_fragment"] = True
|
||||
req["streaming_mode"] = streaming_mode
|
||||
req["return_fragment"] = return_fragment
|
||||
streaming_mode = streaming_mode or return_fragment
|
||||
|
||||
|
||||
try:
|
||||
tts_generator = tts_pipeline.run(req)
|
||||
@ -404,6 +415,9 @@ async def tts_get_endpoint(
|
||||
repetition_penalty: float = 1.35,
|
||||
sample_steps: int = 32,
|
||||
super_sampling: bool = False,
|
||||
overlap_length: int = 2,
|
||||
chunk_length: int = 24,
|
||||
return_fragment: bool = False,
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
@ -428,6 +442,9 @@ async def tts_get_endpoint(
|
||||
"repetition_penalty": float(repetition_penalty),
|
||||
"sample_steps": int(sample_steps),
|
||||
"super_sampling": super_sampling,
|
||||
"overlap_length": int(overlap_length),
|
||||
"chunk_length": int(chunk_length),
|
||||
"return_fragment": return_fragment,
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user