Merge branch 'RVC-Boss:main' into main

This commit is contained in:
hsoftxl 2026-01-13 17:25:10 +08:00 committed by GitHub
commit 0235857b89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1718 additions and 336 deletions

View File

@ -707,10 +707,12 @@ class Text2SemanticDecoder(nn.Module):
if idx == 0:
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
logits = logits[:, :-1]
else:
attn_mask = F.pad(attn_mask, (0, 1), value=False)
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
@ -794,7 +796,7 @@ class Text2SemanticDecoder(nn.Module):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(
y, idx = next(self.infer_panel_naive(
x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
@ -805,7 +807,7 @@ class Text2SemanticDecoder(nn.Module):
temperature,
repetition_penalty,
**kwargs,
)
))
y_list.append(y[0])
idx_list.append(idx)
@ -822,8 +824,15 @@ class Text2SemanticDecoder(nn.Module):
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
streaming_mode: bool = False,
chunk_length: int = 24,
**kwargs,
):
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3)
check_token_num = 2
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
@ -875,7 +884,10 @@ class Text2SemanticDecoder(nn.Module):
.to(device=x.device, dtype=torch.bool)
)
token_counter = 0
curr_ptr = prefix_len
for idx in tqdm(range(1500)):
token_counter+=1
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
else:
@ -900,22 +912,56 @@ class Text2SemanticDecoder(nn.Module):
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
y=y[:, :-1]
token_counter -= 1
if idx == 1499:
stop = True
if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
if streaming_mode:
yield y[:, curr_ptr:] if curr_ptr<y.shape[1] else None, True
break
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num):
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold
score[score<0]=-1
score[:-1]=score[:-1]+score[1:] ##考虑连续两个token
argmax_idx = score.argmax()
if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length:
print(f"\n\ncurr_ptr:{curr_ptr}")
yield y[:, curr_ptr:], False
token_counter -= argmax_idx+1
curr_ptr += argmax_idx+1
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
yield y[:, -token_counter:], False
curr_ptr+=token_counter
token_counter = 0
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx
if not streaming_mode:
if ref_free:
yield y, 0
yield y, idx
def infer_panel(
self,
@ -930,6 +976,6 @@ class Text2SemanticDecoder(nn.Module):
repetition_penalty: float = 1.35,
**kwargs,
):
return self.infer_panel_naive(
return next(self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)
))

View File

