mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-04 05:01:27 +08:00
Merge branch 'RVC-Boss:main' into main
This commit is contained in:
commit
0235857b89
@ -707,10 +707,12 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||||
logits = logits[:, :-1]
|
|
||||||
else:
|
else:
|
||||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||||
|
|
||||||
|
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
|
||||||
samples = sample(
|
samples = sample(
|
||||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||||
)[0]
|
)[0]
|
||||||
@ -794,7 +796,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
y_list = []
|
y_list = []
|
||||||
idx_list = []
|
idx_list = []
|
||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
y, idx = self.infer_panel_naive(
|
y, idx = next(self.infer_panel_naive(
|
||||||
x[i].unsqueeze(0),
|
x[i].unsqueeze(0),
|
||||||
x_lens[i],
|
x_lens[i],
|
||||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||||
@ -805,7 +807,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
))
|
||||||
y_list.append(y[0])
|
y_list.append(y[0])
|
||||||
idx_list.append(idx)
|
idx_list.append(idx)
|
||||||
|
|
||||||
@ -822,8 +824,15 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
|
streaming_mode: bool = False,
|
||||||
|
chunk_length: int = 24,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
|
||||||
|
chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3)
|
||||||
|
check_token_num = 2
|
||||||
|
|
||||||
|
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
x = self.ar_text_position(x)
|
x = self.ar_text_position(x)
|
||||||
@ -875,7 +884,10 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
.to(device=x.device, dtype=torch.bool)
|
.to(device=x.device, dtype=torch.bool)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
token_counter = 0
|
||||||
|
curr_ptr = prefix_len
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
|
token_counter+=1
|
||||||
if xy_attn_mask is not None:
|
if xy_attn_mask is not None:
|
||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||||
else:
|
else:
|
||||||
@ -900,22 +912,56 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||||
stop = True
|
stop = True
|
||||||
|
y=y[:, :-1]
|
||||||
|
token_counter -= 1
|
||||||
|
|
||||||
|
if idx == 1499:
|
||||||
|
stop = True
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
if y.shape[1] == 0:
|
if y.shape[1] == 0:
|
||||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||||
print("bad zero prediction")
|
print("bad zero prediction")
|
||||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||||
|
if streaming_mode:
|
||||||
|
yield y[:, curr_ptr:] if curr_ptr<y.shape[1] else None, True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num):
|
||||||
|
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold
|
||||||
|
score[score<0]=-1
|
||||||
|
score[:-1]=score[:-1]+score[1:] ##考虑连续两个token
|
||||||
|
argmax_idx = score.argmax()
|
||||||
|
|
||||||
|
if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length:
|
||||||
|
print(f"\n\ncurr_ptr:{curr_ptr}")
|
||||||
|
yield y[:, curr_ptr:], False
|
||||||
|
token_counter -= argmax_idx+1
|
||||||
|
curr_ptr += argmax_idx+1
|
||||||
|
|
||||||
|
|
||||||
|
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
|
||||||
|
yield y[:, -token_counter:], False
|
||||||
|
curr_ptr+=token_counter
|
||||||
|
token_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
####################### update next step ###################################
|
####################### update next step ###################################
|
||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
:, y_len + idx
|
:, y_len + idx
|
||||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
if ref_free:
|
|
||||||
return y[:, :-1], 0
|
|
||||||
return y[:, :-1], idx
|
if not streaming_mode:
|
||||||
|
if ref_free:
|
||||||
|
yield y, 0
|
||||||
|
yield y, idx
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def infer_panel(
|
def infer_panel(
|
||||||
self,
|
self,
|
||||||
@ -930,6 +976,6 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return self.infer_panel_naive(
|
return next(self.infer_panel_naive(
|
||||||
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||||
)
|
))
|
||||||
|
|||||||
@ -275,6 +275,15 @@ class TTS_Config:
|
|||||||
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||||
v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"]
|
v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"]
|
||||||
languages: list = v2_languages
|
languages: list = v2_languages
|
||||||
|
mute_tokens: dict = {
|
||||||
|
"v1" : 486,
|
||||||
|
"v2" : 486,
|
||||||
|
"v2Pro": 486,
|
||||||
|
"v2ProPlus": 486,
|
||||||
|
"v3" : 486,
|
||||||
|
"v4" : 486,
|
||||||
|
}
|
||||||
|
mute_emb_sim_matrix: torch.Tensor = None
|
||||||
# "all_zh",#全部按中文识别
|
# "all_zh",#全部按中文识别
|
||||||
# "en",#全部按英文识别#######不变
|
# "en",#全部按英文识别#######不变
|
||||||
# "all_ja",#全部按日文识别
|
# "all_ja",#全部按日文识别
|
||||||
@ -598,6 +607,11 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||||
self.t2s_model = self.t2s_model.half()
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
|
codebook = t2s_model.model.ar_audio_embedding.weight.clone()
|
||||||
|
mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0)
|
||||||
|
sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1)
|
||||||
|
self.configs.mute_emb_sim_matrix = sim_matrix
|
||||||
|
|
||||||
def init_vocoder(self, version: str):
|
def init_vocoder(self, version: str):
|
||||||
if version == "v3":
|
if version == "v3":
|
||||||
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
|
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
|
||||||
@ -994,21 +1008,25 @@ class TTS:
|
|||||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||||
"top_k": 5, # int. top k sampling
|
"top_k": 15, # int. top k sampling
|
||||||
"top_p": 1, # float. top p sampling
|
"top_p": 1, # float. top p sampling
|
||||||
"temperature": 1, # float. temperature for sampling
|
"temperature": 1, # float. temperature for sampling
|
||||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
|
||||||
"batch_size": 1, # int. batch size for inference
|
"batch_size": 1, # int. batch size for inference
|
||||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
|
||||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||||
|
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||||||
|
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||||||
|
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||||
|
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||||
|
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||||||
}
|
}
|
||||||
returns:
|
returns:
|
||||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||||
@ -1021,10 +1039,10 @@ class TTS:
|
|||||||
aux_ref_audio_paths: list = inputs.get("aux_ref_audio_paths", [])
|
aux_ref_audio_paths: list = inputs.get("aux_ref_audio_paths", [])
|
||||||
prompt_text: str = inputs.get("prompt_text", "")
|
prompt_text: str = inputs.get("prompt_text", "")
|
||||||
prompt_lang: str = inputs.get("prompt_lang", "")
|
prompt_lang: str = inputs.get("prompt_lang", "")
|
||||||
top_k: int = inputs.get("top_k", 5)
|
top_k: int = inputs.get("top_k", 15)
|
||||||
top_p: float = inputs.get("top_p", 1)
|
top_p: float = inputs.get("top_p", 1)
|
||||||
temperature: float = inputs.get("temperature", 1)
|
temperature: float = inputs.get("temperature", 1)
|
||||||
text_split_method: str = inputs.get("text_split_method", "cut0")
|
text_split_method: str = inputs.get("text_split_method", "cut1")
|
||||||
batch_size = inputs.get("batch_size", 1)
|
batch_size = inputs.get("batch_size", 1)
|
||||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||||
speed_factor = inputs.get("speed_factor", 1.0)
|
speed_factor = inputs.get("speed_factor", 1.0)
|
||||||
@ -1038,19 +1056,43 @@ class TTS:
|
|||||||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||||
sample_steps = inputs.get("sample_steps", 32)
|
sample_steps = inputs.get("sample_steps", 32)
|
||||||
super_sampling = inputs.get("super_sampling", False)
|
super_sampling = inputs.get("super_sampling", False)
|
||||||
|
streaming_mode = inputs.get("streaming_mode", False)
|
||||||
|
overlap_length = inputs.get("overlap_length", 2)
|
||||||
|
min_chunk_length = inputs.get("min_chunk_length", 16)
|
||||||
|
fixed_length_chunk = inputs.get("fixed_length_chunk", False)
|
||||||
|
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。
|
||||||
|
|
||||||
if parallel_infer:
|
if parallel_infer and not streaming_mode:
|
||||||
print(i18n("并行推理模式已开启"))
|
print(i18n("并行推理模式已开启"))
|
||||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||||
|
elif not parallel_infer and streaming_mode and not self.configs.use_vocoder:
|
||||||
|
print(i18n("流式推理模式已开启"))
|
||||||
|
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||||
|
elif streaming_mode and self.configs.use_vocoder:
|
||||||
|
print(i18n("SoVits V3/4模型不支持流式推理模式,已自动回退到分段返回模式"))
|
||||||
|
streaming_mode = False
|
||||||
|
return_fragment = True
|
||||||
|
if parallel_infer:
|
||||||
|
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||||
|
else:
|
||||||
|
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||||||
|
# self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||||
|
elif parallel_infer and streaming_mode:
|
||||||
|
print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式"))
|
||||||
|
parallel_infer = False
|
||||||
|
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||||
else:
|
else:
|
||||||
print(i18n("并行推理模式已关闭"))
|
print(i18n("朴素推理模式已开启"))
|
||||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||||||
|
|
||||||
if return_fragment:
|
if return_fragment and streaming_mode:
|
||||||
print(i18n("分段返回模式已开启"))
|
print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回"))
|
||||||
if split_bucket:
|
return_fragment = False
|
||||||
split_bucket = False
|
|
||||||
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
if (return_fragment or streaming_mode) and split_bucket:
|
||||||
|
print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理"))
|
||||||
|
split_bucket = False
|
||||||
|
|
||||||
|
|
||||||
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
|
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
|
||||||
print(i18n("分桶处理模式已开启"))
|
print(i18n("分桶处理模式已开启"))
|
||||||
@ -1063,9 +1105,9 @@ class TTS:
|
|||||||
else:
|
else:
|
||||||
print(i18n("分桶处理模式已关闭"))
|
print(i18n("分桶处理模式已关闭"))
|
||||||
|
|
||||||
if fragment_interval < 0.01:
|
# if fragment_interval < 0.01:
|
||||||
fragment_interval = 0.01
|
# fragment_interval = 0.01
|
||||||
print(i18n("分段间隔过小,已自动设置为0.01"))
|
# print(i18n("分段间隔过小,已自动设置为0.01"))
|
||||||
|
|
||||||
no_prompt_text = False
|
no_prompt_text = False
|
||||||
if prompt_text in [None, ""]:
|
if prompt_text in [None, ""]:
|
||||||
@ -1126,7 +1168,7 @@ class TTS:
|
|||||||
###### text preprocessing ########
|
###### text preprocessing ########
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
data: list = None
|
data: list = None
|
||||||
if not return_fragment:
|
if not (return_fragment or streaming_mode):
|
||||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||||
@ -1186,10 +1228,11 @@ class TTS:
|
|||||||
t_34 = 0.0
|
t_34 = 0.0
|
||||||
t_45 = 0.0
|
t_45 = 0.0
|
||||||
audio = []
|
audio = []
|
||||||
|
is_first_package = True
|
||||||
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
||||||
for item in data:
|
for item in data:
|
||||||
t3 = time.perf_counter()
|
t3 = time.perf_counter()
|
||||||
if return_fragment:
|
if return_fragment or streaming_mode:
|
||||||
item = make_batch(item)
|
item = make_batch(item)
|
||||||
if item is None:
|
if item is None:
|
||||||
continue
|
continue
|
||||||
@ -1211,108 +1254,228 @@ class TTS:
|
|||||||
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"############ {i18n('预测语义Token')} ############")
|
|
||||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
|
||||||
all_phoneme_ids,
|
|
||||||
all_phoneme_lens,
|
|
||||||
prompt,
|
|
||||||
all_bert_features,
|
|
||||||
# prompt_phone_len=ph_offset,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
temperature=temperature,
|
|
||||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
|
||||||
max_len=max_len,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
)
|
|
||||||
t4 = time.perf_counter()
|
|
||||||
t_34 += t4 - t3
|
|
||||||
|
|
||||||
refer_audio_spec = []
|
refer_audio_spec = []
|
||||||
if self.is_v2pro:
|
|
||||||
sv_emb = []
|
sv_emb = [] if self.is_v2pro else None
|
||||||
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
||||||
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
||||||
refer_audio_spec.append(spec)
|
refer_audio_spec.append(spec)
|
||||||
if self.is_v2pro:
|
if self.is_v2pro:
|
||||||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
||||||
|
|
||||||
batch_audio_fragment = []
|
if not streaming_mode:
|
||||||
|
print(f"############ {i18n('预测语义Token')} ############")
|
||||||
|
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||||
|
all_phoneme_ids,
|
||||||
|
all_phoneme_lens,
|
||||||
|
prompt,
|
||||||
|
all_bert_features,
|
||||||
|
# prompt_phone_len=ph_offset,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||||
|
max_len=max_len,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
t4 = time.perf_counter()
|
||||||
|
t_34 += t4 - t3
|
||||||
|
|
||||||
# ## vits并行推理 method 1
|
|
||||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
batch_audio_fragment = []
|
||||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
|
||||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
# ## vits并行推理 method 1
|
||||||
# max_len = 0
|
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||||
# for i in range(0, len(batch_phones)):
|
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||||
# max_len = max(max_len, batch_phones[i].shape[-1])
|
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||||
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
# max_len = 0
|
||||||
# batch_phones = batch_phones.to(self.configs.device)
|
# for i in range(0, len(batch_phones)):
|
||||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
|
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||||||
# ))
|
# batch_phones = batch_phones.to(self.configs.device)
|
||||||
print(f"############ {i18n('合成音频')} ############")
|
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||||
if not self.configs.use_vocoder:
|
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
|
||||||
if speed_factor == 1.0:
|
# ))
|
||||||
print(f"{i18n('并行合成中')}...")
|
print(f"############ {i18n('合成音频')} ############")
|
||||||
# ## vits并行推理 method 2
|
if not self.configs.use_vocoder:
|
||||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
if speed_factor == 1.0:
|
||||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
print(f"{i18n('并行合成中')}...")
|
||||||
audio_frag_idx = [
|
# ## vits并行推理 method 2
|
||||||
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||||
for i in range(0, len(pred_semantic_list))
|
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||||
]
|
audio_frag_idx = [
|
||||||
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
||||||
all_pred_semantic = (
|
for i in range(0, len(pred_semantic_list))
|
||||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
]
|
||||||
)
|
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
||||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
all_pred_semantic = (
|
||||||
if self.is_v2pro != True:
|
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||||
_batch_audio_fragment = self.vits_model.decode(
|
|
||||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
|
||||||
).detach()[0, 0, :]
|
|
||||||
else:
|
|
||||||
_batch_audio_fragment = self.vits_model.decode(
|
|
||||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
|
||||||
).detach()[0, 0, :]
|
|
||||||
audio_frag_end_idx.insert(0, 0)
|
|
||||||
batch_audio_fragment = [
|
|
||||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
|
||||||
for i in range(1, len(audio_frag_end_idx))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# ## vits串行推理
|
|
||||||
for i, idx in enumerate(tqdm(idx_list)):
|
|
||||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
|
||||||
_pred_semantic = (
|
|
||||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
|
||||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
|
||||||
if self.is_v2pro != True:
|
|
||||||
audio_fragment = self.vits_model.decode(
|
|
||||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
|
||||||
).detach()[0, 0, :]
|
|
||||||
else:
|
|
||||||
audio_fragment = self.vits_model.decode(
|
|
||||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
|
||||||
).detach()[0, 0, :]
|
|
||||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
|
||||||
else:
|
|
||||||
if parallel_infer:
|
|
||||||
print(f"{i18n('并行合成中')}...")
|
|
||||||
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
|
||||||
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
|
||||||
)
|
|
||||||
batch_audio_fragment.extend(audio_fragments)
|
|
||||||
else:
|
|
||||||
for i, idx in enumerate(tqdm(idx_list)):
|
|
||||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
|
||||||
_pred_semantic = (
|
|
||||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
|
||||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
|
||||||
audio_fragment = self.using_vocoder_synthesis(
|
|
||||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
|
||||||
)
|
)
|
||||||
batch_audio_fragment.append(audio_fragment)
|
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||||
|
|
||||||
|
_batch_audio_fragment = self.vits_model.decode(
|
||||||
|
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||||
|
).detach()[0, 0, :]
|
||||||
|
|
||||||
|
audio_frag_end_idx.insert(0, 0)
|
||||||
|
batch_audio_fragment = [
|
||||||
|
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||||
|
for i in range(1, len(audio_frag_end_idx))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# ## vits串行推理
|
||||||
|
for i, idx in enumerate(tqdm(idx_list)):
|
||||||
|
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||||
|
_pred_semantic = (
|
||||||
|
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||||
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
|
audio_fragment = self.vits_model.decode(
|
||||||
|
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||||
|
).detach()[0, 0, :]
|
||||||
|
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||||
|
else:
|
||||||
|
if parallel_infer:
|
||||||
|
print(f"{i18n('并行合成中')}...")
|
||||||
|
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
||||||
|
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||||||
|
)
|
||||||
|
batch_audio_fragment.extend(audio_fragments)
|
||||||
|
else:
|
||||||
|
for i, idx in enumerate(tqdm(idx_list)):
|
||||||
|
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||||
|
_pred_semantic = (
|
||||||
|
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||||
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
|
audio_fragment = self.using_vocoder_synthesis(
|
||||||
|
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||||
|
)
|
||||||
|
batch_audio_fragment.append(audio_fragment)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# refer_audio_spec: torch.Tensor = [
|
||||||
|
# item.to(dtype=self.precision, device=self.configs.device)
|
||||||
|
# for item in self.prompt_cache["refer_spec"]
|
||||||
|
# ]
|
||||||
|
semantic_token_generator =self.t2s_model.model.infer_panel(
|
||||||
|
all_phoneme_ids[0].unsqueeze(0),
|
||||||
|
all_phoneme_lens,
|
||||||
|
prompt,
|
||||||
|
all_bert_features[0].unsqueeze(0),
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||||
|
max_len=max_len,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
streaming_mode=True,
|
||||||
|
chunk_length=min_chunk_length,
|
||||||
|
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None,
|
||||||
|
chunk_split_thershold=chunk_split_thershold,
|
||||||
|
)
|
||||||
|
t4 = time.perf_counter()
|
||||||
|
t_34 += t4 - t3
|
||||||
|
phones = batch_phones[0].unsqueeze(0).to(self.configs.device)
|
||||||
|
is_first_chunk = True
|
||||||
|
|
||||||
|
if not self.configs.use_vocoder:
|
||||||
|
# if speed_factor == 1.0:
|
||||||
|
# upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
|
||||||
|
# else:
|
||||||
|
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
|
||||||
|
else:
|
||||||
|
# if speed_factor == 1.0:
|
||||||
|
# upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
|
||||||
|
# else:
|
||||||
|
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
|
||||||
|
|
||||||
|
last_audio_chunk = None
|
||||||
|
# last_tokens = None
|
||||||
|
last_latent = None
|
||||||
|
previous_tokens = []
|
||||||
|
overlap_len = overlap_length
|
||||||
|
overlap_size = math.ceil(overlap_length*upsample_rate)
|
||||||
|
for semantic_tokens, is_final in semantic_token_generator:
|
||||||
|
if semantic_tokens is None and last_audio_chunk is not None:
|
||||||
|
yield self.audio_postprocess(
|
||||||
|
[[last_audio_chunk[-overlap_size:]]],
|
||||||
|
output_sr,
|
||||||
|
None,
|
||||||
|
speed_factor,
|
||||||
|
False,
|
||||||
|
0.0,
|
||||||
|
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
_semantic_tokens = semantic_tokens
|
||||||
|
print(f"semantic_tokens shape:{semantic_tokens.shape}")
|
||||||
|
|
||||||
|
previous_tokens.append(semantic_tokens)
|
||||||
|
|
||||||
|
_semantic_tokens = torch.cat(previous_tokens, dim=-1)
|
||||||
|
|
||||||
|
if not is_first_chunk and semantic_tokens.shape[-1] < 10:
|
||||||
|
overlap_len = overlap_length+(10-semantic_tokens.shape[-1])
|
||||||
|
else:
|
||||||
|
overlap_len = overlap_length
|
||||||
|
|
||||||
|
|
||||||
|
if not self.configs.use_vocoder:
|
||||||
|
token_padding_length = 0
|
||||||
|
# token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1]
|
||||||
|
# if token_padding_length>0:
|
||||||
|
# _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486)
|
||||||
|
# else:
|
||||||
|
# token_padding_length = 0
|
||||||
|
|
||||||
|
audio_chunk, latent, latent_mask = self.vits_model.decode_streaming(
|
||||||
|
_semantic_tokens.unsqueeze(0),
|
||||||
|
phones, refer_audio_spec,
|
||||||
|
speed=speed_factor,
|
||||||
|
sv_emb=sv_emb,
|
||||||
|
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
||||||
|
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
||||||
|
if last_latent is not None else None,
|
||||||
|
padding_length=token_padding_length
|
||||||
|
)
|
||||||
|
audio_chunk=audio_chunk.detach()[0, 0, :]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式"))
|
||||||
|
|
||||||
|
if overlap_len>overlap_length:
|
||||||
|
audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):]
|
||||||
|
|
||||||
|
audio_chunk_ = audio_chunk
|
||||||
|
if is_first_chunk and not is_final:
|
||||||
|
is_first_chunk = False
|
||||||
|
audio_chunk_ = audio_chunk_[:-overlap_size]
|
||||||
|
elif is_first_chunk and is_final:
|
||||||
|
is_first_chunk = False
|
||||||
|
elif not is_first_chunk and not is_final:
|
||||||
|
audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size)
|
||||||
|
audio_chunk_ = (
|
||||||
|
audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \
|
||||||
|
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
|
||||||
|
)
|
||||||
|
|
||||||
|
last_latent = latent
|
||||||
|
last_audio_chunk = audio_chunk
|
||||||
|
yield self.audio_postprocess(
|
||||||
|
[[audio_chunk_]],
|
||||||
|
output_sr,
|
||||||
|
None,
|
||||||
|
speed_factor,
|
||||||
|
False,
|
||||||
|
0.0,
|
||||||
|
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_first_package:
|
||||||
|
print(f"first_package_delay: {time.perf_counter()-t0:.3f}")
|
||||||
|
is_first_package = False
|
||||||
|
|
||||||
|
|
||||||
|
yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16)
|
||||||
|
|
||||||
t5 = time.perf_counter()
|
t5 = time.perf_counter()
|
||||||
t_45 += t5 - t4
|
t_45 += t5 - t4
|
||||||
@ -1327,17 +1490,18 @@ class TTS:
|
|||||||
fragment_interval,
|
fragment_interval,
|
||||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||||
)
|
)
|
||||||
|
elif streaming_mode:...
|
||||||
else:
|
else:
|
||||||
audio.append(batch_audio_fragment)
|
audio.append(batch_audio_fragment)
|
||||||
|
|
||||||
if self.stop_flag:
|
if self.stop_flag:
|
||||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not return_fragment:
|
if not (return_fragment or streaming_mode):
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||||
if len(audio) == 0:
|
if len(audio) == 0:
|
||||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||||||
return
|
return
|
||||||
yield self.audio_postprocess(
|
yield self.audio_postprocess(
|
||||||
audio,
|
audio,
|
||||||
@ -1384,16 +1548,17 @@ class TTS:
|
|||||||
fragment_interval: float = 0.3,
|
fragment_interval: float = 0.3,
|
||||||
super_sampling: bool = False,
|
super_sampling: bool = False,
|
||||||
) -> Tuple[int, np.ndarray]:
|
) -> Tuple[int, np.ndarray]:
|
||||||
zero_wav = torch.zeros(
|
if fragment_interval>0:
|
||||||
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
zero_wav = torch.zeros(
|
||||||
)
|
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
||||||
|
)
|
||||||
|
|
||||||
for i, batch in enumerate(audio):
|
for i, batch in enumerate(audio):
|
||||||
for j, audio_fragment in enumerate(batch):
|
for j, audio_fragment in enumerate(batch):
|
||||||
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
|
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
|
||||||
if max_audio > 1:
|
if max_audio > 1:
|
||||||
audio_fragment /= max_audio
|
audio_fragment /= max_audio
|
||||||
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment
|
||||||
audio[i][j] = audio_fragment
|
audio[i][j] = audio_fragment
|
||||||
|
|
||||||
if split_bucket:
|
if split_bucket:
|
||||||
@ -1413,13 +1578,18 @@ class TTS:
|
|||||||
max_audio = np.abs(audio).max()
|
max_audio = np.abs(audio).max()
|
||||||
if max_audio > 1:
|
if max_audio > 1:
|
||||||
audio /= max_audio
|
audio /= max_audio
|
||||||
|
audio = (audio * 32768).astype(np.int16)
|
||||||
t2 = time.perf_counter()
|
t2 = time.perf_counter()
|
||||||
print(f"超采样用时:{t2 - t1:.3f}s")
|
print(f"超采样用时:{t2 - t1:.3f}s")
|
||||||
else:
|
else:
|
||||||
|
# audio = audio.float() * 32768
|
||||||
|
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy()
|
||||||
|
|
||||||
audio = audio.cpu().numpy()
|
audio = audio.cpu().numpy()
|
||||||
|
|
||||||
audio = (audio * 32768).astype(np.int16)
|
audio = (audio * 32768).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# if speed_factor != 1.0:
|
# if speed_factor != 1.0:
|
||||||
# audio = speed_change(audio, speed=speed_factor, sr=int(sr))
|
# audio = speed_change(audio, speed=speed_factor, sr=int(sr))
|
||||||
@ -1612,24 +1782,43 @@ class TTS:
|
|||||||
self,
|
self,
|
||||||
audio_fragments: List[torch.Tensor],
|
audio_fragments: List[torch.Tensor],
|
||||||
overlap_len: int,
|
overlap_len: int,
|
||||||
|
search_len:int= 320
|
||||||
):
|
):
|
||||||
|
# overlap_len-=search_len
|
||||||
|
|
||||||
|
dtype = audio_fragments[0].dtype
|
||||||
|
|
||||||
for i in range(len(audio_fragments) - 1):
|
for i in range(len(audio_fragments) - 1):
|
||||||
f1 = audio_fragments[i]
|
f1 = audio_fragments[i].float()
|
||||||
f2 = audio_fragments[i + 1]
|
f2 = audio_fragments[i + 1].float()
|
||||||
w1 = f1[-overlap_len:]
|
w1 = f1[-overlap_len:]
|
||||||
w2 = f2[:overlap_len]
|
w2 = f2[:overlap_len+search_len]
|
||||||
assert w1.shape == w2.shape
|
# w2 = w2[-w2.shape[-1]//2:]
|
||||||
corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1), padding=w2.shape[-1] // 2).view(-1)[:-1]
|
# assert w1.shape == w2.shape
|
||||||
idx = corr.argmax()
|
corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1)
|
||||||
f1_ = f1[: -(overlap_len - idx)]
|
|
||||||
|
corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8
|
||||||
|
idx = (corr_norm/corr_den.sqrt()).argmax()
|
||||||
|
|
||||||
|
print(f"seg_idx: {idx}")
|
||||||
|
|
||||||
|
# idx = corr.argmax()
|
||||||
|
f1_ = f1[: -overlap_len]
|
||||||
audio_fragments[i] = f1_
|
audio_fragments[i] = f1_
|
||||||
|
|
||||||
f2_ = f2[idx:]
|
f2_ = f2[idx:]
|
||||||
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
|
window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype)
|
||||||
f2_[: (overlap_len - idx)] = (
|
f2_[: overlap_len] = (
|
||||||
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
|
window[: overlap_len] * f2_[: overlap_len]
|
||||||
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
|
+ window[overlap_len :] * f1[-overlap_len :]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
|
||||||
|
# f2_[: (overlap_len - idx)] = (
|
||||||
|
# window * f2_[: (overlap_len - idx)]
|
||||||
|
# + (1-window) * f1[-(overlap_len - idx) :]
|
||||||
|
# )
|
||||||
|
|
||||||
audio_fragments[i + 1] = f2_
|
audio_fragments[i + 1] = f2_
|
||||||
|
|
||||||
return torch.cat(audio_fragments, 0)
|
return torch.cat(audio_fragments, 0).to(dtype)
|
||||||
|
|||||||
@ -261,41 +261,21 @@ class T2SBlock:
|
|||||||
|
|
||||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||||
|
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||||
|
|
||||||
if padding_mask is not None:
|
x = x + attn
|
||||||
for i in range(batch_size):
|
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
# mask = padding_mask[i,:,0]
|
x = x + self.mlp.forward(x)
|
||||||
if self.false.device != padding_mask.device:
|
x = F.layer_norm(
|
||||||
self.false = self.false.to(padding_mask.device)
|
x,
|
||||||
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
[self.hidden_dim],
|
||||||
x_item = x[i, idx, :].unsqueeze(0)
|
self.norm_w2,
|
||||||
attn_item = attn[i, idx, :].unsqueeze(0)
|
self.norm_b2,
|
||||||
x_item = x_item + attn_item
|
self.norm_eps2,
|
||||||
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
)
|
||||||
x_item = x_item + self.mlp.forward(x_item)
|
|
||||||
x_item = F.layer_norm(
|
|
||||||
x_item,
|
|
||||||
[self.hidden_dim],
|
|
||||||
self.norm_w2,
|
|
||||||
self.norm_b2,
|
|
||||||
self.norm_eps2,
|
|
||||||
)
|
|
||||||
x[i, idx, :] = x_item.squeeze(0)
|
|
||||||
x = self.to_mask(x, padding_mask)
|
|
||||||
else:
|
|
||||||
x = x + attn
|
|
||||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
|
||||||
x = x + self.mlp.forward(x)
|
|
||||||
x = F.layer_norm(
|
|
||||||
x,
|
|
||||||
[self.hidden_dim],
|
|
||||||
self.norm_w2,
|
|
||||||
self.norm_b2,
|
|
||||||
self.norm_eps2,
|
|
||||||
)
|
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||||
|
|||||||
@ -417,7 +417,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
|||||||
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True)
|
||||||
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
temperature = gr.Slider(
|
temperature = gr.Slider(
|
||||||
|
|||||||
@ -37,6 +37,10 @@ from einops import rearrange, repeat
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from module.distrib import broadcast_tensors, is_distributed
|
||||||
|
from module.ddp_utils import SyncFunction
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
@ -69,27 +73,40 @@ def sample_vectors(samples, num: int):
|
|||||||
return samples[indices]
|
return samples[indices]
|
||||||
|
|
||||||
|
|
||||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64):
|
||||||
dim, dtype = samples.shape[-1], samples.dtype
|
N, D = samples.shape
|
||||||
max_kmeans_samples = 500
|
dtype, device = samples.dtype, samples.device
|
||||||
samples = samples[:max_kmeans_samples, :]
|
|
||||||
|
if frames_to_use < N:
|
||||||
|
indices = torch.randperm(N, device=device)[:frames_to_use]
|
||||||
|
samples = samples[indices]
|
||||||
|
|
||||||
means = sample_vectors(samples, num_clusters)
|
means = sample_vectors(samples, num_clusters)
|
||||||
|
|
||||||
print("kmeans start ... ")
|
print("kmeans start ... ")
|
||||||
for _ in tqdm(range(num_iters)):
|
for _ in tqdm(range(num_iters)):
|
||||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
# Store cluster assignments
|
||||||
dists = -(diffs**2).sum(dim=-1)
|
all_assignments = []
|
||||||
|
|
||||||
buckets = dists.max(dim=-1).indices
|
for i in range(0, samples.shape[0], batch_size):
|
||||||
|
batch = samples[i : i + batch_size] # [B, D]
|
||||||
|
dists = torch.cdist(batch, means, p=2) # [B, C]
|
||||||
|
assignments = dists.argmin(dim=1) # [B]
|
||||||
|
all_assignments.append(assignments)
|
||||||
|
|
||||||
|
buckets = torch.cat(all_assignments, dim=0) # [N]
|
||||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||||
zero_mask = bins == 0
|
zero_mask = bins == 0
|
||||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||||
|
|
||||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
# Compute new means
|
||||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
new_means = torch.zeros_like(means)
|
||||||
new_means = new_means / bins_min_clamped[..., None]
|
for i in range(num_clusters):
|
||||||
|
mask = buckets == i
|
||||||
|
if mask.any():
|
||||||
|
new_means[i] = samples[mask].mean(dim=0)
|
||||||
|
|
||||||
means = torch.where(zero_mask[..., None], means, new_means)
|
means = torch.where(zero_mask[:, None], means, new_means)
|
||||||
|
|
||||||
return means, bins
|
return means, bins
|
||||||
|
|
||||||
@ -141,13 +158,24 @@ class EuclideanCodebook(nn.Module):
|
|||||||
if self.inited:
|
if self.inited:
|
||||||
return
|
return
|
||||||
|
|
||||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
# [B * T * world_size, D]
|
||||||
|
data = SyncFunction.apply(data)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||||
|
else:
|
||||||
|
embed = torch.empty_like(self.embed)
|
||||||
|
cluster_size = torch.empty_like(self.cluster_size)
|
||||||
|
dist.broadcast(embed, src=0)
|
||||||
|
dist.broadcast(cluster_size, src=0)
|
||||||
|
|
||||||
self.embed.data.copy_(embed)
|
self.embed.data.copy_(embed)
|
||||||
self.embed_avg.data.copy_(embed.clone())
|
self.embed_avg.data.copy_(embed.clone())
|
||||||
self.cluster_size.data.copy_(cluster_size)
|
self.cluster_size.data.copy_(cluster_size)
|
||||||
self.inited.data.copy_(torch.Tensor([True]))
|
self.inited.data.copy_(torch.Tensor([True]))
|
||||||
# Make sure all buffers across workers are in sync after initialization
|
# Make sure all buffers across workers are in sync after initialization
|
||||||
# broadcast_tensors(self.buffers())
|
broadcast_tensors(self.buffers())
|
||||||
|
|
||||||
def replace_(self, samples, mask):
|
def replace_(self, samples, mask):
|
||||||
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
||||||
@ -161,9 +189,17 @@ class EuclideanCodebook(nn.Module):
|
|||||||
if not torch.any(expired_codes):
|
if not torch.any(expired_codes):
|
||||||
return
|
return
|
||||||
|
|
||||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
if is_distributed():
|
||||||
self.replace_(batch_samples, mask=expired_codes)
|
# [B * T * world_size, D]
|
||||||
# broadcast_tensors(self.buffers())
|
batch_samples = SyncFunction.apply(batch_samples)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
new_embeds = sample_vectors(batch_samples, expired_codes.sum())
|
||||||
|
else:
|
||||||
|
new_embeds = torch.zeros(expired_codes.sum(), self.embed.size(1), device=self.embed.device)
|
||||||
|
dist.broadcast(new_embeds, src=0)
|
||||||
|
self.embed.data[expired_codes] = new_embeds
|
||||||
|
broadcast_tensors(self.buffers())
|
||||||
|
|
||||||
def preprocess(self, x):
|
def preprocess(self, x):
|
||||||
x = rearrange(x, "... d -> (...) d")
|
x = rearrange(x, "... d -> (...) d")
|
||||||
@ -208,17 +244,26 @@ class EuclideanCodebook(nn.Module):
|
|||||||
quantize = self.dequantize(embed_ind)
|
quantize = self.dequantize(embed_ind)
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
|
### Update codebook by EMA
|
||||||
|
embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
|
||||||
|
embed_sum = x.t() @ embed_onehot # [D, cb-size]
|
||||||
|
if is_distributed():
|
||||||
|
dist.all_reduce(embed_onehot_sum)
|
||||||
|
dist.all_reduce(embed_sum)
|
||||||
|
# Update ema cluster count N_i^t, eq. (6) in vqvae paper
|
||||||
|
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
|
||||||
|
# Update ema embed: eq. (7) in vqvae paper
|
||||||
|
self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
|
||||||
|
# apply laplace smoothing
|
||||||
|
n = self.cluster_size.sum()
|
||||||
|
cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
|
||||||
|
# Update ema embed: eq. (8) in vqvae paper
|
||||||
|
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||||
|
self.embed.data.copy_(embed_normalized)
|
||||||
|
|
||||||
# We do the expiry of code at that point as buffers are in sync
|
# We do the expiry of code at that point as buffers are in sync
|
||||||
# and all the workers will take the same decision.
|
# and all the workers will take the same decision.
|
||||||
self.expire_codes_(x)
|
self.expire_codes_(x)
|
||||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
|
||||||
embed_sum = x.t() @ embed_onehot
|
|
||||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
|
||||||
cluster_size = (
|
|
||||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
|
||||||
)
|
|
||||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
|
||||||
self.embed.data.copy_(embed_normalized)
|
|
||||||
|
|
||||||
return quantize, embed_ind
|
return quantize, embed_ind
|
||||||
|
|
||||||
|
|||||||
181
GPT_SoVITS/module/ddp_utils.py
Normal file
181
GPT_SoVITS/module/ddp_utils.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.nn.parallel.distributed import _find_tensors
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
|
||||||
|
class SyncFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
# @torch.no_grad()
|
||||||
|
def forward(ctx, tensor):
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
# Collect batch sizes from all processes
|
||||||
|
local_bs = torch.tensor([tensor.shape[0]], device=tensor.device)
|
||||||
|
batch_sizes = [torch.zeros_like(local_bs) for _ in range(world_size)]
|
||||||
|
torch.distributed.all_gather(batch_sizes, local_bs)
|
||||||
|
|
||||||
|
# Convert to integer list and find the minimum
|
||||||
|
batch_sizes_int = [bs.item() for bs in batch_sizes]
|
||||||
|
min_bs = min(batch_sizes_int)
|
||||||
|
|
||||||
|
# Crop the tensor to the minimum batch size if needed
|
||||||
|
cropped_tensor = tensor[:min_bs] if tensor.shape[0] > min_bs else tensor
|
||||||
|
|
||||||
|
# Prepare for gathering
|
||||||
|
out_shape = (min_bs * world_size,) + tensor.shape[1:]
|
||||||
|
gathered_tensor = torch.zeros(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Build tensor list for all_gather
|
||||||
|
tensor_list = list(torch.chunk(gathered_tensor, world_size))
|
||||||
|
|
||||||
|
# Perform all_gather using the cropped tensors
|
||||||
|
torch.distributed.all_gather(tensor_list, cropped_tensor)
|
||||||
|
|
||||||
|
# Save for backward pass
|
||||||
|
ctx.min_bs = min_bs
|
||||||
|
ctx.world_size = world_size
|
||||||
|
ctx.orig_shape = tensor.shape
|
||||||
|
|
||||||
|
return gathered_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
assert False
|
||||||
|
grad_input = grad_output.clone()
|
||||||
|
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
|
||||||
|
|
||||||
|
idx_from = torch.distributed.get_rank() * ctx.batch_size
|
||||||
|
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
|
||||||
|
return grad_input[idx_from:idx_to]
|
||||||
|
|
||||||
|
class DDP(DistributedDataParallel):
|
||||||
|
"""
|
||||||
|
Override the forward call in lightning so it goes to training and validation step respectively
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, *inputs, **kwargs): # pragma: no cover
|
||||||
|
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
|
||||||
|
self._sync_params()
|
||||||
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||||
|
assert len(self.device_ids) == 1
|
||||||
|
if self.module.training:
|
||||||
|
output = self.module.training_step(*inputs[0], **kwargs[0])
|
||||||
|
elif self.module.testing:
|
||||||
|
output = self.module.test_step(*inputs[0], **kwargs[0])
|
||||||
|
else:
|
||||||
|
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
||||||
|
if torch.is_grad_enabled():
|
||||||
|
# We'll return the output object verbatim since it is a freeform
|
||||||
|
# object. We need to find any tensors in this object, though,
|
||||||
|
# because we need to figure out which parameters were used during
|
||||||
|
# this forward pass, to ensure we short circuit reduction for any
|
||||||
|
# unused parameters. Only if `find_unused_parameters` is set.
|
||||||
|
if self.find_unused_parameters:
|
||||||
|
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||||
|
else:
|
||||||
|
self.reducer.prepare_for_backward([])
|
||||||
|
else:
|
||||||
|
from torch.nn.parallel.distributed import (
|
||||||
|
Join,
|
||||||
|
_DDPSink,
|
||||||
|
_tree_flatten_with_rref,
|
||||||
|
_tree_unflatten_with_rref,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
|
||||||
|
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
||||||
|
self.logger.set_runtime_stats_and_log()
|
||||||
|
self.num_iterations += 1
|
||||||
|
self.reducer.prepare_for_forward()
|
||||||
|
|
||||||
|
# Notify the join context that this process has not joined, if
|
||||||
|
# needed
|
||||||
|
work = Join.notify_join_context(self)
|
||||||
|
if work:
|
||||||
|
self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
|
||||||
|
|
||||||
|
# Calling _rebuild_buckets before forward compuation,
|
||||||
|
# It may allocate new buckets before deallocating old buckets
|
||||||
|
# inside _rebuild_buckets. To save peak memory usage,
|
||||||
|
# call _rebuild_buckets before the peak memory usage increases
|
||||||
|
# during forward computation.
|
||||||
|
# This should be called only once during whole training period.
|
||||||
|
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
|
||||||
|
print("Reducer buckets have been rebuilt in this iteration.")
|
||||||
|
self._has_rebuilt_buckets = True
|
||||||
|
|
||||||
|
# sync params according to location (before/after forward) user
|
||||||
|
# specified as part of hook, if hook was specified.
|
||||||
|
buffer_hook_registered = hasattr(self, "buffer_hook")
|
||||||
|
if self._check_sync_bufs_pre_fwd():
|
||||||
|
self._sync_buffers()
|
||||||
|
|
||||||
|
if self._join_config.enable:
|
||||||
|
# Notify joined ranks whether they should sync in backwards pass or not.
|
||||||
|
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
|
||||||
|
|
||||||
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||||
|
if self.module.training:
|
||||||
|
output = self.module.training_step(*inputs[0], **kwargs[0])
|
||||||
|
elif self.module.testing:
|
||||||
|
output = self.module.test_step(*inputs[0], **kwargs[0])
|
||||||
|
else:
|
||||||
|
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
||||||
|
|
||||||
|
# sync params according to location (before/after forward) user
|
||||||
|
# specified as part of hook, if hook was specified.
|
||||||
|
if self._check_sync_bufs_post_fwd():
|
||||||
|
self._sync_buffers()
|
||||||
|
|
||||||
|
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
||||||
|
self.require_forward_param_sync = True
|
||||||
|
# We'll return the output object verbatim since it is a freeform
|
||||||
|
# object. We need to find any tensors in this object, though,
|
||||||
|
# because we need to figure out which parameters were used during
|
||||||
|
# this forward pass, to ensure we short circuit reduction for any
|
||||||
|
# unused parameters. Only if `find_unused_parameters` is set.
|
||||||
|
if self.find_unused_parameters and not self.static_graph:
|
||||||
|
# Do not need to populate this for static graph.
|
||||||
|
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||||
|
else:
|
||||||
|
self.reducer.prepare_for_backward([])
|
||||||
|
else:
|
||||||
|
self.require_forward_param_sync = False
|
||||||
|
|
||||||
|
# TODO: DDPSink is currently enabled for unused parameter detection and
|
||||||
|
# static graph training for first iteration.
|
||||||
|
if (self.find_unused_parameters and not self.static_graph) or (
|
||||||
|
self.static_graph and self.num_iterations == 1
|
||||||
|
):
|
||||||
|
state_dict = {
|
||||||
|
"static_graph": self.static_graph,
|
||||||
|
"num_iterations": self.num_iterations,
|
||||||
|
}
|
||||||
|
|
||||||
|
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
|
||||||
|
output_placeholders = [None for _ in range(len(output_tensor_list))]
|
||||||
|
# Do not touch tensors that have no grad_fn, which can cause issues
|
||||||
|
# such as https://github.com/pytorch/pytorch/issues/60733
|
||||||
|
for i, output in enumerate(output_tensor_list):
|
||||||
|
if torch.is_tensor(output) and output.grad_fn is None:
|
||||||
|
output_placeholders[i] = output
|
||||||
|
|
||||||
|
# When find_unused_parameters=True, makes tensors which require grad
|
||||||
|
# run through the DDPSink backward pass. When not all outputs are
|
||||||
|
# used in loss, this makes those corresponding tensors receive
|
||||||
|
# undefined gradient which the reducer then handles to ensure
|
||||||
|
# param.grad field is not touched and we don't error out.
|
||||||
|
passthrough_tensor_list = _DDPSink.apply(
|
||||||
|
self.reducer,
|
||||||
|
state_dict,
|
||||||
|
*output_tensor_list,
|
||||||
|
)
|
||||||
|
for i in range(len(output_placeholders)):
|
||||||
|
if output_placeholders[i] is None:
|
||||||
|
output_placeholders[i] = passthrough_tensor_list[i]
|
||||||
|
|
||||||
|
# Reconstruct output data structure.
|
||||||
|
output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
|
||||||
|
return output
|
||||||
123
GPT_SoVITS/module/distrib.py
Normal file
123
GPT_SoVITS/module/distrib.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Torch distributed utilities."""
|
||||||
|
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def rank():
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
return torch.distributed.get_rank()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def world_size():
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
return torch.distributed.get_world_size()
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed():
|
||||||
|
return world_size() > 1
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
||||||
|
if is_distributed():
|
||||||
|
return torch.distributed.all_reduce(tensor, op)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_complex_or_float(tensor):
|
||||||
|
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
||||||
|
# utility function to check that the number of params in all workers is the same,
|
||||||
|
# and thus avoid a deadlock with distributed all reduce.
|
||||||
|
if not is_distributed() or not params:
|
||||||
|
return
|
||||||
|
# print('params[0].device ', params[0].device)
|
||||||
|
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
||||||
|
all_reduce(tensor)
|
||||||
|
if tensor.item() != len(params) * world_size():
|
||||||
|
# If not all the workers have the same number, for at least one of them,
|
||||||
|
# this inequality will be verified.
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
||||||
|
"""Broadcast the tensors from the given parameters to all workers.
|
||||||
|
This can be used to ensure that all workers have the same model to start with.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return
|
||||||
|
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
||||||
|
_check_number_of_params(tensors)
|
||||||
|
handles = []
|
||||||
|
for tensor in tensors:
|
||||||
|
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
||||||
|
handles.append(handle)
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def sync_buffer(buffers, average=True):
|
||||||
|
"""
|
||||||
|
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return
|
||||||
|
handles = []
|
||||||
|
for buffer in buffers:
|
||||||
|
if torch.is_floating_point(buffer.data):
|
||||||
|
if average:
|
||||||
|
handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||||
|
else:
|
||||||
|
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
|
||||||
|
handles.append((buffer, handle))
|
||||||
|
for buffer, handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
if average:
|
||||||
|
buffer.data /= world_size
|
||||||
|
|
||||||
|
|
||||||
|
def sync_grad(params):
|
||||||
|
"""
|
||||||
|
Simpler alternative to DistributedDataParallel, that doesn't rely
|
||||||
|
on any black magic. For simple models it can also be as fast.
|
||||||
|
Just call this on your model parameters after the call to backward!
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return
|
||||||
|
handles = []
|
||||||
|
for p in params:
|
||||||
|
if p.grad is not None:
|
||||||
|
handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||||
|
handles.append((p, handle))
|
||||||
|
for p, handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
p.grad.data /= world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
|
||||||
|
"""Average a dictionary of metrics across all workers, using the optional
|
||||||
|
`count` as unormalized weight.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return metrics
|
||||||
|
keys, values = zip(*metrics.items())
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
||||||
|
tensor *= count
|
||||||
|
all_reduce(tensor)
|
||||||
|
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
||||||
|
return dict(zip(keys, averaged))
|
||||||
@ -151,6 +151,8 @@ class DurationPredictor(nn.Module):
|
|||||||
return x * x_mask
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
WINDOW = {}
|
||||||
|
|
||||||
class TextEncoder(nn.Module):
|
class TextEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -209,7 +211,7 @@ class TextEncoder(nn.Module):
|
|||||||
|
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
|
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||||
|
|
||||||
y = self.ssl_proj(y * y_mask) * y_mask
|
y = self.ssl_proj(y * y_mask) * y_mask
|
||||||
@ -222,13 +224,44 @@ class TextEncoder(nn.Module):
|
|||||||
text = self.text_embedding(text).transpose(1, 2)
|
text = self.text_embedding(text).transpose(1, 2)
|
||||||
text = self.encoder_text(text * text_mask, text_mask)
|
text = self.encoder_text(text * text_mask, text_mask)
|
||||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||||
|
|
||||||
|
if padding_length is not None and padding_length!=0:
|
||||||
|
y = y[:, :, :-padding_length]
|
||||||
|
y_mask = y_mask[:, :, :-padding_length]
|
||||||
|
|
||||||
|
|
||||||
y = self.encoder2(y * y_mask, y_mask)
|
y = self.encoder2(y * y_mask, y_mask)
|
||||||
|
|
||||||
|
if result_length is not None:
|
||||||
|
y = y[:, :, -result_length:]
|
||||||
|
y_mask = y_mask[:, :, -result_length:]
|
||||||
|
|
||||||
|
if overlap_frames is not None:
|
||||||
|
overlap_len = overlap_frames.shape[-1]
|
||||||
|
window = WINDOW.get(overlap_len, None)
|
||||||
|
if window is None:
|
||||||
|
# WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
|
||||||
|
WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2))
|
||||||
|
window = WINDOW[overlap_len]
|
||||||
|
|
||||||
|
|
||||||
|
window = window.to(y.device)
|
||||||
|
y[:,:,:overlap_len] = (
|
||||||
|
window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len]
|
||||||
|
+ window[overlap_len:].view(1, 1, -1) * overlap_frames
|
||||||
|
)
|
||||||
|
|
||||||
|
y_ = y
|
||||||
|
y_mask_ = y_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if speed != 1:
|
if speed != 1:
|
||||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||||
stats = self.proj(y) * y_mask
|
stats = self.proj(y) * y_mask
|
||||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
return y, m, logs, y_mask
|
return y, m, logs, y_mask, y_, y_mask_
|
||||||
|
|
||||||
def extract_latent(self, x):
|
def extract_latent(self, x):
|
||||||
x = self.ssl_proj(x)
|
x = self.ssl_proj(x)
|
||||||
@ -921,7 +954,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
|
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||||
z_p = self.flow(z, y_mask, g=ge)
|
z_p = self.flow(z, y_mask, g=ge)
|
||||||
|
|
||||||
@ -949,7 +982,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
|
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
@ -957,6 +990,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||||
def get_ge(refer, sv_emb):
|
def get_ge(refer, sv_emb):
|
||||||
@ -989,7 +1023,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(
|
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
|
||||||
quantized,
|
quantized,
|
||||||
y_lengths,
|
y_lengths,
|
||||||
text,
|
text,
|
||||||
@ -1004,6 +1038,59 @@ class SynthesizerTrn(nn.Module):
|
|||||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||||
|
def get_ge(refer, sv_emb):
|
||||||
|
ge = None
|
||||||
|
if refer is not None:
|
||||||
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||||
|
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||||
|
if self.version == "v1":
|
||||||
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
|
else:
|
||||||
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
|
if self.is_v2pro:
|
||||||
|
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
||||||
|
ge += sv_emb.unsqueeze(-1)
|
||||||
|
ge = self.prelu(ge)
|
||||||
|
return ge
|
||||||
|
|
||||||
|
if type(refer) == list:
|
||||||
|
ges = []
|
||||||
|
for idx, _refer in enumerate(refer):
|
||||||
|
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||||
|
ges.append(ge)
|
||||||
|
ge = torch.stack(ges, 0).mean(0)
|
||||||
|
else:
|
||||||
|
ge = get_ge(refer, sv_emb)
|
||||||
|
|
||||||
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||||
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||||
|
|
||||||
|
quantized = self.quantizer.decode(codes)
|
||||||
|
if self.semantic_frame_rate == "25hz":
|
||||||
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
|
result_length = (2*result_length) if result_length is not None else None
|
||||||
|
padding_length = (2*padding_length) if padding_length is not None else None
|
||||||
|
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(
|
||||||
|
quantized,
|
||||||
|
y_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||||
|
speed,
|
||||||
|
result_length=result_length,
|
||||||
|
overlap_frames=overlap_frames,
|
||||||
|
padding_length=padding_length
|
||||||
|
)
|
||||||
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
|
||||||
|
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||||
|
return o, y_, y_mask_
|
||||||
|
|
||||||
def extract_latent(self, x):
|
def extract_latent(self, x):
|
||||||
ssl = self.ssl_proj(x)
|
ssl = self.ssl_proj(x)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
@ -1226,7 +1313,7 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
ssl = self.ssl_proj(ssl)
|
ssl = self.ssl_proj(ssl)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||||
fea = self.bridge(x)
|
fea = self.bridge(x)
|
||||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
||||||
fea, y_mask_ = self.wns1(
|
fea, y_mask_ = self.wns1(
|
||||||
@ -1260,7 +1347,7 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||||
fea = self.bridge(x)
|
fea = self.bridge(x)
|
||||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
||||||
####more wn paramter to learn mel
|
####more wn paramter to learn mel
|
||||||
@ -1377,7 +1464,7 @@ class SynthesizerTrnV3b(nn.Module):
|
|||||||
ssl = self.ssl_proj(ssl)
|
ssl = self.ssl_proj(ssl)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||||
z_p = self.flow(z, y_mask, g=ge)
|
z_p = self.flow(z, y_mask, g=ge)
|
||||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||||
@ -1420,7 +1507,7 @@ class SynthesizerTrnV3b(nn.Module):
|
|||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||||
fea = self.bridge(x)
|
fea = self.bridge(x)
|
||||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||||
####more wn paramter to learn mel
|
####more wn paramter to learn mel
|
||||||
|
|||||||
@ -124,7 +124,7 @@ def run(rank, n_gpus, hps):
|
|||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
batch_sampler=train_sampler,
|
batch_sampler=train_sampler,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
prefetch_factor=4,
|
prefetch_factor=3,
|
||||||
)
|
)
|
||||||
# if rank == 0:
|
# if rank == 0:
|
||||||
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
||||||
|
|||||||
@ -118,13 +118,13 @@ def run(rank, n_gpus, hps):
|
|||||||
collate_fn = TextAudioSpeakerCollate()
|
collate_fn = TextAudioSpeakerCollate()
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
num_workers=6,
|
num_workers=5,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
batch_sampler=train_sampler,
|
batch_sampler=train_sampler,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
prefetch_factor=4,
|
prefetch_factor=3,
|
||||||
)
|
)
|
||||||
# if rank == 0:
|
# if rank == 0:
|
||||||
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
||||||
|
|||||||
@ -120,13 +120,13 @@ def run(rank, n_gpus, hps):
|
|||||||
collate_fn = TextAudioSpeakerCollate()
|
collate_fn = TextAudioSpeakerCollate()
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
num_workers=6,
|
num_workers=5,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
batch_sampler=train_sampler,
|
batch_sampler=train_sampler,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
prefetch_factor=4,
|
prefetch_factor=3,
|
||||||
)
|
)
|
||||||
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
|
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
|
||||||
os.makedirs(save_root, exist_ok=True)
|
os.makedirs(save_root, exist_ok=True)
|
||||||
|
|||||||
611
GPT_SoVITS/stream_v2pro.py
Normal file
611
GPT_SoVITS/stream_v2pro.py
Normal file
@ -0,0 +1,611 @@
|
|||||||
|
# 这是一个实验性质的实现,旨在探索 stream infer 的可能性。(xiao hai xie zhe wan de)
|
||||||
|
from typing import List
|
||||||
|
from export_torch_script import ExportERes2NetV2, SSLModel, T2SModel, VitsModel, get_raw_t2s_model, init_sv_cn, resamplex, sample, spectrogram_torch
|
||||||
|
import export_torch_script
|
||||||
|
from my_utils import load_audio
|
||||||
|
import torch
|
||||||
|
from torch import LongTensor, Tensor, nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
from inference_webui import get_phones_and_bert
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
class StreamT2SModel(nn.Module):
|
||||||
|
def __init__(self, t2s: T2SModel):
|
||||||
|
super(StreamT2SModel, self).__init__()
|
||||||
|
self.t2s = t2s
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def pre_infer(
|
||||||
|
self,
|
||||||
|
prompts: LongTensor,
|
||||||
|
ref_seq: LongTensor,
|
||||||
|
text_seq: LongTensor,
|
||||||
|
ref_bert: torch.Tensor,
|
||||||
|
text_bert: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]:
|
||||||
|
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||||
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
|
bert = bert.unsqueeze(0)
|
||||||
|
|
||||||
|
x = self.t2s.ar_text_embedding(all_phoneme_ids)
|
||||||
|
x = x + self.t2s.bert_proj(bert.transpose(1, 2))
|
||||||
|
x: torch.Tensor = self.t2s.ar_text_position(x)
|
||||||
|
|
||||||
|
# [1,N,512] [1,N]
|
||||||
|
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
|
y = prompts
|
||||||
|
# x_example = x[:,:,0] * 0.0
|
||||||
|
|
||||||
|
x_len = x.shape[1]
|
||||||
|
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||||
|
|
||||||
|
y_emb = self.t2s.ar_audio_embedding(y)
|
||||||
|
y_len: int = y_emb.shape[1]
|
||||||
|
prefix_len = y.shape[1]
|
||||||
|
y_pos = self.t2s.ar_audio_position(y_emb)
|
||||||
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
|
||||||
|
bsz = x.shape[0]
|
||||||
|
src_len = x_len + y_len
|
||||||
|
x_attn_mask_pad = F.pad(
|
||||||
|
x_attn_mask,
|
||||||
|
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||||
|
value=True,
|
||||||
|
)
|
||||||
|
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||||
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||||
|
(x_len, 0),
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
|
xy_attn_mask = (
|
||||||
|
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(bsz * self.t2s.num_head, -1, -1)
|
||||||
|
.view(bsz, self.t2s.num_head, src_len, src_len)
|
||||||
|
.to(device=x.device, dtype=torch.bool)
|
||||||
|
)
|
||||||
|
|
||||||
|
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.process_prompt(
|
||||||
|
xy_pos, xy_attn_mask, None
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
samples = sample(
|
||||||
|
logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0
|
||||||
|
)[0]
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
y_emb: Tensor = self.t2s.ar_audio_embedding(y[:, -1:])
|
||||||
|
xy_pos: Tensor = (
|
||||||
|
y_emb * self.t2s.ar_audio_position.x_scale
|
||||||
|
+ self.t2s.ar_audio_position.alpha
|
||||||
|
* self.t2s.ar_audio_position.pe[:, y_len].to(
|
||||||
|
dtype=y_emb.dtype, device=y_emb.device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return y_len, y, xy_pos, k_cache, v_cache
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def decode_next_token(
|
||||||
|
self,
|
||||||
|
idx: int, # 记住从1开始 到1500
|
||||||
|
top_k: int,
|
||||||
|
y_len: int,
|
||||||
|
y: Tensor,
|
||||||
|
xy_pos: Tensor,
|
||||||
|
k_cache: List[Tensor],
|
||||||
|
v_cache: List[Tensor],
|
||||||
|
) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]:
|
||||||
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
|
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||||
|
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token(
|
||||||
|
xy_pos, k_cache, v_cache
|
||||||
|
)
|
||||||
|
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
|
||||||
|
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
|
||||||
|
samples = sample(
|
||||||
|
logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
last_token = int(samples[0, 0])
|
||||||
|
|
||||||
|
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
|
# stop = True
|
||||||
|
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
|
||||||
|
return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache
|
||||||
|
|
||||||
|
# if stop:
|
||||||
|
# if y.shape[1] == 0:
|
||||||
|
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||||
|
# break
|
||||||
|
|
||||||
|
y_emb = self.t2s.ar_audio_embedding(y[:, -1:])
|
||||||
|
xy_pos = (
|
||||||
|
y_emb * self.t2s.ar_audio_position.x_scale
|
||||||
|
+ self.t2s.ar_audio_position.alpha
|
||||||
|
* self.t2s.ar_audio_position.pe[:, y_len + idx].to(
|
||||||
|
dtype=y_emb.dtype, device=y_emb.device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return y, xy_pos, last_token, k_cache, v_cache
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
idx: int, # 记住从1开始 到1500
|
||||||
|
top_k: int,
|
||||||
|
y_len: int,
|
||||||
|
y: Tensor,
|
||||||
|
xy_pos: Tensor,
|
||||||
|
k_cache: List[Tensor],
|
||||||
|
v_cache: List[Tensor],
|
||||||
|
):
|
||||||
|
return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache)
|
||||||
|
|
||||||
|
|
||||||
|
class StepVitsModel(nn.Module):
|
||||||
|
def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2):
|
||||||
|
super().__init__()
|
||||||
|
self.hps = vits.hps
|
||||||
|
self.vq_model = vits.vq_model
|
||||||
|
self.hann_window = vits.hann_window
|
||||||
|
self.sv = sv_model
|
||||||
|
|
||||||
|
def ref_handle(self, ref_audio_32k):
|
||||||
|
refer = spectrogram_torch(
|
||||||
|
self.hann_window,
|
||||||
|
ref_audio_32k.float(),
|
||||||
|
self.hps.data.filter_length,
|
||||||
|
self.hps.data.sampling_rate,
|
||||||
|
self.hps.data.hop_length,
|
||||||
|
self.hps.data.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
refer = refer.to(ref_audio_32k.dtype)
|
||||||
|
ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device)
|
||||||
|
sv_emb = self.sv(ref_audio_16k)
|
||||||
|
return refer, sv_emb
|
||||||
|
|
||||||
|
def extract_latent(self, ssl_content):
|
||||||
|
codes = self.vq_model.extract_latent(ssl_content)
|
||||||
|
return codes[0]
|
||||||
|
|
||||||
|
def forward(self, pred_semantic, text_seq, refer, sv_emb=None):
|
||||||
|
return self.vq_model(
|
||||||
|
pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb
|
||||||
|
)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor):
|
||||||
|
ref_len = len(reference_audio)
|
||||||
|
search_len = len(search_audio)
|
||||||
|
|
||||||
|
if search_len < ref_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用F.conv1d计算原始互相关
|
||||||
|
reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0)
|
||||||
|
search_padded = search_audio.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算点积
|
||||||
|
dot_products = F.conv1d(search_padded, reference_flipped).squeeze()
|
||||||
|
|
||||||
|
if len(dot_products.shape) == 0:
|
||||||
|
dot_products = dot_products.unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算参考音频的平方和
|
||||||
|
ref_squared_sum = torch.sum(reference_audio**2)
|
||||||
|
|
||||||
|
# 计算搜索音频每个位置的平方和(滑动窗口)
|
||||||
|
search_squared = search_audio**2
|
||||||
|
search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0)
|
||||||
|
ones_kernel = torch.ones(
|
||||||
|
1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device
|
||||||
|
)
|
||||||
|
|
||||||
|
segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze()
|
||||||
|
|
||||||
|
if len(segment_squared_sums.shape) == 0:
|
||||||
|
segment_squared_sums = segment_squared_sums.unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算归一化因子
|
||||||
|
ref_norm = torch.sqrt(ref_squared_sum)
|
||||||
|
segment_norms = torch.sqrt(segment_squared_sums)
|
||||||
|
|
||||||
|
# 避免除零
|
||||||
|
epsilon = 1e-8
|
||||||
|
normalization_factor = ref_norm * segment_norms + epsilon
|
||||||
|
|
||||||
|
# 归一化互相关
|
||||||
|
correlation_scores = dot_products / normalization_factor
|
||||||
|
|
||||||
|
best_offset = torch.argmax(correlation_scores).item()
|
||||||
|
|
||||||
|
return best_offset, correlation_scores
|
||||||
|
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
def test_stream(
|
||||||
|
gpt_path,
|
||||||
|
vits_path,
|
||||||
|
version,
|
||||||
|
ref_audio_path,
|
||||||
|
ref_text,
|
||||||
|
output_path,
|
||||||
|
device="cpu",
|
||||||
|
is_half=True,
|
||||||
|
):
|
||||||
|
if export_torch_script.sv_cn_model == None:
|
||||||
|
init_sv_cn(device,is_half)
|
||||||
|
|
||||||
|
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||||
|
ssl = SSLModel()
|
||||||
|
|
||||||
|
print(f"device: {device}")
|
||||||
|
|
||||||
|
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
|
||||||
|
ref_text, "all_zh", "v2"
|
||||||
|
)
|
||||||
|
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||||
|
ref_bert = ref_bert_T.T
|
||||||
|
if is_half:
|
||||||
|
ref_bert = ref_bert.half()
|
||||||
|
ref_bert = ref_bert.to(ref_seq.device)
|
||||||
|
|
||||||
|
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||||
|
"这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
|
||||||
|
)
|
||||||
|
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||||
|
text_bert = text_bert_T.T
|
||||||
|
if is_half:
|
||||||
|
text_bert = text_bert.half()
|
||||||
|
text_bert = text_bert.to(text_seq.device)
|
||||||
|
|
||||||
|
ssl_content = ssl(ref_audio)
|
||||||
|
if is_half:
|
||||||
|
ssl_content = ssl_content.half()
|
||||||
|
ssl_content = ssl_content.to(device)
|
||||||
|
|
||||||
|
sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model)
|
||||||
|
|
||||||
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
|
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
|
||||||
|
vits.eval()
|
||||||
|
|
||||||
|
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||||
|
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||||
|
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||||
|
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||||
|
print("#### get_raw_t2s_model ####")
|
||||||
|
print(raw_t2s.config)
|
||||||
|
if is_half:
|
||||||
|
raw_t2s = raw_t2s.half()
|
||||||
|
t2s_m = T2SModel(raw_t2s)
|
||||||
|
t2s_m.eval()
|
||||||
|
# t2s = torch.jit.script(t2s_m).to(device)
|
||||||
|
t2s = t2s_m
|
||||||
|
print("#### script t2s_m ####")
|
||||||
|
|
||||||
|
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||||
|
|
||||||
|
stream_t2s = StreamT2SModel(t2s).to(device)
|
||||||
|
stream_t2s = torch.jit.script(stream_t2s)
|
||||||
|
|
||||||
|
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
|
||||||
|
if is_half:
|
||||||
|
ref_audio_sr = ref_audio_sr.half()
|
||||||
|
ref_audio_sr = ref_audio_sr.to(device)
|
||||||
|
|
||||||
|
top_k = 15
|
||||||
|
|
||||||
|
codes = vits.vq_model.extract_latent(ssl_content)
|
||||||
|
prompt_semantic = codes[0, 0]
|
||||||
|
prompts = prompt_semantic.unsqueeze(0)
|
||||||
|
|
||||||
|
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||||||
|
sv_emb = sv_model(audio_16k)
|
||||||
|
print("text_seq",text_seq.shape)
|
||||||
|
|
||||||
|
refer = spectrogram_torch(
|
||||||
|
vits.hann_window,
|
||||||
|
ref_audio_sr,
|
||||||
|
vits.hps.data.filter_length,
|
||||||
|
vits.hps.data.sampling_rate,
|
||||||
|
vits.hps.data.hop_length,
|
||||||
|
vits.hps.data.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
st = time.time()
|
||||||
|
et = time.time()
|
||||||
|
|
||||||
|
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||||
|
idx = 1
|
||||||
|
last_idx = 0
|
||||||
|
audios = []
|
||||||
|
raw_audios = []
|
||||||
|
last_audio_ret = None
|
||||||
|
offset_index = []
|
||||||
|
full_audios = []
|
||||||
|
print("y.shape:", y.shape)
|
||||||
|
cut_id = 0
|
||||||
|
while True:
|
||||||
|
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
||||||
|
# print("y.shape:", y.shape)
|
||||||
|
stop = last_token==t2s.EOS
|
||||||
|
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
|
||||||
|
|
||||||
|
if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
|
||||||
|
cut_id = idx + 7
|
||||||
|
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
|
||||||
|
# y = torch.cat([y, y[:,-1:]], dim=1)
|
||||||
|
# idx+=1
|
||||||
|
|
||||||
|
if stop :
|
||||||
|
idx -=1
|
||||||
|
print('stop')
|
||||||
|
print(idx, y[:,-idx+last_idx:])
|
||||||
|
print(idx,last_idx, y.shape)
|
||||||
|
print(y[:,-idx:-idx+20])
|
||||||
|
|
||||||
|
|
||||||
|
# 玄学这档子事说不清楚
|
||||||
|
if idx == cut_id or stop:
|
||||||
|
print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}")
|
||||||
|
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||||
|
full_audios.append(audio)
|
||||||
|
if last_idx == 0:
|
||||||
|
last_audio_ret = audio[-1280*8:-1280*8+256]
|
||||||
|
audio = audio[:-1280*8]
|
||||||
|
raw_audios.append(audio)
|
||||||
|
et = time.time()
|
||||||
|
else:
|
||||||
|
if stop:
|
||||||
|
audio_ = audio[last_idx*1280 -1280*8:]
|
||||||
|
raw_audios.append(audio_)
|
||||||
|
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
|
||||||
|
offset_index.append(i)
|
||||||
|
audio = audio_[i:]
|
||||||
|
else:
|
||||||
|
audio_ = audio[last_idx*1280 -1280*8:-1280*8]
|
||||||
|
raw_audios.append(audio_)
|
||||||
|
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
|
||||||
|
offset_index.append(i)
|
||||||
|
last_audio_ret = audio[-1280*8:-1280*8+256]
|
||||||
|
audio = audio_[i:]
|
||||||
|
last_idx = idx
|
||||||
|
# print(f'write {output_path}/out_{audio_index}')
|
||||||
|
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
|
audios.append(audio)
|
||||||
|
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
|
||||||
|
if idx>1500:
|
||||||
|
break
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
idx+=1
|
||||||
|
|
||||||
|
at = time.time()
|
||||||
|
|
||||||
|
for (i,a) in enumerate(audios):
|
||||||
|
print(f'write {output_path}/out_{i}')
|
||||||
|
soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
|
print(f"frist token: {et - st:.4f} seconds")
|
||||||
|
print(f"all token: {at - st:.4f} seconds")
|
||||||
|
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||||
|
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
|
audio = torch.cat(audios, dim=0)
|
||||||
|
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
|
audio_raw = torch.cat(raw_audios, dim=0)
|
||||||
|
soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
|
|
||||||
|
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
|
||||||
|
|
||||||
|
max_duration = full_audios[-1].shape[0]
|
||||||
|
plt.xlim(0, max_duration)
|
||||||
|
|
||||||
|
last_line = 0
|
||||||
|
|
||||||
|
for i,a in enumerate(full_audios):
|
||||||
|
plt.plot((a+2.0*i).float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}")
|
||||||
|
# plt.axvline(x=last_line, color=colors[i], linestyle='--')
|
||||||
|
last_line = a.shape[0]-8*1280
|
||||||
|
plt.axvline(x=last_line, color=colors[i], linestyle='--')
|
||||||
|
|
||||||
|
plt.plot((audio-2.0).float().detach().cpu().numpy(), color='black', label='Final Audio')
|
||||||
|
|
||||||
|
plt.plot((audio_raw-4.0).float().detach().cpu().numpy(), color='cyan', label='Raw Audio')
|
||||||
|
|
||||||
|
print("offset_index:", offset_index)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def export_prov2(
|
||||||
|
gpt_path,
|
||||||
|
vits_path,
|
||||||
|
version,
|
||||||
|
ref_audio_path,
|
||||||
|
ref_text,
|
||||||
|
output_path,
|
||||||
|
device="cpu",
|
||||||
|
is_half=True,
|
||||||
|
lang="auto",
|
||||||
|
):
|
||||||
|
if export_torch_script.sv_cn_model == None:
|
||||||
|
init_sv_cn(device,is_half)
|
||||||
|
|
||||||
|
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||||
|
ssl = SSLModel()
|
||||||
|
|
||||||
|
print(f"device: {device}")
|
||||||
|
|
||||||
|
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
|
||||||
|
ref_text, lang, "v2"
|
||||||
|
)
|
||||||
|
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||||
|
ref_bert = ref_bert_T.T
|
||||||
|
if is_half:
|
||||||
|
ref_bert = ref_bert.half()
|
||||||
|
ref_bert = ref_bert.to(ref_seq.device)
|
||||||
|
|
||||||
|
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||||
|
"这是一个简单的示例,真没想到这么简单就完成了.The King and His Stories.Once there was a king.He likes to write stories, but his stories were not good.", "auto", "v2"
|
||||||
|
)
|
||||||
|
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||||
|
text_bert = text_bert_T.T
|
||||||
|
if is_half:
|
||||||
|
text_bert = text_bert.half()
|
||||||
|
text_bert = text_bert.to(text_seq.device)
|
||||||
|
|
||||||
|
ssl_content = ssl(ref_audio)
|
||||||
|
if is_half:
|
||||||
|
ssl_content = ssl_content.half()
|
||||||
|
ssl_content = ssl_content.to(device)
|
||||||
|
|
||||||
|
sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model)
|
||||||
|
|
||||||
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
|
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
|
||||||
|
vits.eval()
|
||||||
|
vits = StepVitsModel(vits, sv_model)
|
||||||
|
|
||||||
|
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||||
|
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||||
|
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||||
|
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||||
|
print("#### get_raw_t2s_model ####")
|
||||||
|
print(raw_t2s.config)
|
||||||
|
if is_half:
|
||||||
|
raw_t2s = raw_t2s.half()
|
||||||
|
t2s_m = T2SModel(raw_t2s)
|
||||||
|
t2s_m.eval()
|
||||||
|
# t2s = torch.jit.script(t2s_m).to(device)
|
||||||
|
t2s = t2s_m
|
||||||
|
print("#### script t2s_m ####")
|
||||||
|
|
||||||
|
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||||
|
|
||||||
|
stream_t2s = StreamT2SModel(t2s).to(device)
|
||||||
|
stream_t2s = torch.jit.script(stream_t2s)
|
||||||
|
|
||||||
|
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
|
||||||
|
ref_audio_sr = ref_audio_sr.to(device)
|
||||||
|
if is_half:
|
||||||
|
ref_audio_sr = ref_audio_sr.half()
|
||||||
|
|
||||||
|
top_k = 15
|
||||||
|
|
||||||
|
prompts = vits.extract_latent(ssl_content)
|
||||||
|
|
||||||
|
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||||||
|
sv_emb = sv_model(audio_16k)
|
||||||
|
print("text_seq",text_seq.shape)
|
||||||
|
# torch.jit.trace()
|
||||||
|
|
||||||
|
refer,sv_emb = vits.ref_handle(ref_audio_sr)
|
||||||
|
|
||||||
|
st = time.time()
|
||||||
|
et = time.time()
|
||||||
|
|
||||||
|
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||||
|
idx = 1
|
||||||
|
print("y.shape:", y.shape)
|
||||||
|
while True:
|
||||||
|
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
||||||
|
# print("y.shape:", y.shape)
|
||||||
|
|
||||||
|
idx+=1
|
||||||
|
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
|
||||||
|
if idx>1500:
|
||||||
|
break
|
||||||
|
|
||||||
|
if last_token == t2s.EOS:
|
||||||
|
break
|
||||||
|
|
||||||
|
at = time.time()
|
||||||
|
print("EOS:",t2s.EOS)
|
||||||
|
|
||||||
|
print(f"frist token: {et - st:.4f} seconds")
|
||||||
|
print(f"all token: {at - st:.4f} seconds")
|
||||||
|
print("sv_emb", sv_emb.shape)
|
||||||
|
print("refer",refer.shape)
|
||||||
|
y = y[:,-idx:].unsqueeze(0)
|
||||||
|
print("y", y.shape)
|
||||||
|
audio = vits(y, text_seq, refer, sv_emb)
|
||||||
|
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
|
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||||
|
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||||||
|
torch._dynamo.mark_dynamic(text_seq, 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||||
|
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||||
|
torch._dynamo.mark_dynamic(refer, 2)
|
||||||
|
torch._dynamo.mark_dynamic(y, 2)
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"forward": (y, text_seq, refer, sv_emb),
|
||||||
|
"extract_latent": ssl_content,
|
||||||
|
"ref_handle": ref_audio_sr,
|
||||||
|
}
|
||||||
|
|
||||||
|
stream_t2s.save(f"{output_path}/t2s.pt")
|
||||||
|
torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt")
|
||||||
|
torch.jit.script(find_best_audio_offset_fast, optimize=True).save(f"{output_path}/find_best_audio_offset_fast.pt")
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
|
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sovits_model", required=True, help="Path to the SoVITS model file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ref_audio", required=True, help="Path to the reference audio file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ref_text", required=True, help="Path to the reference text file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path", required=True, help="Path to the output directory"
|
||||||
|
)
|
||||||
|
parser.add_argument("--device", help="Device to use", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
parser.add_argument("--version", help="version of the model", default="v2Pro")
|
||||||
|
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
|
||||||
|
parser.add_argument("--lang", default="auto", help="Language for text processing (default: auto)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.output_path):
|
||||||
|
os.makedirs(args.output_path)
|
||||||
|
|
||||||
|
is_half = not args.no_half
|
||||||
|
with torch.no_grad():
|
||||||
|
export_prov2(
|
||||||
|
gpt_path=args.gpt_model,
|
||||||
|
vits_path=args.sovits_model,
|
||||||
|
version=args.version,
|
||||||
|
ref_audio_path=args.ref_audio,
|
||||||
|
ref_text=args.ref_text,
|
||||||
|
output_path=args.output_path,
|
||||||
|
device=args.device,
|
||||||
|
is_half=is_half,
|
||||||
|
lang=args.lang,
|
||||||
|
)
|
||||||
@ -238,6 +238,46 @@ def _expand_number(m):
|
|||||||
return _inflect.number_to_words(num, andword="")
|
return _inflect.number_to_words(num, andword="")
|
||||||
|
|
||||||
|
|
||||||
|
# 加减乘除
|
||||||
|
RE_ASMD = re.compile(
|
||||||
|
r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))\s+([\+\-\×÷=])\s+((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
|
||||||
|
)
|
||||||
|
# RE_ASMD = re.compile(
|
||||||
|
# r"\b((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))\b"
|
||||||
|
# )
|
||||||
|
|
||||||
|
asmd_map = {"+": " plus ", "-": " minus ", "×": " times ", "÷": " divided by ", "=": " Equals "}
|
||||||
|
|
||||||
|
|
||||||
|
def replace_asmd(match) -> str:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
match (re.Match)
|
||||||
|
Returns:
|
||||||
|
str
|
||||||
|
"""
|
||||||
|
result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
RE_INTEGER = re.compile(r"(?:^|\s+)(-)" r"(\d+)")
|
||||||
|
|
||||||
|
|
||||||
|
def replace_negative_num(match) -> str:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
match (re.Match)
|
||||||
|
Returns:
|
||||||
|
str
|
||||||
|
"""
|
||||||
|
sign = match.group(1)
|
||||||
|
number = match.group(2)
|
||||||
|
sign: str = "negative " if sign else ""
|
||||||
|
result = f"{sign}{number}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(text):
|
def normalize(text):
|
||||||
"""
|
"""
|
||||||
!!! 所有的处理都需要正确的输入 !!!
|
!!! 所有的处理都需要正确的输入 !!!
|
||||||
@ -245,7 +285,13 @@ def normalize(text):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
|
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
|
||||||
text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
|
|
||||||
|
# 处理数学运算
|
||||||
|
# 替换text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
|
||||||
|
while RE_ASMD.search(text):
|
||||||
|
text = RE_ASMD.sub(replace_asmd, text)
|
||||||
|
text = RE_INTEGER.sub(replace_negative_num, text)
|
||||||
|
|
||||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||||
text = re.sub(_time_re, _expand_time, text)
|
text = re.sub(_time_re, _expand_time, text)
|
||||||
text = re.sub(_measurement_re, _expand_measurement, text)
|
text = re.sub(_measurement_re, _expand_measurement, text)
|
||||||
|
|||||||
@ -347,7 +347,7 @@ Use v4 from v1/v2/v3 environment:
|
|||||||
|
|
||||||
2. Clone the latest codes from github.
|
2. Clone the latest codes from github.
|
||||||
|
|
||||||
3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.ckpt, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS/pretrained_models`.
|
3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.pth, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS/pretrained_models`.
|
||||||
|
|
||||||
## V2Pro Release Notes
|
## V2Pro Release Notes
|
||||||
|
|
||||||
|
|||||||
126
api_v2.py
126
api_v2.py
@ -27,20 +27,23 @@ POST:
|
|||||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||||
"top_k": 5, # int. top k sampling
|
"top_k": 15, # int. top k sampling
|
||||||
"top_p": 1, # float. top p sampling
|
"top_p": 1, # float. top p sampling
|
||||||
"temperature": 1, # float. temperature for sampling
|
"temperature": 1, # float. temperature for sampling
|
||||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||||
"batch_size": 1, # int. batch size for inference
|
"batch_size": 1, # int. batch size for inference
|
||||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||||
"super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3.
|
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||||
|
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||||
|
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||||
|
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -101,7 +104,7 @@ RESP:
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Generator
|
from typing import Generator, Union
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -121,6 +124,7 @@ from tools.i18n.i18n import I18nAuto
|
|||||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
import threading
|
||||||
|
|
||||||
# print(sys.path)
|
# print(sys.path)
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
@ -154,7 +158,7 @@ class TTS_Request(BaseModel):
|
|||||||
aux_ref_audio_paths: list = None
|
aux_ref_audio_paths: list = None
|
||||||
prompt_lang: str = None
|
prompt_lang: str = None
|
||||||
prompt_text: str = ""
|
prompt_text: str = ""
|
||||||
top_k: int = 5
|
top_k: int = 15
|
||||||
top_p: float = 1
|
top_p: float = 1
|
||||||
temperature: float = 1
|
temperature: float = 1
|
||||||
text_split_method: str = "cut5"
|
text_split_method: str = "cut5"
|
||||||
@ -165,17 +169,58 @@ class TTS_Request(BaseModel):
|
|||||||
fragment_interval: float = 0.3
|
fragment_interval: float = 0.3
|
||||||
seed: int = -1
|
seed: int = -1
|
||||||
media_type: str = "wav"
|
media_type: str = "wav"
|
||||||
streaming_mode: bool = False
|
streaming_mode: Union[bool, int] = False
|
||||||
parallel_infer: bool = True
|
parallel_infer: bool = True
|
||||||
repetition_penalty: float = 1.35
|
repetition_penalty: float = 1.35
|
||||||
sample_steps: int = 32
|
sample_steps: int = 32
|
||||||
super_sampling: bool = False
|
super_sampling: bool = False
|
||||||
|
overlap_length: int = 2
|
||||||
|
min_chunk_length: int = 16
|
||||||
|
|
||||||
|
|
||||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
|
||||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
# Author: AkagawaTsurunaki
|
||||||
audio_file.write(data)
|
# Issue:
|
||||||
|
# Stack overflow probabilistically occurs
|
||||||
|
# when the function `sf_writef_short` of `libsndfile_64bit.dll` is called
|
||||||
|
# using the Python library `soundfile`
|
||||||
|
# Note:
|
||||||
|
# This is an issue related to `libsndfile`, not this project itself.
|
||||||
|
# It happens when you generate a large audio tensor (about 499804 frames in my PC)
|
||||||
|
# and try to convert it to an ogg file.
|
||||||
|
# Related:
|
||||||
|
# https://github.com/RVC-Boss/GPT-SoVITS/issues/1199
|
||||||
|
# https://github.com/libsndfile/libsndfile/issues/1023
|
||||||
|
# https://github.com/bastibe/python-soundfile/issues/396
|
||||||
|
# Suggestion:
|
||||||
|
# Or split the whole audio data into smaller audio segment to avoid stack overflow?
|
||||||
|
|
||||||
|
def handle_pack_ogg():
|
||||||
|
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||||
|
audio_file.write(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# See: https://docs.python.org/3/library/threading.html
|
||||||
|
# The stack size of this thread is at least 32768
|
||||||
|
# If stack overflow error still occurs, just modify the `stack_size`.
|
||||||
|
# stack_size = n * 4096, where n should be a positive integer.
|
||||||
|
# Here we chose n = 4096.
|
||||||
|
stack_size = 4096 * 4096
|
||||||
|
try:
|
||||||
|
threading.stack_size(stack_size)
|
||||||
|
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
|
||||||
|
pack_ogg_thread.start()
|
||||||
|
pack_ogg_thread.join()
|
||||||
|
except RuntimeError as e:
|
||||||
|
# If changing the thread stack size is unsupported, a RuntimeError is raised.
|
||||||
|
print("RuntimeError: {}".format(e))
|
||||||
|
print("Changing the thread stack size is unsupported.")
|
||||||
|
except ValueError as e:
|
||||||
|
# If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified.
|
||||||
|
print("ValueError: {}".format(e))
|
||||||
|
print("The specified stack size is invalid.")
|
||||||
|
|
||||||
return io_buffer
|
return io_buffer
|
||||||
|
|
||||||
|
|
||||||
@ -286,8 +331,8 @@ def check_params(req: dict):
|
|||||||
)
|
)
|
||||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||||
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
|
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
|
||||||
elif media_type == "ogg" and not streaming_mode:
|
# elif media_type == "ogg" and not streaming_mode:
|
||||||
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
# return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||||
|
|
||||||
if text_split_method not in cut_method_names:
|
if text_split_method not in cut_method_names:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -307,25 +352,26 @@ async def tts_handle(req: dict):
|
|||||||
"text": "", # str.(required) text to be synthesized
|
"text": "", # str.(required) text to be synthesized
|
||||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||||
"ref_audio_path": "", # str.(required) reference audio path
|
"ref_audio_path": "", # str.(required) reference audio path
|
||||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||||
"top_k": 5, # int. top k sampling
|
"top_k": 15, # int. top k sampling
|
||||||
"top_p": 1, # float. top p sampling
|
"top_p": 1, # float. top p sampling
|
||||||
"temperature": 1, # float. temperature for sampling
|
"temperature": 1, # float. temperature for sampling
|
||||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||||
"batch_size": 1, # int. batch size for inference
|
"batch_size": 1, # int. batch size for inference
|
||||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
|
||||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
|
||||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||||
|
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||||
|
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||||
|
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||||
}
|
}
|
||||||
returns:
|
returns:
|
||||||
StreamingResponse: audio stream response.
|
StreamingResponse: audio stream response.
|
||||||
@ -338,9 +384,35 @@ async def tts_handle(req: dict):
|
|||||||
check_res = check_params(req)
|
check_res = check_params(req)
|
||||||
if check_res is not None:
|
if check_res is not None:
|
||||||
return check_res
|
return check_res
|
||||||
|
|
||||||
|
if streaming_mode == 0:
|
||||||
|
streaming_mode = False
|
||||||
|
return_fragment = False
|
||||||
|
fixed_length_chunk = False
|
||||||
|
elif streaming_mode == 1:
|
||||||
|
streaming_mode = False
|
||||||
|
return_fragment = True
|
||||||
|
fixed_length_chunk = False
|
||||||
|
elif streaming_mode == 2:
|
||||||
|
streaming_mode = True
|
||||||
|
return_fragment = False
|
||||||
|
fixed_length_chunk = False
|
||||||
|
elif streaming_mode == 3:
|
||||||
|
streaming_mode = True
|
||||||
|
return_fragment = False
|
||||||
|
fixed_length_chunk = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"})
|
||||||
|
|
||||||
|
req["streaming_mode"] = streaming_mode
|
||||||
|
req["return_fragment"] = return_fragment
|
||||||
|
req["fixed_length_chunk"] = fixed_length_chunk
|
||||||
|
|
||||||
|
print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}")
|
||||||
|
|
||||||
|
streaming_mode = streaming_mode or return_fragment
|
||||||
|
|
||||||
if streaming_mode or return_fragment:
|
|
||||||
req["return_fragment"] = True
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tts_generator = tts_pipeline.run(req)
|
tts_generator = tts_pipeline.run(req)
|
||||||
@ -388,10 +460,10 @@ async def tts_get_endpoint(
|
|||||||
aux_ref_audio_paths: list = None,
|
aux_ref_audio_paths: list = None,
|
||||||
prompt_lang: str = None,
|
prompt_lang: str = None,
|
||||||
prompt_text: str = "",
|
prompt_text: str = "",
|
||||||
top_k: int = 5,
|
top_k: int = 15,
|
||||||
top_p: float = 1,
|
top_p: float = 1,
|
||||||
temperature: float = 1,
|
temperature: float = 1,
|
||||||
text_split_method: str = "cut0",
|
text_split_method: str = "cut5",
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
batch_threshold: float = 0.75,
|
batch_threshold: float = 0.75,
|
||||||
split_bucket: bool = True,
|
split_bucket: bool = True,
|
||||||
@ -399,11 +471,13 @@ async def tts_get_endpoint(
|
|||||||
fragment_interval: float = 0.3,
|
fragment_interval: float = 0.3,
|
||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
media_type: str = "wav",
|
media_type: str = "wav",
|
||||||
streaming_mode: bool = False,
|
|
||||||
parallel_infer: bool = True,
|
parallel_infer: bool = True,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
sample_steps: int = 32,
|
sample_steps: int = 32,
|
||||||
super_sampling: bool = False,
|
super_sampling: bool = False,
|
||||||
|
streaming_mode: Union[bool, int] = False,
|
||||||
|
overlap_length: int = 2,
|
||||||
|
min_chunk_length: int = 16,
|
||||||
):
|
):
|
||||||
req = {
|
req = {
|
||||||
"text": text,
|
"text": text,
|
||||||
@ -428,6 +502,8 @@ async def tts_get_endpoint(
|
|||||||
"repetition_penalty": float(repetition_penalty),
|
"repetition_penalty": float(repetition_penalty),
|
||||||
"sample_steps": int(sample_steps),
|
"sample_steps": int(sample_steps),
|
||||||
"super_sampling": super_sampling,
|
"super_sampling": super_sampling,
|
||||||
|
"overlap_length": int(overlap_length),
|
||||||
|
"min_chunk_length": int(min_chunk_length),
|
||||||
}
|
}
|
||||||
return await tts_handle(req)
|
return await tts_handle(req)
|
||||||
|
|
||||||
|
|||||||
@ -373,7 +373,7 @@ if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ]; then
|
|||||||
location=$(pip show torch | grep Location | awk -F ": " '{print $2}')
|
location=$(pip show torch | grep Location | awk -F ": " '{print $2}')
|
||||||
cd "${location}"/torch/lib/ || exit
|
cd "${location}"/torch/lib/ || exit
|
||||||
rm libhsa-runtime64.so*
|
rm libhsa-runtime64.so*
|
||||||
cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so
|
cp "$(readlink -f /opt/rocm/lib/libhsa-runtime64.so)" libhsa-runtime64.so
|
||||||
echo -e "${SUCCESS}ROCm Runtime Lib Updated..."
|
echo -e "${SUCCESS}ROCm Runtime Lib Updated..."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@ -16,10 +16,10 @@ pypinyin
|
|||||||
pyopenjtalk>=0.4.1
|
pyopenjtalk>=0.4.1
|
||||||
g2p_en
|
g2p_en
|
||||||
torchaudio
|
torchaudio
|
||||||
modelscope==1.10.0
|
modelscope
|
||||||
sentencepiece
|
sentencepiece
|
||||||
transformers>=4.43,<=4.50
|
transformers>=4.43,<=4.50
|
||||||
peft
|
peft<0.18.0
|
||||||
chardet
|
chardet
|
||||||
PyYAML
|
PyYAML
|
||||||
psutil
|
psutil
|
||||||
@ -39,7 +39,5 @@ x_transformers
|
|||||||
torchmetrics<=1.5
|
torchmetrics<=1.5
|
||||||
pydantic<=2.10.6
|
pydantic<=2.10.6
|
||||||
ctranslate2>=4.0,<5
|
ctranslate2>=4.0,<5
|
||||||
huggingface_hub>=0.13
|
|
||||||
tokenizers>=0.13,<1
|
|
||||||
av>=11
|
av>=11
|
||||||
tqdm
|
tqdm
|
||||||
|
|||||||
@ -1,34 +1,13 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def check_fw_local_models():
|
|
||||||
"""
|
|
||||||
启动时检查本地是否有 Faster Whisper 模型.
|
|
||||||
"""
|
|
||||||
model_size_list = [
|
|
||||||
"medium",
|
|
||||||
"medium.en",
|
|
||||||
"distil-large-v2",
|
|
||||||
"distil-large-v3",
|
|
||||||
"large-v1",
|
|
||||||
"large-v2",
|
|
||||||
"large-v3",
|
|
||||||
]
|
|
||||||
for i, size in enumerate(model_size_list):
|
|
||||||
if os.path.exists(f"tools/asr/models/faster-whisper-{size}"):
|
|
||||||
model_size_list[i] = size + "-local"
|
|
||||||
return model_size_list
|
|
||||||
|
|
||||||
|
|
||||||
def get_models():
|
def get_models():
|
||||||
model_size_list = [
|
model_size_list = [
|
||||||
"medium",
|
"medium",
|
||||||
"medium.en",
|
"medium.en",
|
||||||
"distil-large-v2",
|
|
||||||
"distil-large-v3",
|
|
||||||
"large-v1",
|
|
||||||
"large-v2",
|
"large-v2",
|
||||||
"large-v3",
|
"large-v3",
|
||||||
|
"large-v3-turbo",
|
||||||
|
#"distil-large-v2",
|
||||||
|
#"distil-large-v3",
|
||||||
|
#"distil-large-v3.5",
|
||||||
]
|
]
|
||||||
return model_size_list
|
return model_size_list
|
||||||
|
|
||||||
@ -36,7 +15,7 @@ def get_models():
|
|||||||
asr_dict = {
|
asr_dict = {
|
||||||
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
|
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
|
||||||
"Faster Whisper (多语种)": {
|
"Faster Whisper (多语种)": {
|
||||||
"lang": ["auto", "zh", "en", "ja", "ko", "yue"],
|
"lang": ["auto", "en", "ja", "ko"],
|
||||||
"size": get_models(),
|
"size": get_models(),
|
||||||
"path": "fasterwhisper_asr.py",
|
"path": "fasterwhisper_asr.py",
|
||||||
"precision": ["float32", "float16", "int8"],
|
"precision": ["float32", "float16", "int8"],
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download as snapshot_download_hf
|
||||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
from modelscope import snapshot_download as snapshot_download_ms
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from tools.asr.config import get_models
|
from tools.asr.config import get_models
|
||||||
@ -40,11 +40,32 @@ language_code_list = [
|
|||||||
|
|
||||||
|
|
||||||
def download_model(model_size: str):
|
def download_model(model_size: str):
|
||||||
if "distil" in model_size:
|
url = "https://huggingface.co/api/models/gpt2"
|
||||||
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
|
try:
|
||||||
|
requests.get(url, timeout=3)
|
||||||
|
source = "HF"
|
||||||
|
except Exception:
|
||||||
|
source = "ModelScope"
|
||||||
|
|
||||||
|
model_path = ""
|
||||||
|
if source == "HF":
|
||||||
|
if "distil" in model_size:
|
||||||
|
if "3.5" in model_size:
|
||||||
|
repo_id = "distil-whisper/distil-large-v3.5-ct2"
|
||||||
|
model_path = "tools/asr/models/faster-distil-whisper-large-v3.5"
|
||||||
|
else:
|
||||||
|
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
|
||||||
|
elif model_size == "large-v3-turbo":
|
||||||
|
repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
|
||||||
|
model_path = "tools/asr/models/faster-whisper-large-v3-turbo"
|
||||||
|
else:
|
||||||
|
repo_id = f"Systran/faster-whisper-{model_size}"
|
||||||
|
model_path = (
|
||||||
|
model_path or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
repo_id = "XXXXRT/faster-whisper"
|
||||||
model_path = f"tools/asr/models/{repo_id.strip('Systran/')}"
|
model_path = "tools/asr/models"
|
||||||
|
|
||||||
files: list[str] = [
|
files: list[str] = [
|
||||||
"config.json",
|
"config.json",
|
||||||
@ -52,32 +73,31 @@ def download_model(model_size: str):
|
|||||||
"tokenizer.json",
|
"tokenizer.json",
|
||||||
"vocabulary.txt",
|
"vocabulary.txt",
|
||||||
]
|
]
|
||||||
if model_size == "large-v3" or "distil" in model_size:
|
if "large-v3" in model_size or "distil" in model_size:
|
||||||
files.append("preprocessor_config.json")
|
files.append("preprocessor_config.json")
|
||||||
files.append("vocabulary.json")
|
files.append("vocabulary.json")
|
||||||
|
|
||||||
files.remove("vocabulary.txt")
|
files.remove("vocabulary.txt")
|
||||||
|
|
||||||
for attempt in range(2):
|
if source == "ModelScope":
|
||||||
try:
|
files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files]
|
||||||
snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
allow_patterns=files,
|
|
||||||
local_dir=model_path,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except LocalEntryNotFoundError:
|
|
||||||
if attempt < 1:
|
|
||||||
time.sleep(2)
|
|
||||||
else:
|
|
||||||
print("[ERROR] LocalEntryNotFoundError and no fallback.")
|
|
||||||
traceback.print_exc()
|
|
||||||
exit(1)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
|
if source == "HF":
|
||||||
|
print(f"Downloading model from HuggingFace: {repo_id} to {model_path}")
|
||||||
|
snapshot_download_hf(
|
||||||
|
repo_id,
|
||||||
|
local_dir=model_path,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=files,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Downloading model from ModelScope: {repo_id} to {model_path}")
|
||||||
|
snapshot_download_ms(
|
||||||
|
repo_id,
|
||||||
|
local_dir=model_path,
|
||||||
|
allow_patterns=files,
|
||||||
|
)
|
||||||
|
return model_path + f"/faster-whisper-{model_size}".replace("whisper-distil", "distil-whisper")
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
@ -106,7 +126,7 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
|
|||||||
)
|
)
|
||||||
text = ""
|
text = ""
|
||||||
|
|
||||||
if info.language == "zh":
|
if info.language in ["zh", "yue"]:
|
||||||
print("检测为中文文本, 转 FunASR 处理")
|
print("检测为中文文本, 转 FunASR 处理")
|
||||||
text = only_asr(file_path, language=info.language.lower())
|
text = only_asr(file_path, language=info.language.lower())
|
||||||
|
|
||||||
|
|||||||
@ -4,9 +4,8 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
# from funasr.utils import version_checker
|
|
||||||
# version_checker.check_for_update = lambda: None
|
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
|
from modelscope import snapshot_download
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
funasr_models = {} # 存储模型避免重复加载
|
funasr_models = {} # 存储模型避免重复加载
|
||||||
@ -16,40 +15,43 @@ def only_asr(input_file, language):
|
|||||||
try:
|
try:
|
||||||
model = create_model(language)
|
model = create_model(language)
|
||||||
text = model.generate(input=input_file)[0]["text"]
|
text = model.generate(input=input_file)[0]["text"]
|
||||||
except:
|
except Exception:
|
||||||
text = ""
|
text = ""
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def create_model(language="zh"):
|
def create_model(language="zh"):
|
||||||
path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
|
||||||
path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
|
||||||
path_vad = path_vad if os.path.exists(path_vad) else "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
|
||||||
path_punc = path_punc if os.path.exists(path_punc) else "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
|
||||||
vad_model_revision = punc_model_revision = "v2.0.4"
|
|
||||||
|
|
||||||
if language == "zh":
|
if language == "zh":
|
||||||
|
path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||||
|
path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||||
path_asr = "tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
path_asr = "tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||||
path_asr = (
|
snapshot_download(
|
||||||
path_asr
|
"iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
if os.path.exists(path_asr)
|
local_dir="tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
else "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
)
|
||||||
|
snapshot_download(
|
||||||
|
"iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||||
|
local_dir="tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||||
|
)
|
||||||
|
snapshot_download(
|
||||||
|
"iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
|
local_dir="tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
)
|
)
|
||||||
model_revision = "v2.0.4"
|
model_revision = "v2.0.4"
|
||||||
elif language == "yue":
|
elif language == "yue":
|
||||||
path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
||||||
path_asr = (
|
snapshot_download(
|
||||||
path_asr
|
"iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
|
||||||
if os.path.exists(path_asr)
|
local_dir="tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
|
||||||
else "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
|
||||||
)
|
)
|
||||||
model_revision = "master"
|
|
||||||
path_vad = path_punc = None
|
path_vad = path_punc = None
|
||||||
vad_model_revision = punc_model_revision = None
|
vad_model_revision = punc_model_revision = ""
|
||||||
###友情提示:粤语带VAD识别可能会有少量shape不对报错的,但是不带VAD可以.不带vad只能分阶段单独加标点。不过标点模型对粤语效果真的不行…
|
model_revision = "master"
|
||||||
else:
|
else:
|
||||||
raise ValueError("FunASR 不支持该语言" + ": " + language)
|
raise ValueError(f"{language} is not supported")
|
||||||
|
|
||||||
|
vad_model_revision = punc_model_revision = "v2.0.4"
|
||||||
|
|
||||||
if language in funasr_models:
|
if language in funasr_models:
|
||||||
return funasr_models[language]
|
return funasr_models[language]
|
||||||
@ -83,7 +85,7 @@ def execute_asr(input_folder, output_folder, model_size, language):
|
|||||||
file_path = os.path.join(input_folder, file_name)
|
file_path = os.path.join(input_folder, file_name)
|
||||||
text = model.generate(input=file_path)[0]["text"]
|
text = model.generate(input=file_path)[0]["text"]
|
||||||
output.append(f"{file_path}|{output_file_name}|{language.upper()}|{text}")
|
output.append(f"{file_path}|{output_file_name}|{language.upper()}|{text}")
|
||||||
except:
|
except Exception:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
output_folder = output_folder or "output/asr_opt"
|
output_folder = output_folder or "output/asr_opt"
|
||||||
|
|||||||
@ -38,7 +38,7 @@
|
|||||||
"hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: FO hop size, the smaller the value, the higher the accuracy)",
|
"hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: FO hop size, the smaller the value, the higher the accuracy)",
|
||||||
"max:归一化后最大值多少": "Loudness multiplier after normalized",
|
"max:归一化后最大值多少": "Loudness multiplier after normalized",
|
||||||
"max_sil_kept:切完后静音最多留多长": "Maximum length for silence to be kept",
|
"max_sil_kept:切完后静音最多留多长": "Maximum length for silence to be kept",
|
||||||
"min_interval:最短切割间隔": "Minumum interval for audio cutting",
|
"min_interval:最短切割间隔": "Minimum interval for audio cutting",
|
||||||
"min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: the minimum length of each segment. If the first segment is too short, it will be concatenated with the next segment until it exceeds this value",
|
"min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: the minimum length of each segment. If the first segment is too short, it will be concatenated with the next segment until it exceeds this value",
|
||||||
"temperature": "temperature",
|
"temperature": "temperature",
|
||||||
"threshold:音量小于这个值视作静音的备选切割点": "Noise gate threshold (loudness below this value will be treated as noise",
|
"threshold:音量小于这个值视作静音的备选切割点": "Noise gate threshold (loudness below this value will be treated as noise",
|
||||||
@ -176,7 +176,7 @@
|
|||||||
"语音降噪": "Speech Denoising",
|
"语音降噪": "Speech Denoising",
|
||||||
"请上传3~10秒内参考音频,超过会报错!": "Please upload a reference audio within the 3-10 second range; if it exceeds this duration, it will raise errors.",
|
"请上传3~10秒内参考音频,超过会报错!": "Please upload a reference audio within the 3-10 second range; if it exceeds this duration, it will raise errors.",
|
||||||
"请上传参考音频": "Please Upload the Reference Audio",
|
"请上传参考音频": "Please Upload the Reference Audio",
|
||||||
"请填入推理文本": "Please Fill in the Terget Text",
|
"请填入推理文本": "Please Fill in the Target Text",
|
||||||
"请填入正确的List路径": "Please Fill in the Correct List Path",
|
"请填入正确的List路径": "Please Fill in the Correct List Path",
|
||||||
"请填入正确的音频文件夹路径": "Please Fill in the Correct Audio Folder Path",
|
"请填入正确的音频文件夹路径": "Please Fill in the Correct Audio Folder Path",
|
||||||
"请输入有效文本": "Please enter valid text.",
|
"请输入有效文本": "Please enter valid text.",
|
||||||
|
|||||||
25
webui.py
25
webui.py
@ -86,7 +86,6 @@ from config import (
|
|||||||
from tools import my_utils
|
from tools import my_utils
|
||||||
from tools.my_utils import check_details, check_for_existance
|
from tools.my_utils import check_details, check_for_existance
|
||||||
|
|
||||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
|
|
||||||
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
|
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
|
||||||
@ -117,8 +116,8 @@ def set_default():
|
|||||||
gpu_info = "\n".join(gpu_infos)
|
gpu_info = "\n".join(gpu_infos)
|
||||||
if is_gpu_ok:
|
if is_gpu_ok:
|
||||||
minmem = min(mem)
|
minmem = min(mem)
|
||||||
default_batch_size = minmem // 2 if version not in v3v4set else minmem // 8
|
default_batch_size = int(minmem // 2 if version not in v3v4set else minmem // 8)
|
||||||
default_batch_size_s1 = minmem // 2
|
default_batch_size_s1 = int(minmem // 2)
|
||||||
else:
|
else:
|
||||||
default_batch_size = default_batch_size_s1 = int(psutil.virtual_memory().total / 1024 / 1024 / 1024 / 4)
|
default_batch_size = default_batch_size_s1 = int(psutil.virtual_memory().total / 1024 / 1024 / 1024 / 4)
|
||||||
if version not in v3v4set:
|
if version not in v3v4set:
|
||||||
@ -343,7 +342,7 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
|
|||||||
os.environ["sovits_path"] = sovits_path
|
os.environ["sovits_path"] = sovits_path
|
||||||
os.environ["cnhubert_base_path"] = cnhubert_base_path
|
os.environ["cnhubert_base_path"] = cnhubert_base_path
|
||||||
os.environ["bert_path"] = bert_path
|
os.environ["bert_path"] = bert_path
|
||||||
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_number(gpu_number)
|
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
|
||||||
os.environ["is_half"] = str(is_half)
|
os.environ["is_half"] = str(is_half)
|
||||||
os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
|
os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
|
||||||
os.environ["is_share"] = str(is_share)
|
os.environ["is_share"] = str(is_share)
|
||||||
@ -628,7 +627,7 @@ def open1Bb(
|
|||||||
data["output_dir"] = "%s/logs_s1_%s" % (s1_dir, version)
|
data["output_dir"] = "%s/logs_s1_%s" % (s1_dir, version)
|
||||||
# data["version"]=version
|
# data["version"]=version
|
||||||
|
|
||||||
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(gpu_numbers.replace("-", ","))
|
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ",")))
|
||||||
os.environ["hz"] = "25hz"
|
os.environ["hz"] = "25hz"
|
||||||
tmp_config_path = "%s/tmp_s1.yaml" % tmp
|
tmp_config_path = "%s/tmp_s1.yaml" % tmp
|
||||||
with open(tmp_config_path, "w") as f:
|
with open(tmp_config_path, "w") as f:
|
||||||
@ -801,7 +800,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
|
|||||||
{
|
{
|
||||||
"i_part": str(i_part),
|
"i_part": str(i_part),
|
||||||
"all_parts": str(all_parts),
|
"all_parts": str(all_parts),
|
||||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||||
"is_half": str(is_half),
|
"is_half": str(is_half),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -892,7 +891,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
|
|||||||
{
|
{
|
||||||
"i_part": str(i_part),
|
"i_part": str(i_part),
|
||||||
"all_parts": str(all_parts),
|
"all_parts": str(all_parts),
|
||||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
os.environ.update(config)
|
os.environ.update(config)
|
||||||
@ -914,7 +913,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
|
|||||||
{
|
{
|
||||||
"i_part": str(i_part),
|
"i_part": str(i_part),
|
||||||
"all_parts": str(all_parts),
|
"all_parts": str(all_parts),
|
||||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
os.environ.update(config)
|
os.environ.update(config)
|
||||||
@ -986,7 +985,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
|
|||||||
{
|
{
|
||||||
"i_part": str(i_part),
|
"i_part": str(i_part),
|
||||||
"all_parts": str(all_parts),
|
"all_parts": str(all_parts),
|
||||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
os.environ.update(config)
|
os.environ.update(config)
|
||||||
@ -1086,7 +1085,7 @@ def open1abc(
|
|||||||
{
|
{
|
||||||
"i_part": str(i_part),
|
"i_part": str(i_part),
|
||||||
"all_parts": str(all_parts),
|
"all_parts": str(all_parts),
|
||||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
os.environ.update(config)
|
os.environ.update(config)
|
||||||
@ -1133,7 +1132,7 @@ def open1abc(
|
|||||||
{
|
{
|
||||||
"i_part": str(i_part),
|
"i_part": str(i_part),
|
||||||
"all_parts": str(all_parts),
|
"all_parts": str(all_parts),
|
||||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
os.environ.update(config)
|
os.environ.update(config)
|
||||||
@ -1155,7 +1154,7 @@ def open1abc(
|
|||||||
{
|
{
|
||||||
"i_part": str(i_part),
|
"i_part": str(i_part),
|
||||||
"all_parts": str(all_parts),
|
"all_parts": str(all_parts),
|
||||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
os.environ.update(config)
|
os.environ.update(config)
|
||||||
@ -1195,7 +1194,7 @@ def open1abc(
|
|||||||
{
|
{
|
||||||
"i_part": str(i_part),
|
"i_part": str(i_part),
|
||||||
"all_parts": str(all_parts),
|
"all_parts": str(all_parts),
|
||||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
os.environ.update(config)
|
os.environ.update(config)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user