@ -275,6 +275,15 @@ class TTS_Config:
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"]
languages: list = v2_languages
mute_tokens: dict = {
"v1" : 486,
"v2" : 486,
"v2Pro": 486,
"v2ProPlus": 486,
"v3" : 486,
"v4" : 486,
}
mute_emb_sim_matrix: torch.Tensor = None
# "all_zh",#全部按中文识别
# "en",#全部按英文识别#######不变
# "all_ja",#全部按日文识别
@ -598,6 +607,11 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half()
codebook = t2s_model.model.ar_audio_embedding.weight.clone()
mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0)
sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1)
self.configs.mute_emb_sim_matrix = sim_matrix
def init_vocoder(self, version: str):
if version == "v3":
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
@ -994,21 +1008,25 @@ class TTS:
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_k": 15, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
}
returns:
Tuple[int, np.ndarray]: sampling rate and audio data.
@ -1021,10 +1039,10 @@ class TTS:
aux_ref_audio_paths: list = inputs.get("aux_ref_audio_paths", [])
prompt_text: str = inputs.get("prompt_text", "")
prompt_lang: str = inputs.get("prompt_lang", "")
top_k: int = inputs.get("top_k", 5)
top_k: int = inputs.get("top_k", 15)
top_p: float = inputs.get("top_p", 1)
temperature: float = inputs.get("temperature", 1)
text_split_method: str = inputs.get("text_split_method", "cut0")
text_split_method: str = inputs.get("text_split_method", "cut1")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
@ -1038,19 +1056,43 @@ class TTS:
repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 32)
super_sampling = inputs.get("super_sampling", False)
streaming_mode = inputs.get("streaming_mode", False)
overlap_length = inputs.get("overlap_length", 2)
min_chunk_length = inputs.get("min_chunk_length", 16)
fixed_length_chunk = inputs.get("fixed_length_chunk", False)
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值若大于该阈值则视为可切分点。
if parallel_infer:
if parallel_infer and not streaming_mode:
print(i18n("并行推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
elif not parallel_infer and streaming_mode and not self.configs.use_vocoder:
print(i18n("流式推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
elif streaming_mode and self.configs.use_vocoder:
print(i18n("SoVits V3/4模型不支持流式推理模式已自动回退到分段返回模式"))
streaming_mode = False
return_fragment = True
if parallel_infer:
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
else:
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
# self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
elif parallel_infer and streaming_mode:
print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式"))
parallel_infer = False
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
else:
print(i18n("并行推理模式已关闭"))
print(i18n("朴素推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
if return_fragment:
print(i18n("分段返回模式已开启"))
if split_bucket:
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if return_fragment and streaming_mode:
print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回"))
return_fragment = False
if (return_fragment or streaming_mode) and split_bucket:
print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理"))
split_bucket = False
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
print(i18n("分桶处理模式已开启"))
@ -1063,9 +1105,9 @@ class TTS:
else:
print(i18n("分桶处理模式已关闭"))
if fragment_interval < 0.01:
fragment_interval = 0.01
print(i18n("分段间隔过小已自动设置为0.01"))
# if fragment_interval < 0.01:
# fragment_interval = 0.01
# print(i18n("分段间隔过小已自动设置为0.01"))
no_prompt_text = False
if prompt_text in [None, ""]:
@ -1126,7 +1168,7 @@ class TTS:
###### text preprocessing ########
t1 = time.perf_counter()
data: list = None
if not return_fragment:
if not (return_fragment or streaming_mode):
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0:
yield 16000, np.zeros(int(16000), dtype=np.int16)
@ -1186,10 +1228,11 @@ class TTS:
t_34 = 0.0
t_45 = 0.0
audio = []
is_first_package = True
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
for item in data:
t3 = time.perf_counter()
if return_fragment:
if return_fragment or streaming_mode:
item = make_batch(item)
if item is None:
continue
@ -1211,108 +1254,228 @@ class TTS:
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
)
print(f"############ {i18n('预测语义Token')} ############")
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
max_len=max_len,
repetition_penalty=repetition_penalty,
)
t4 = time.perf_counter()
t_34 += t4 - t3
refer_audio_spec = []
if self.is_v2pro:
sv_emb = []
sv_emb = [] if self.is_v2pro else None
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
spec = spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec)
if self.is_v2pro:
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
batch_audio_fragment = []
if not streaming_mode:
print(f"############ {i18n('预测语义Token')} ############")
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
max_len=max_len,
repetition_penalty=repetition_penalty,
)
t4 = time.perf_counter()
t_34 += t4 - t3
# ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# max_len = 0
# for i in range(0, len(batch_phones)):
# max_len = max(max_len, batch_phones[i].shape[-1])
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# ))
print(f"############ {i18n('合成音频')} ############")
if not self.configs.use_vocoder:
if speed_factor == 1.0:
print(f"{i18n('并行合成中')}...")
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [
pred_semantic_list[i].shape[0] * 2 * upsample_rate
for i in range(0, len(pred_semantic_list))
]
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = (
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
if self.is_v2pro != True:
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else:
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment = [
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
for i in range(1, len(audio_frag_end_idx))
]
else:
# ## vits串行推理
for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
if self.is_v2pro != True:
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else:
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
else:
if parallel_infer:
print(f"{i18n('并行合成中')}...")
audio_fragments = self.using_vocoder_synthesis_batched_infer(
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.extend(audio_fragments)
else:
for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.using_vocoder_synthesis(
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
batch_audio_fragment = []
# ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# max_len = 0
# for i in range(0, len(batch_phones)):
# max_len = max(max_len, batch_phones[i].shape[-1])
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# ))
print(f"############ {i18n('合成音频')} ############")
if not self.configs.use_vocoder:
if speed_factor == 1.0:
print(f"{i18n('并行合成中')}...")
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [
pred_semantic_list[i].shape[0] * 2 * upsample_rate
for i in range(0, len(pred_semantic_list))
]
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = (
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
)
batch_audio_fragment.append(audio_fragment)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment = [
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
for i in range(1, len(audio_frag_end_idx))
]
else:
# ## vits串行推理
for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
else:
if parallel_infer:
print(f"{i18n('并行合成中')}...")
audio_fragments = self.using_vocoder_synthesis_batched_infer(
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.extend(audio_fragments)
else:
for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.using_vocoder_synthesis(
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.append(audio_fragment)
else:
# refer_audio_spec: torch.Tensor = [
# item.to(dtype=self.precision, device=self.configs.device)
# for item in self.prompt_cache["refer_spec"]
# ]
semantic_token_generator =self.t2s_model.model.infer_panel(
all_phoneme_ids[0].unsqueeze(0),
all_phoneme_lens,
prompt,
all_bert_features[0].unsqueeze(0),
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
max_len=max_len,
repetition_penalty=repetition_penalty,
streaming_mode=True,
chunk_length=min_chunk_length,
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None,
chunk_split_thershold=chunk_split_thershold,
)
t4 = time.perf_counter()
t_34 += t4 - t3
phones = batch_phones[0].unsqueeze(0).to(self.configs.device)
is_first_chunk = True
if not self.configs.use_vocoder:
# if speed_factor == 1.0:
# upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
# else:
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
else:
# if speed_factor == 1.0:
# upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
# else:
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
last_audio_chunk = None
# last_tokens = None
last_latent = None
previous_tokens = []
overlap_len = overlap_length
overlap_size = math.ceil(overlap_length*upsample_rate)
for semantic_tokens, is_final in semantic_token_generator:
if semantic_tokens is None and last_audio_chunk is not None:
yield self.audio_postprocess(
[[last_audio_chunk[-overlap_size:]]],
output_sr,
None,
speed_factor,
False,
0.0,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
break
_semantic_tokens = semantic_tokens
print(f"semantic_tokens shape:{semantic_tokens.shape}")
previous_tokens.append(semantic_tokens)
_semantic_tokens = torch.cat(previous_tokens, dim=-1)
if not is_first_chunk and semantic_tokens.shape[-1] < 10:
overlap_len = overlap_length+(10-semantic_tokens.shape[-1])
else:
overlap_len = overlap_length
if not self.configs.use_vocoder:
token_padding_length = 0
# token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1]
# if token_padding_length>0:
# _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486)
# else:
# token_padding_length = 0
audio_chunk, latent, latent_mask = self.vits_model.decode_streaming(
_semantic_tokens.unsqueeze(0),
phones, refer_audio_spec,
speed=speed_factor,
sv_emb=sv_emb,
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
if last_latent is not None else None,
padding_length=token_padding_length
)
audio_chunk=audio_chunk.detach()[0, 0, :]
else:
raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式"))
if overlap_len>overlap_length:
audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):]
audio_chunk_ = audio_chunk
if is_first_chunk and not is_final:
is_first_chunk = False
audio_chunk_ = audio_chunk_[:-overlap_size]
elif is_first_chunk and is_final:
is_first_chunk = False
elif not is_first_chunk and not is_final:
audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size)
audio_chunk_ = (
audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
)
last_latent = latent
last_audio_chunk = audio_chunk
yield self.audio_postprocess(
[[audio_chunk_]],
output_sr,
None,
speed_factor,
False,
0.0,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
if is_first_package:
print(f"first_package_delay: {time.perf_counter()-t0:.3f}")
is_first_package = False
yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16)
t5 = time.perf_counter()
t_45 += t5 - t4
@ -1327,17 +1490,18 @@ class TTS:
fragment_interval,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
elif streaming_mode:...
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield 16000, np.zeros(int(16000), dtype=np.int16)
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
return
if not return_fragment:
if not (return_fragment or streaming_mode):
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if len(audio) == 0:
yield 16000, np.zeros(int(16000), dtype=np.int16)
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
return
yield self.audio_postprocess(
audio,
@ -1384,16 +1548,17 @@ class TTS:
fragment_interval: float = 0.3,
super_sampling: bool = False,
) -> Tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
)
if fragment_interval>0:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
)
for i, batch in enumerate(audio):
for j, audio_fragment in enumerate(batch):
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
if max_audio > 1:
audio_fragment /= max_audio
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment
audio[i][j] = audio_fragment
if split_bucket:
@ -1413,13 +1578,18 @@ class TTS:
max_audio = np.abs(audio).max()
if max_audio > 1:
audio /= max_audio
audio = (audio * 32768).astype(np.int16)
t2 = time.perf_counter()
print(f"超采样用时:{t2 - t1:.3f}s")
else:
# audio = audio.float() * 32768
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy()
audio = audio.cpu().numpy()
audio = (audio * 32768).astype(np.int16)
# try:
# if speed_factor != 1.0:
# audio = speed_change(audio, speed=speed_factor, sr=int(sr))
@ -1612,24 +1782,43 @@ class TTS:
self,
audio_fragments: List[torch.Tensor],
overlap_len: int,
search_len:int= 320
):
# overlap_len-=search_len
dtype = audio_fragments[0].dtype
for i in range(len(audio_fragments) - 1):
f1 = audio_fragments[i]
f2 = audio_fragments[i + 1]
f1 = audio_fragments[i].float()
f2 = audio_fragments[i + 1].float()
w1 = f1[-overlap_len:]
w2 = f2[:overlap_len]
assert w1.shape == w2.shape
corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1), padding=w2.shape[-1] // 2).view(-1)[:-1]
idx = corr.argmax()
f1_ = f1[: -(overlap_len - idx)]
w2 = f2[:overlap_len+search_len]
# w2 = w2[-w2.shape[-1]//2:]
# assert w1.shape == w2.shape
corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1)
corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8
idx = (corr_norm/corr_den.sqrt()).argmax()
print(f"seg_idx: {idx}")
# idx = corr.argmax()
f1_ = f1[: -overlap_len]
audio_fragments[i] = f1_
f2_ = f2[idx:]
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
f2_[: (overlap_len - idx)] = (
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype)
f2_[: overlap_len] = (
window[: overlap_len] * f2_[: overlap_len]
+ window[overlap_len :] * f1[-overlap_len :]
)
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
# f2_[: (overlap_len - idx)] = (
# window * f2_[: (overlap_len - idx)]
# + (1-window) * f1[-(overlap_len - idx) :]
# )
audio_fragments[i + 1] = f2_
return torch.cat(audio_fragments, 0)
return torch.cat(audio_fragments, 0).to(dtype)

View File

@ -261,41 +261,21 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
if padding_mask is not None:
for i in range(batch_size):
# mask = padding_mask[i,:,0]
if self.false.device != padding_mask.device:
self.false = self.false.to(padding_mask.device)
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
x_item = x[i, idx, :].unsqueeze(0)
attn_item = attn[i, idx, :].unsqueeze(0)
x_item = x_item + attn_item
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x_item = x_item + self.mlp.forward(x_item)
x_item = F.layer_norm(
x_item,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
x[i, idx, :] = x_item.squeeze(0)
x = self.to_mask(x, padding_mask)
else:
x = x + attn
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
x = x + attn
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):

View File

@ -417,7 +417,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
)
with gr.Row():
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True)
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
with gr.Row():
temperature = gr.Slider(

View File

@ -37,6 +37,10 @@ from einops import rearrange, repeat
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from module.distrib import broadcast_tensors, is_distributed
from module.ddp_utils import SyncFunction
from tqdm import tqdm
@ -69,27 +73,40 @@ def sample_vectors(samples, num: int):
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
max_kmeans_samples = 500
samples = samples[:max_kmeans_samples, :]
def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64):
N, D = samples.shape
dtype, device = samples.dtype, samples.device
if frames_to_use < N:
indices = torch.randperm(N, device=device)[:frames_to_use]
samples = samples[indices]
means = sample_vectors(samples, num_clusters)
print("kmeans start ... ")
for _ in tqdm(range(num_iters)):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1)
# Store cluster assignments
all_assignments = []
buckets = dists.max(dim=-1).indices
for i in range(0, samples.shape[0], batch_size):
batch = samples[i : i + batch_size] # [B, D]
dists = torch.cdist(batch, means, p=2) # [B, C]
assignments = dists.argmin(dim=1) # [B]
all_assignments.append(assignments)
buckets = torch.cat(all_assignments, dim=0) # [N]
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
# Compute new means
new_means = torch.zeros_like(means)
for i in range(num_clusters):
mask = buckets == i
if mask.any():
new_means[i] = samples[mask].mean(dim=0)
means = torch.where(zero_mask[..., None], means, new_means)
means = torch.where(zero_mask[:, None], means, new_means)
return means, bins
@ -141,13 +158,24 @@ class EuclideanCodebook(nn.Module):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
if dist.is_available() and dist.is_initialized():
# [B * T * world_size, D]
data = SyncFunction.apply(data)
if dist.get_rank() == 0:
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
else:
embed = torch.empty_like(self.embed)
cluster_size = torch.empty_like(self.cluster_size)
dist.broadcast(embed, src=0)
dist.broadcast(cluster_size, src=0)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
# broadcast_tensors(self.buffers())
broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
@ -161,9 +189,17 @@ class EuclideanCodebook(nn.Module):
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
# broadcast_tensors(self.buffers())
if is_distributed():
# [B * T * world_size, D]
batch_samples = SyncFunction.apply(batch_samples)
if dist.get_rank() == 0:
new_embeds = sample_vectors(batch_samples, expired_codes.sum())
else:
new_embeds = torch.zeros(expired_codes.sum(), self.embed.size(1), device=self.embed.device)
dist.broadcast(new_embeds, src=0)
self.embed.data[expired_codes] = new_embeds
broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
@ -208,17 +244,26 @@ class EuclideanCodebook(nn.Module):
quantize = self.dequantize(embed_ind)
if self.training:
### Update codebook by EMA
embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
embed_sum = x.t() @ embed_onehot # [D, cb-size]
if is_distributed():
dist.all_reduce(embed_onehot_sum)
dist.all_reduce(embed_sum)
# Update ema cluster count N_i^t, eq. (6) in vqvae paper
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
# Update ema embed: eq. (7) in vqvae paper
self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
# apply laplace smoothing
n = self.cluster_size.sum()
cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
# Update ema embed: eq. (8) in vqvae paper
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind

View File

@ -0,0 +1,181 @@
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import _find_tensors
from packaging import version
# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
class SyncFunction(torch.autograd.Function):
@staticmethod
# @torch.no_grad()
def forward(ctx, tensor):
world_size = torch.distributed.get_world_size()
# Collect batch sizes from all processes
local_bs = torch.tensor([tensor.shape[0]], device=tensor.device)
batch_sizes = [torch.zeros_like(local_bs) for _ in range(world_size)]
torch.distributed.all_gather(batch_sizes, local_bs)
# Convert to integer list and find the minimum
batch_sizes_int = [bs.item() for bs in batch_sizes]
min_bs = min(batch_sizes_int)
# Crop the tensor to the minimum batch size if needed
cropped_tensor = tensor[:min_bs] if tensor.shape[0] > min_bs else tensor
# Prepare for gathering
out_shape = (min_bs * world_size,) + tensor.shape[1:]
gathered_tensor = torch.zeros(out_shape, dtype=tensor.dtype, device=tensor.device)
# Build tensor list for all_gather
tensor_list = list(torch.chunk(gathered_tensor, world_size))
# Perform all_gather using the cropped tensors
torch.distributed.all_gather(tensor_list, cropped_tensor)
# Save for backward pass
ctx.min_bs = min_bs
ctx.world_size = world_size
ctx.orig_shape = tensor.shape
return gathered_tensor
@staticmethod
def backward(ctx, grad_output):
assert False
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
idx_from = torch.distributed.get_rank() * ctx.batch_size
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
return grad_input[idx_from:idx_to]
class DDP(DistributedDataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def forward(self, *inputs, **kwargs): # pragma: no cover
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
self._sync_params()
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
assert len(self.device_ids) == 1
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
from torch.nn.parallel.distributed import (
Join,
_DDPSink,
_tree_flatten_with_rref,
_tree_unflatten_with_rref,
)
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.logger.set_runtime_stats_and_log()
self.num_iterations += 1
self.reducer.prepare_for_forward()
# Notify the join context that this process has not joined, if
# needed
work = Join.notify_join_context(self)
if work:
self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
# Calling _rebuild_buckets before forward compuation,
# It may allocate new buckets before deallocating old buckets
# inside _rebuild_buckets. To save peak memory usage,
# call _rebuild_buckets before the peak memory usage increases
# during forward computation.
# This should be called only once during whole training period.
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
print("Reducer buckets have been rebuilt in this iteration.")
self._has_rebuilt_buckets = True
# sync params according to location (before/after forward) user
# specified as part of hook, if hook was specified.
buffer_hook_registered = hasattr(self, "buffer_hook")
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
if self._join_config.enable:
# Notify joined ranks whether they should sync in backwards pass or not.
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
# sync params according to location (before/after forward) user
# specified as part of hook, if hook was specified.
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters and not self.static_graph:
# Do not need to populate this for static graph.
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
self.require_forward_param_sync = False
# TODO: DDPSink is currently enabled for unused parameter detection and
# static graph training for first iteration.
if (self.find_unused_parameters and not self.static_graph) or (
self.static_graph and self.num_iterations == 1
):
state_dict = {
"static_graph": self.static_graph,
"num_iterations": self.num_iterations,
}
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
output_placeholders = [None for _ in range(len(output_tensor_list))]
# Do not touch tensors that have no grad_fn, which can cause issues
# such as https://github.com/pytorch/pytorch/issues/60733
for i, output in enumerate(output_tensor_list):
if torch.is_tensor(output) and output.grad_fn is None:
output_placeholders[i] = output
# When find_unused_parameters=True, makes tensors which require grad
# run through the DDPSink backward pass. When not all outputs are
# used in loss, this makes those corresponding tensors receive
# undefined gradient which the reducer then handles to ensure
# param.grad field is not touched and we don't error out.
passthrough_tensor_list = _DDPSink.apply(
self.reducer,
state_dict,
*output_tensor_list,
)
for i in range(len(output_placeholders)):
if output_placeholders[i] is None:
output_placeholders[i] = passthrough_tensor_list[i]
# Reconstruct output data structure.
output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
return output

View File

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import typing as tp
import torch
def rank():
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def is_distributed():
return world_size() > 1
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)
def _is_complex_or_float(tensor):
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
def _check_number_of_params(params: tp.List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
# print('params[0].device ', params[0].device)
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(
f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
)
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if not is_distributed():
return
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
_check_number_of_params(tensors)
handles = []
for tensor in tensors:
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
def sync_buffer(buffers, average=True):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if not is_distributed():
return
handles = []
for buffer in buffers:
if torch.is_floating_point(buffer.data):
if average:
handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
else:
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
handles.append((buffer, handle))
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
def sync_grad(params):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if not is_distributed():
return
handles = []
for p in params:
if p.grad is not None:
handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size()
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))

View File

@ -151,6 +151,8 @@ class DurationPredictor(nn.Module):
return x * x_mask
WINDOW = {}
class TextEncoder(nn.Module):
def __init__(
self,
@ -209,7 +211,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y = self.ssl_proj(y * y_mask) * y_mask
@ -222,13 +224,44 @@ class TextEncoder(nn.Module):
text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask)
y = self.mrte(y, y_mask, text, text_mask, ge)
if padding_length is not None and padding_length!=0:
y = y[:, :, :-padding_length]
y_mask = y_mask[:, :, :-padding_length]
y = self.encoder2(y * y_mask, y_mask)
if result_length is not None:
y = y[:, :, -result_length:]
y_mask = y_mask[:, :, -result_length:]
if overlap_frames is not None:
overlap_len = overlap_frames.shape[-1]
window = WINDOW.get(overlap_len, None)
if window is None:
# WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2))
window = WINDOW[overlap_len]
window = window.to(y.device)
y[:,:,:overlap_len] = (
window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len]
+ window[overlap_len:].view(1, 1, -1) * overlap_frames
)
y_ = y
y_mask_ = y_mask
if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask
return y, m, logs, y_mask, y_, y_mask_
def extract_latent(self, x):
x = self.ssl_proj(x)
@ -921,7 +954,7 @@ class SynthesizerTrn(nn.Module):
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
@ -949,7 +982,7 @@ class SynthesizerTrn(nn.Module):
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -957,6 +990,7 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
def get_ge(refer, sv_emb):
@ -989,7 +1023,7 @@ class SynthesizerTrn(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
quantized,
y_lengths,
text,
@ -1004,6 +1038,59 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
return ge
if type(refer) == list:
ges = []
for idx, _refer in enumerate(refer):
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
ges.append(ge)
ge = torch.stack(ges, 0).mean(0)
else:
ge = get_ge(refer, sv_emb)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
result_length = (2*result_length) if result_length is not None else None
padding_length = (2*padding_length) if padding_length is not None else None
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
result_length=result_length,
overlap_frames=overlap_frames,
padding_length=padding_length
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, y_, y_mask_
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
@ -1226,7 +1313,7 @@ class SynthesizerTrnV3(nn.Module):
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
fea, y_mask_ = self.wns1(
@ -1260,7 +1347,7 @@ class SynthesizerTrnV3(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
####more wn paramter to learn mel
@ -1377,7 +1464,7 @@ class SynthesizerTrnV3b(nn.Module):
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
@ -1420,7 +1507,7 @@ class SynthesizerTrnV3b(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel

View File

@ -124,7 +124,7 @@ def run(rank, n_gpus, hps):
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=4,
prefetch_factor=3,
)
# if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)

View File

@ -118,13 +118,13 @@ def run(rank, n_gpus, hps):
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
train_dataset,
num_workers=6,
num_workers=5,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=4,
prefetch_factor=3,
)
# if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)

View File

@ -120,13 +120,13 @@ def run(rank, n_gpus, hps):
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
train_dataset,
num_workers=6,
num_workers=5,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=4,
prefetch_factor=3,
)
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
os.makedirs(save_root, exist_ok=True)

611
GPT_SoVITS/stream_v2pro.py Normal file
View File

@ -0,0 +1,611 @@
# 这是一个实验性质的实现,旨在探索 stream infer 的可能性。(xiao hai xie zhe wan de)
from typing import List
from export_torch_script import ExportERes2NetV2, SSLModel, T2SModel, VitsModel, get_raw_t2s_model, init_sv_cn, resamplex, sample, spectrogram_torch
import export_torch_script
from my_utils import load_audio
import torch
from torch import LongTensor, Tensor, nn
from torch.nn import functional as F
import soundfile
from inference_webui import get_phones_and_bert
import matplotlib.pyplot as plt
class StreamT2SModel(nn.Module):
def __init__(self, t2s: T2SModel):
super(StreamT2SModel, self).__init__()
self.t2s = t2s
@torch.jit.export
def pre_infer(
self,
prompts: LongTensor,
ref_seq: LongTensor,
text_seq: LongTensor,
ref_bert: torch.Tensor,
text_bert: torch.Tensor,
top_k: int,
) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]:
bert = torch.cat([ref_bert.T, text_bert.T], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0)
x = self.t2s.ar_text_embedding(all_phoneme_ids)
x = x + self.t2s.bert_proj(bert.transpose(1, 2))
x: torch.Tensor = self.t2s.ar_text_position(x)
# [1,N,512] [1,N]
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
y = prompts
# x_example = x[:,:,0] * 0.0
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
y_emb = self.t2s.ar_audio_embedding(y)
y_len: int = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.t2s.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
bsz = x.shape[0]
src_len = x_len + y_len
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = (
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.unsqueeze(0)
.expand(bsz * self.t2s.num_head, -1, -1)
.view(bsz, self.t2s.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.process_prompt(
xy_pos, xy_attn_mask, None
)
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
logits = logits[:, :-1]
samples = sample(
logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0
)[0]
y = torch.concat([y, samples], dim=1)
y_emb: Tensor = self.t2s.ar_audio_embedding(y[:, -1:])
xy_pos: Tensor = (
y_emb * self.t2s.ar_audio_position.x_scale
+ self.t2s.ar_audio_position.alpha
* self.t2s.ar_audio_position.pe[:, y_len].to(
dtype=y_emb.dtype, device=y_emb.device
)
)
return y_len, y, xy_pos, k_cache, v_cache
@torch.jit.export
def decode_next_token(
self,
idx: int, # 记住从1开始 到1500
top_k: int,
y_len: int,
y: Tensor,
xy_pos: Tensor,
k_cache: List[Tensor],
v_cache: List[Tensor],
) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]:
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token(
xy_pos, k_cache, v_cache
)
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0
)[0]
y = torch.concat([y, samples], dim=1)
last_token = int(samples[0, 0])
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
# stop = True
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache
# if stop:
# if y.shape[1] == 0:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# break
y_emb = self.t2s.ar_audio_embedding(y[:, -1:])
xy_pos = (
y_emb * self.t2s.ar_audio_position.x_scale
+ self.t2s.ar_audio_position.alpha
* self.t2s.ar_audio_position.pe[:, y_len + idx].to(
dtype=y_emb.dtype, device=y_emb.device
)
)
return y, xy_pos, last_token, k_cache, v_cache
def forward(
self,
idx: int, # 记住从1开始 到1500
top_k: int,
y_len: int,
y: Tensor,
xy_pos: Tensor,
k_cache: List[Tensor],
v_cache: List[Tensor],
):
return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache)
class StepVitsModel(nn.Module):
def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2):
super().__init__()
self.hps = vits.hps
self.vq_model = vits.vq_model
self.hann_window = vits.hann_window
self.sv = sv_model
def ref_handle(self, ref_audio_32k):
refer = spectrogram_torch(
self.hann_window,
ref_audio_32k.float(),
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False,
)
refer = refer.to(ref_audio_32k.dtype)
ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device)
sv_emb = self.sv(ref_audio_16k)
return refer, sv_emb
def extract_latent(self, ssl_content):
codes = self.vq_model.extract_latent(ssl_content)
return codes[0]
def forward(self, pred_semantic, text_seq, refer, sv_emb=None):
return self.vq_model(
pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb
)[0, 0]
@torch.jit.script
def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor):
ref_len = len(reference_audio)
search_len = len(search_audio)
if search_len < ref_len:
raise ValueError(
f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})"
)
# 使用F.conv1d计算原始互相关
reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0)
search_padded = search_audio.unsqueeze(0).unsqueeze(0)
# 计算点积
dot_products = F.conv1d(search_padded, reference_flipped).squeeze()
if len(dot_products.shape) == 0:
dot_products = dot_products.unsqueeze(0)
# 计算参考音频的平方和
ref_squared_sum = torch.sum(reference_audio**2)
# 计算搜索音频每个位置的平方和(滑动窗口)
search_squared = search_audio**2
search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0)
ones_kernel = torch.ones(
1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device
)
segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze()
if len(segment_squared_sums.shape) == 0:
segment_squared_sums = segment_squared_sums.unsqueeze(0)
# 计算归一化因子
ref_norm = torch.sqrt(ref_squared_sum)
segment_norms = torch.sqrt(segment_squared_sums)
# 避免除零
epsilon = 1e-8
normalization_factor = ref_norm * segment_norms + epsilon
# 归一化互相关
correlation_scores = dot_products / normalization_factor
best_offset = torch.argmax(correlation_scores).item()
return best_offset, correlation_scores
import time
def test_stream(
gpt_path,
vits_path,
version,
ref_audio_path,
ref_text,
output_path,
device="cpu",
is_half=True,
):
if export_torch_script.sv_cn_model == None:
init_sv_cn(device,is_half)
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
print(f"device: {device}")
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
ref_text, "all_zh", "v2"
)
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T
if is_half:
ref_bert = ref_bert.half()
ref_bert = ref_bert.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T
if is_half:
text_bert = text_bert.half()
text_bert = text_bert.to(text_seq.device)
ssl_content = ssl(ref_audio)
if is_half:
ssl_content = ssl_content.half()
ssl_content = ssl_content.to(device)
sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
# dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path, weights_only=False)
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print("#### get_raw_t2s_model ####")
print(raw_t2s.config)
if is_half:
raw_t2s = raw_t2s.half()
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
# t2s = torch.jit.script(t2s_m).to(device)
t2s = t2s_m
print("#### script t2s_m ####")
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
stream_t2s = StreamT2SModel(t2s).to(device)
stream_t2s = torch.jit.script(stream_t2s)
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
if is_half:
ref_audio_sr = ref_audio_sr.half()
ref_audio_sr = ref_audio_sr.to(device)
top_k = 15
codes = vits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
sv_emb = sv_model(audio_16k)
print("text_seq",text_seq.shape)
refer = spectrogram_torch(
vits.hann_window,
ref_audio_sr,
vits.hps.data.filter_length,
vits.hps.data.sampling_rate,
vits.hps.data.hop_length,
vits.hps.data.win_length,
center=False,
)
st = time.time()
et = time.time()
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
idx = 1
last_idx = 0
audios = []
raw_audios = []
last_audio_ret = None
offset_index = []
full_audios = []
print("y.shape:", y.shape)
cut_id = 0
while True:
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
# print("y.shape:", y.shape)
stop = last_token==t2s.EOS
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
cut_id = idx + 7
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
# y = torch.cat([y, y[:,-1:]], dim=1)
# idx+=1
if stop :
idx -=1
print('stop')
print(idx, y[:,-idx+last_idx:])
print(idx,last_idx, y.shape)
print(y[:,-idx:-idx+20])
# 玄学这档子事说不清楚
if idx == cut_id or stop:
print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}")
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
full_audios.append(audio)
if last_idx == 0:
last_audio_ret = audio[-1280*8:-1280*8+256]
audio = audio[:-1280*8]
raw_audios.append(audio)
et = time.time()
else:
if stop:
audio_ = audio[last_idx*1280 -1280*8:]
raw_audios.append(audio_)
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
offset_index.append(i)
audio = audio_[i:]
else:
audio_ = audio[last_idx*1280 -1280*8:-1280*8]
raw_audios.append(audio_)
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
offset_index.append(i)
last_audio_ret = audio[-1280*8:-1280*8+256]
audio = audio_[i:]
last_idx = idx
# print(f'write {output_path}/out_{audio_index}')
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
audios.append(audio)
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
if idx>1500:
break
if stop:
break
idx+=1
at = time.time()
for (i,a) in enumerate(audios):
print(f'write {output_path}/out_{i}')
soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000)
print(f"frist token: {et - st:.4f} seconds")
print(f"all token: {at - st:.4f} seconds")
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
audio = torch.cat(audios, dim=0)
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
audio_raw = torch.cat(raw_audios, dim=0)
soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000)
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
max_duration = full_audios[-1].shape[0]
plt.xlim(0, max_duration)
last_line = 0
for i,a in enumerate(full_audios):
plt.plot((a+2.0*i).float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}")
# plt.axvline(x=last_line, color=colors[i], linestyle='--')
last_line = a.shape[0]-8*1280
plt.axvline(x=last_line, color=colors[i], linestyle='--')
plt.plot((audio-2.0).float().detach().cpu().numpy(), color='black', label='Final Audio')
plt.plot((audio_raw-4.0).float().detach().cpu().numpy(), color='cyan', label='Raw Audio')
print("offset_index:", offset_index)
plt.show()
def export_prov2(
gpt_path,
vits_path,
version,
ref_audio_path,
ref_text,
output_path,
device="cpu",
is_half=True,
lang="auto",
):
if export_torch_script.sv_cn_model == None:
init_sv_cn(device,is_half)
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
print(f"device: {device}")
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
ref_text, lang, "v2"
)
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T
if is_half:
ref_bert = ref_bert.half()
ref_bert = ref_bert.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一个简单的示例,真没想到这么简单就完成了.The King and His Stories.Once there was a king.He likes to write stories, but his stories were not good.", "auto", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T
if is_half:
text_bert = text_bert.half()
text_bert = text_bert.to(text_seq.device)
ssl_content = ssl(ref_audio)
if is_half:
ssl_content = ssl_content.half()
ssl_content = ssl_content.to(device)
sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
vits.eval()
vits = StepVitsModel(vits, sv_model)
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
# dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path, weights_only=False)
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print("#### get_raw_t2s_model ####")
print(raw_t2s.config)
if is_half:
raw_t2s = raw_t2s.half()
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
# t2s = torch.jit.script(t2s_m).to(device)
t2s = t2s_m
print("#### script t2s_m ####")
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
stream_t2s = StreamT2SModel(t2s).to(device)
stream_t2s = torch.jit.script(stream_t2s)
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
ref_audio_sr = ref_audio_sr.to(device)
if is_half:
ref_audio_sr = ref_audio_sr.half()
top_k = 15
prompts = vits.extract_latent(ssl_content)
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
sv_emb = sv_model(audio_16k)
print("text_seq",text_seq.shape)
# torch.jit.trace()
refer,sv_emb = vits.ref_handle(ref_audio_sr)
st = time.time()
et = time.time()
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
idx = 1
print("y.shape:", y.shape)
while True:
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
# print("y.shape:", y.shape)
idx+=1
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
if idx>1500:
break
if last_token == t2s.EOS:
break
at = time.time()
print("EOS:",t2s.EOS)
print(f"frist token: {et - st:.4f} seconds")
print(f"all token: {at - st:.4f} seconds")
print("sv_emb", sv_emb.shape)
print("refer",refer.shape)
y = y[:,-idx:].unsqueeze(0)
print("y", y.shape)
audio = vits(y, text_seq, refer, sv_emb)
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
torch._dynamo.mark_dynamic(ssl_content, 2)
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
torch._dynamo.mark_dynamic(ref_seq, 1)
torch._dynamo.mark_dynamic(text_seq, 1)
torch._dynamo.mark_dynamic(ref_bert, 0)
torch._dynamo.mark_dynamic(text_bert, 0)
torch._dynamo.mark_dynamic(refer, 2)
torch._dynamo.mark_dynamic(y, 2)
inputs = {
"forward": (y, text_seq, refer, sv_emb),
"extract_latent": ssl_content,
"ref_handle": ref_audio_sr,
}
stream_t2s.save(f"{output_path}/t2s.pt")
torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt")
torch.jit.script(find_best_audio_offset_fast, optimize=True).save(f"{output_path}/find_best_audio_offset_fast.pt")
import argparse
import os
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument(
"--sovits_model", required=True, help="Path to the SoVITS model file"
)
parser.add_argument(
"--ref_audio", required=True, help="Path to the reference audio file"
)
parser.add_argument(
"--ref_text", required=True, help="Path to the reference text file"
)
parser.add_argument(
"--output_path", required=True, help="Path to the output directory"
)
parser.add_argument("--device", help="Device to use", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--version", help="version of the model", default="v2Pro")
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
parser.add_argument("--lang", default="auto", help="Language for text processing (default: auto)")
args = parser.parse_args()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
is_half = not args.no_half
with torch.no_grad():
export_prov2(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
version=args.version,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
device=args.device,
is_half=is_half,
lang=args.lang,
)

View File

@ -238,6 +238,46 @@ def _expand_number(m):
return _inflect.number_to_words(num, andword="")
# 加减乘除
RE_ASMD = re.compile(
r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))\s+([\+\-\×÷=])\s+((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
)
# RE_ASMD = re.compile(
# r"\b((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))\b"
# )
asmd_map = {"+": " plus ", "-": " minus ", "×": " times ", "÷": " divided by ", "=": " Equals "}
def replace_asmd(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
return result
RE_INTEGER = re.compile(r"(?:^|\s+)(-)" r"(\d+)")
def replace_negative_num(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
sign = match.group(1)
number = match.group(2)
sign: str = "negative " if sign else ""
result = f"{sign}{number}"
return result
def normalize(text):
"""
!!! 所有的处理都需要正确的输入 !!!
@ -245,7 +285,13 @@ def normalize(text):
"""
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
# 处理数学运算
# 替换text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
while RE_ASMD.search(text):
text = RE_ASMD.sub(replace_asmd, text)
text = RE_INTEGER.sub(replace_negative_num, text)
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_time_re, _expand_time, text)
text = re.sub(_measurement_re, _expand_measurement, text)

View File

@ -347,7 +347,7 @@ Use v4 from v1/v2/v3 environment:
2. Clone the latest codes from github.
3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.ckpt, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS/pretrained_models`.
3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.pth, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS/pretrained_models`.
## V2Pro Release Notes

126
api_v2.py
View File

@ -27,20 +27,23 @@ POST:
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_k": 15, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"streaming_mode": False, # bool. whether to return a streaming response.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
}
```
@ -101,7 +104,7 @@ RESP:
import os
import sys
import traceback
from typing import Generator
from typing import Generator, Union
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -121,6 +124,7 @@ from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from pydantic import BaseModel
import threading
# print(sys.path)
i18n = I18nAuto()
@ -154,7 +158,7 @@ class TTS_Request(BaseModel):
aux_ref_audio_paths: list = None
prompt_lang: str = None
prompt_text: str = ""
top_k: int = 5
top_k: int = 15
top_p: float = 1
temperature: float = 1
text_split_method: str = "cut5"
@ -165,17 +169,58 @@ class TTS_Request(BaseModel):
fragment_interval: float = 0.3
seed: int = -1
media_type: str = "wav"
streaming_mode: bool = False
streaming_mode: Union[bool, int] = False
parallel_infer: bool = True
repetition_penalty: float = 1.35
sample_steps: int = 32
super_sampling: bool = False
overlap_length: int = 2
min_chunk_length: int = 16
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
# Author: AkagawaTsurunaki
# Issue:
# Stack overflow probabilistically occurs
# when the function `sf_writef_short` of `libsndfile_64bit.dll` is called
# using the Python library `soundfile`
# Note:
# This is an issue related to `libsndfile`, not this project itself.
# It happens when you generate a large audio tensor (about 499804 frames in my PC)
# and try to convert it to an ogg file.
# Related:
# https://github.com/RVC-Boss/GPT-SoVITS/issues/1199
# https://github.com/libsndfile/libsndfile/issues/1023
# https://github.com/bastibe/python-soundfile/issues/396
# Suggestion:
# Or split the whole audio data into smaller audio segment to avoid stack overflow?
def handle_pack_ogg():
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
# See: https://docs.python.org/3/library/threading.html
# The stack size of this thread is at least 32768
# If stack overflow error still occurs, just modify the `stack_size`.
# stack_size = n * 4096, where n should be a positive integer.
# Here we chose n = 4096.
stack_size = 4096 * 4096
try:
threading.stack_size(stack_size)
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
pack_ogg_thread.start()
pack_ogg_thread.join()
except RuntimeError as e:
# If changing the thread stack size is unsupported, a RuntimeError is raised.
print("RuntimeError: {}".format(e))
print("Changing the thread stack size is unsupported.")
except ValueError as e:
# If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified.
print("ValueError: {}".format(e))
print("The specified stack size is invalid.")
return io_buffer
@ -286,8 +331,8 @@ def check_params(req: dict):
)
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
# elif media_type == "ogg" and not streaming_mode:
# return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
if text_split_method not in cut_method_names:
return JSONResponse(
@ -307,25 +352,26 @@ async def tts_handle(req: dict):
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_k": 15, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
}
returns:
StreamingResponse: audio stream response.
@ -338,9 +384,35 @@ async def tts_handle(req: dict):
check_res = check_params(req)
if check_res is not None:
return check_res
if streaming_mode == 0:
streaming_mode = False
return_fragment = False
fixed_length_chunk = False
elif streaming_mode == 1:
streaming_mode = False
return_fragment = True
fixed_length_chunk = False
elif streaming_mode == 2:
streaming_mode = True
return_fragment = False
fixed_length_chunk = False
elif streaming_mode == 3:
streaming_mode = True
return_fragment = False
fixed_length_chunk = True
else:
return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"})
req["streaming_mode"] = streaming_mode
req["return_fragment"] = return_fragment
req["fixed_length_chunk"] = fixed_length_chunk
print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}")
streaming_mode = streaming_mode or return_fragment
if streaming_mode or return_fragment:
req["return_fragment"] = True
try:
tts_generator = tts_pipeline.run(req)
@ -388,10 +460,10 @@ async def tts_get_endpoint(
aux_ref_audio_paths: list = None,
prompt_lang: str = None,
prompt_text: str = "",
top_k: int = 5,
top_k: int = 15,
top_p: float = 1,
temperature: float = 1,
text_split_method: str = "cut0",
text_split_method: str = "cut5",
batch_size: int = 1,
batch_threshold: float = 0.75,
split_bucket: bool = True,
@ -399,11 +471,13 @@ async def tts_get_endpoint(
fragment_interval: float = 0.3,
seed: int = -1,
media_type: str = "wav",
streaming_mode: bool = False,
parallel_infer: bool = True,
repetition_penalty: float = 1.35,
sample_steps: int = 32,
super_sampling: bool = False,
streaming_mode: Union[bool, int] = False,
overlap_length: int = 2,
min_chunk_length: int = 16,
):
req = {
"text": text,
@ -428,6 +502,8 @@ async def tts_get_endpoint(
"repetition_penalty": float(repetition_penalty),
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
"overlap_length": int(overlap_length),
"min_chunk_length": int(min_chunk_length),
}
return await tts_handle(req)

View File

@ -373,7 +373,7 @@ if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ]; then
location=$(pip show torch | grep Location | awk -F ": " '{print $2}')
cd "${location}"/torch/lib/ || exit
rm libhsa-runtime64.so*
cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so
cp "$(readlink -f /opt/rocm/lib/libhsa-runtime64.so)" libhsa-runtime64.so
echo -e "${SUCCESS}ROCm Runtime Lib Updated..."
fi

View File

@ -16,10 +16,10 @@ pypinyin
pyopenjtalk>=0.4.1
g2p_en
torchaudio
modelscope==1.10.0
modelscope
sentencepiece
transformers>=4.43,<=4.50
peft
peft<0.18.0
chardet
PyYAML
psutil
@ -39,7 +39,5 @@ x_transformers
torchmetrics<=1.5
pydantic<=2.10.6
ctranslate2>=4.0,<5
huggingface_hub>=0.13
tokenizers>=0.13,<1
av>=11
tqdm

View File

@ -1,34 +1,13 @@
import os
def check_fw_local_models():
"""
启动时检查本地是否有 Faster Whisper 模型.
"""
model_size_list = [
"medium",
"medium.en",
"distil-large-v2",
"distil-large-v3",
"large-v1",
"large-v2",
"large-v3",
]
for i, size in enumerate(model_size_list):
if os.path.exists(f"tools/asr/models/faster-whisper-{size}"):
model_size_list[i] = size + "-local"
return model_size_list
def get_models():
model_size_list = [
"medium",
"medium.en",
"distil-large-v2",
"distil-large-v3",
"large-v1",
"large-v2",
"large-v3",
"large-v3-turbo",
#"distil-large-v2",
#"distil-large-v3",
#"distil-large-v3.5",
]
return model_size_list
@ -36,7 +15,7 @@ def get_models():
asr_dict = {
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
"Faster Whisper (多语种)": {
"lang": ["auto", "zh", "en", "ja", "ko", "yue"],
"lang": ["auto", "en", "ja", "ko"],
"size": get_models(),
"path": "fasterwhisper_asr.py",
"precision": ["float32", "float16", "int8"],

View File

@ -1,12 +1,12 @@
import argparse
import os
import time
import traceback
import requests
import torch
from faster_whisper import WhisperModel
from huggingface_hub import snapshot_download
from huggingface_hub.errors import LocalEntryNotFoundError
from huggingface_hub import snapshot_download as snapshot_download_hf
from modelscope import snapshot_download as snapshot_download_ms
from tqdm import tqdm
from tools.asr.config import get_models
@ -40,11 +40,32 @@ language_code_list = [
def download_model(model_size: str):
if "distil" in model_size:
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
url = "https://huggingface.co/api/models/gpt2"
try:
requests.get(url, timeout=3)
source = "HF"
except Exception:
source = "ModelScope"
model_path = ""
if source == "HF":
if "distil" in model_size:
if "3.5" in model_size:
repo_id = "distil-whisper/distil-large-v3.5-ct2"
model_path = "tools/asr/models/faster-distil-whisper-large-v3.5"
else:
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
elif model_size == "large-v3-turbo":
repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
model_path = "tools/asr/models/faster-whisper-large-v3-turbo"
else:
repo_id = f"Systran/faster-whisper-{model_size}"
model_path = (
model_path or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}"
)
else:
repo_id = f"Systran/faster-whisper-{model_size}"
model_path = f"tools/asr/models/{repo_id.strip('Systran/')}"
repo_id = "XXXXRT/faster-whisper"
model_path = "tools/asr/models"
files: list[str] = [
"config.json",
@ -52,32 +73,31 @@ def download_model(model_size: str):
"tokenizer.json",
"vocabulary.txt",
]
if model_size == "large-v3" or "distil" in model_size:
if "large-v3" in model_size or "distil" in model_size:
files.append("preprocessor_config.json")
files.append("vocabulary.json")
files.remove("vocabulary.txt")
for attempt in range(2):
try:
snapshot_download(
repo_id=repo_id,
allow_patterns=files,
local_dir=model_path,
)
break
except LocalEntryNotFoundError:
if attempt < 1:
time.sleep(2)
else:
print("[ERROR] LocalEntryNotFoundError and no fallback.")
traceback.print_exc()
exit(1)
except Exception as e:
print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}")
traceback.print_exc()
exit(1)
if source == "ModelScope":
files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files]
if source == "HF":
print(f"Downloading model from HuggingFace: {repo_id} to {model_path}")
snapshot_download_hf(
repo_id,
local_dir=model_path,
local_dir_use_symlinks=False,
allow_patterns=files,
)
else:
print(f"Downloading model from ModelScope: {repo_id} to {model_path}")
snapshot_download_ms(
repo_id,
local_dir=model_path,
allow_patterns=files,
)
return model_path + f"/faster-whisper-{model_size}".replace("whisper-distil", "distil-whisper")
return model_path
@ -106,7 +126,7 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
)
text = ""
if info.language == "zh":
if info.language in ["zh", "yue"]:
print("检测为中文文本, 转 FunASR 处理")
text = only_asr(file_path, language=info.language.lower())

View File

@ -4,9 +4,8 @@ import argparse
import os
import traceback
# from funasr.utils import version_checker
# version_checker.check_for_update = lambda: None
from funasr import AutoModel
from modelscope import snapshot_download
from tqdm import tqdm
funasr_models = {} # 存储模型避免重复加载
@ -16,40 +15,43 @@ def only_asr(input_file, language):
try:
model = create_model(language)
text = model.generate(input=input_file)[0]["text"]
except:
except Exception:
text = ""
print(traceback.format_exc())
return text
def create_model(language="zh"):
path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch"
path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
path_vad = path_vad if os.path.exists(path_vad) else "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
path_punc = path_punc if os.path.exists(path_punc) else "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
vad_model_revision = punc_model_revision = "v2.0.4"
if language == "zh":
path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch"
path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
path_asr = "tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
path_asr = (
path_asr
if os.path.exists(path_asr)
else "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
snapshot_download(
"iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
local_dir="tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch",
)
snapshot_download(
"iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
local_dir="tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
)
snapshot_download(
"iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
local_dir="tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
)
model_revision = "v2.0.4"
elif language == "yue":
path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
path_asr = (
path_asr
if os.path.exists(path_asr)
else "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
snapshot_download(
"iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
local_dir="tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
)
model_revision = "master"
path_vad = path_punc = None
vad_model_revision = punc_model_revision = None
###友情提示粤语带VAD识别可能会有少量shape不对报错的但是不带VAD可以.不带vad只能分阶段单独加标点。不过标点模型对粤语效果真的不行…
vad_model_revision = punc_model_revision = ""
model_revision = "master"
else:
raise ValueError("FunASR 不支持该语言" + ": " + language)
raise ValueError(f"{language} is not supported")
vad_model_revision = punc_model_revision = "v2.0.4"
if language in funasr_models:
return funasr_models[language]
@ -83,7 +85,7 @@ def execute_asr(input_folder, output_folder, model_size, language):
file_path = os.path.join(input_folder, file_name)
text = model.generate(input=file_path)[0]["text"]
output.append(f"{file_path}|{output_file_name}|{language.upper()}|{text}")
except:
except Exception:
print(traceback.format_exc())
output_folder = output_folder or "output/asr_opt"

View File

@ -38,7 +38,7 @@
"hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: FO hop size, the smaller the value, the higher the accuracy",
"max:归一化后最大值多少": "Loudness multiplier after normalized",
"max_sil_kept:切完后静音最多留多长": "Maximum length for silence to be kept",
"min_interval:最短切割间隔": "Minumum interval for audio cutting",
"min_interval:最短切割间隔": "Minimum interval for audio cutting",
"min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: the minimum length of each segment. If the first segment is too short, it will be concatenated with the next segment until it exceeds this value",
"temperature": "temperature",
"threshold:音量小于这个值视作静音的备选切割点": "Noise gate threshold (loudness below this value will be treated as noise",
@ -176,7 +176,7 @@
"语音降噪": "Speech Denoising",
"请上传3~10秒内参考音频超过会报错": "Please upload a reference audio within the 3-10 second range; if it exceeds this duration, it will raise errors.",
"请上传参考音频": "Please Upload the Reference Audio",
"请填入推理文本": "Please Fill in the Terget Text",
"请填入推理文本": "Please Fill in the Target Text",
"请填入正确的List路径": "Please Fill in the Correct List Path",
"请填入正确的音频文件夹路径": "Please Fill in the Correct Audio Folder Path",
"请输入有效文本": "Please enter valid text.",

View File

@ -86,7 +86,6 @@ from config import (
from tools import my_utils
from tools.my_utils import check_details, check_for_existance
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
@ -117,8 +116,8 @@ def set_default():
gpu_info = "\n".join(gpu_infos)
if is_gpu_ok:
minmem = min(mem)
default_batch_size = minmem // 2 if version not in v3v4set else minmem // 8
default_batch_size_s1 = minmem // 2
default_batch_size = int(minmem // 2 if version not in v3v4set else minmem // 8)
default_batch_size_s1 = int(minmem // 2)
else:
default_batch_size = default_batch_size_s1 = int(psutil.virtual_memory().total / 1024 / 1024 / 1024 / 4)
if version not in v3v4set:
@ -343,7 +342,7 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
os.environ["sovits_path"] = sovits_path
os.environ["cnhubert_base_path"] = cnhubert_base_path
os.environ["bert_path"] = bert_path
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_number(gpu_number)
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
os.environ["is_half"] = str(is_half)
os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
os.environ["is_share"] = str(is_share)
@ -628,7 +627,7 @@ def open1Bb(
data["output_dir"] = "%s/logs_s1_%s" % (s1_dir, version)
# data["version"]=version
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(gpu_numbers.replace("-", ","))
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ",")))
os.environ["hz"] = "25hz"
tmp_config_path = "%s/tmp_s1.yaml" % tmp
with open(tmp_config_path, "w") as f:
@ -801,7 +800,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
{
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
"is_half": str(is_half),
}
)
@ -892,7 +891,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
{
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
os.environ.update(config)
@ -914,7 +913,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
{
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
os.environ.update(config)
@ -986,7 +985,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
{
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
os.environ.update(config)
@ -1086,7 +1085,7 @@ def open1abc(
{
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
os.environ.update(config)
@ -1133,7 +1132,7 @@ def open1abc(
{
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
os.environ.update(config)
@ -1155,7 +1154,7 @@ def open1abc(
{
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
os.environ.update(config)
@ -1195,7 +1194,7 @@ def open1abc(
{
"i_part": str(i_part),
"all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
os.environ.update(config)