mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-03 20:40:30 +08:00
Merge branch 'RVC-Boss:main' into main
This commit is contained in:
commit
0235857b89
@ -707,10 +707,12 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
if idx == 0:
|
||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||
logits = logits[:, :-1]
|
||||
else:
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
@ -794,7 +796,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list = []
|
||||
idx_list = []
|
||||
for i in range(len(x)):
|
||||
y, idx = self.infer_panel_naive(
|
||||
y, idx = next(self.infer_panel_naive(
|
||||
x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
@ -805,7 +807,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs,
|
||||
)
|
||||
))
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
@ -822,8 +824,15 @@ class Text2SemanticDecoder(nn.Module):
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
streaming_mode: bool = False,
|
||||
chunk_length: int = 24,
|
||||
**kwargs,
|
||||
):
|
||||
mute_emb_sim_matrix = kwargs.get("mute_emb_sim_matrix", None)
|
||||
chunk_split_thershold = kwargs.get("chunk_split_thershold", 0.3)
|
||||
check_token_num = 2
|
||||
|
||||
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_position(x)
|
||||
@ -875,7 +884,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
token_counter = 0
|
||||
curr_ptr = prefix_len
|
||||
for idx in tqdm(range(1500)):
|
||||
token_counter+=1
|
||||
if xy_attn_mask is not None:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||
else:
|
||||
@ -900,22 +912,56 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
stop = True
|
||||
y=y[:, :-1]
|
||||
token_counter -= 1
|
||||
|
||||
if idx == 1499:
|
||||
stop = True
|
||||
|
||||
if stop:
|
||||
if y.shape[1] == 0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
# print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
if streaming_mode:
|
||||
yield y[:, curr_ptr:] if curr_ptr<y.shape[1] else None, True
|
||||
break
|
||||
|
||||
|
||||
if streaming_mode and (mute_emb_sim_matrix is not None) and (token_counter >= chunk_length+check_token_num):
|
||||
score = mute_emb_sim_matrix[y[0, curr_ptr:]] - chunk_split_thershold
|
||||
score[score<0]=-1
|
||||
score[:-1]=score[:-1]+score[1:] ##考虑连续两个token
|
||||
argmax_idx = score.argmax()
|
||||
|
||||
if score[argmax_idx]>=0 and argmax_idx+1>=chunk_length:
|
||||
print(f"\n\ncurr_ptr:{curr_ptr}")
|
||||
yield y[:, curr_ptr:], False
|
||||
token_counter -= argmax_idx+1
|
||||
curr_ptr += argmax_idx+1
|
||||
|
||||
|
||||
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
|
||||
yield y[:, -token_counter:], False
|
||||
curr_ptr+=token_counter
|
||||
token_counter = 0
|
||||
|
||||
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx
|
||||
|
||||
|
||||
if not streaming_mode:
|
||||
if ref_free:
|
||||
yield y, 0
|
||||
yield y, idx
|
||||
|
||||
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
@ -930,6 +976,6 @@ class Text2SemanticDecoder(nn.Module):
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs,
|
||||
):
|
||||
return self.infer_panel_naive(
|
||||
return next(self.infer_panel_naive(
|
||||
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||
)
|
||||
))
|
||||
|
||||
@ -275,6 +275,15 @@ class TTS_Config:
|
||||
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"]
|
||||
languages: list = v2_languages
|
||||
mute_tokens: dict = {
|
||||
"v1" : 486,
|
||||
"v2" : 486,
|
||||
"v2Pro": 486,
|
||||
"v2ProPlus": 486,
|
||||
"v3" : 486,
|
||||
"v4" : 486,
|
||||
}
|
||||
mute_emb_sim_matrix: torch.Tensor = None
|
||||
# "all_zh",#全部按中文识别
|
||||
# "en",#全部按英文识别#######不变
|
||||
# "all_ja",#全部按日文识别
|
||||
@ -598,6 +607,11 @@ class TTS:
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
codebook = t2s_model.model.ar_audio_embedding.weight.clone()
|
||||
mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0)
|
||||
sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1)
|
||||
self.configs.mute_emb_sim_matrix = sim_matrix
|
||||
|
||||
def init_vocoder(self, version: str):
|
||||
if version == "v3":
|
||||
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
|
||||
@ -994,21 +1008,25 @@ class TTS:
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_k": 15, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
||||
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||||
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||||
}
|
||||
returns:
|
||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||
@ -1021,10 +1039,10 @@ class TTS:
|
||||
aux_ref_audio_paths: list = inputs.get("aux_ref_audio_paths", [])
|
||||
prompt_text: str = inputs.get("prompt_text", "")
|
||||
prompt_lang: str = inputs.get("prompt_lang", "")
|
||||
top_k: int = inputs.get("top_k", 5)
|
||||
top_k: int = inputs.get("top_k", 15)
|
||||
top_p: float = inputs.get("top_p", 1)
|
||||
temperature: float = inputs.get("temperature", 1)
|
||||
text_split_method: str = inputs.get("text_split_method", "cut0")
|
||||
text_split_method: str = inputs.get("text_split_method", "cut1")
|
||||
batch_size = inputs.get("batch_size", 1)
|
||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||
speed_factor = inputs.get("speed_factor", 1.0)
|
||||
@ -1038,19 +1056,43 @@ class TTS:
|
||||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||
sample_steps = inputs.get("sample_steps", 32)
|
||||
super_sampling = inputs.get("super_sampling", False)
|
||||
streaming_mode = inputs.get("streaming_mode", False)
|
||||
overlap_length = inputs.get("overlap_length", 2)
|
||||
min_chunk_length = inputs.get("min_chunk_length", 16)
|
||||
fixed_length_chunk = inputs.get("fixed_length_chunk", False)
|
||||
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。
|
||||
|
||||
if parallel_infer:
|
||||
if parallel_infer and not streaming_mode:
|
||||
print(i18n("并行推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||
elif not parallel_infer and streaming_mode and not self.configs.use_vocoder:
|
||||
print(i18n("流式推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||
elif streaming_mode and self.configs.use_vocoder:
|
||||
print(i18n("SoVits V3/4模型不支持流式推理模式,已自动回退到分段返回模式"))
|
||||
streaming_mode = False
|
||||
return_fragment = True
|
||||
if parallel_infer:
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||
else:
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||||
# self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||
elif parallel_infer and streaming_mode:
|
||||
print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式"))
|
||||
parallel_infer = False
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||||
else:
|
||||
print(i18n("并行推理模式已关闭"))
|
||||
print(i18n("朴素推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||||
|
||||
if return_fragment:
|
||||
print(i18n("分段返回模式已开启"))
|
||||
if split_bucket:
|
||||
split_bucket = False
|
||||
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
||||
if return_fragment and streaming_mode:
|
||||
print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回"))
|
||||
return_fragment = False
|
||||
|
||||
if (return_fragment or streaming_mode) and split_bucket:
|
||||
print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理"))
|
||||
split_bucket = False
|
||||
|
||||
|
||||
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
|
||||
print(i18n("分桶处理模式已开启"))
|
||||
@ -1063,9 +1105,9 @@ class TTS:
|
||||
else:
|
||||
print(i18n("分桶处理模式已关闭"))
|
||||
|
||||
if fragment_interval < 0.01:
|
||||
fragment_interval = 0.01
|
||||
print(i18n("分段间隔过小,已自动设置为0.01"))
|
||||
# if fragment_interval < 0.01:
|
||||
# fragment_interval = 0.01
|
||||
# print(i18n("分段间隔过小,已自动设置为0.01"))
|
||||
|
||||
no_prompt_text = False
|
||||
if prompt_text in [None, ""]:
|
||||
@ -1126,7 +1168,7 @@ class TTS:
|
||||
###### text preprocessing ########
|
||||
t1 = time.perf_counter()
|
||||
data: list = None
|
||||
if not return_fragment:
|
||||
if not (return_fragment or streaming_mode):
|
||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
||||
if len(data) == 0:
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
@ -1186,10 +1228,11 @@ class TTS:
|
||||
t_34 = 0.0
|
||||
t_45 = 0.0
|
||||
audio = []
|
||||
is_first_package = True
|
||||
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
||||
for item in data:
|
||||
t3 = time.perf_counter()
|
||||
if return_fragment:
|
||||
if return_fragment or streaming_mode:
|
||||
item = make_batch(item)
|
||||
if item is None:
|
||||
continue
|
||||
@ -1211,108 +1254,228 @@ class TTS:
|
||||
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||
)
|
||||
|
||||
print(f"############ {i18n('预测语义Token')} ############")
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_lens,
|
||||
prompt,
|
||||
all_bert_features,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spec = []
|
||||
if self.is_v2pro:
|
||||
sv_emb = []
|
||||
|
||||
sv_emb = [] if self.is_v2pro else None
|
||||
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
||||
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
||||
refer_audio_spec.append(spec)
|
||||
if self.is_v2pro:
|
||||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
||||
|
||||
batch_audio_fragment = []
|
||||
if not streaming_mode:
|
||||
print(f"############ {i18n('预测语义Token')} ############")
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_lens,
|
||||
prompt,
|
||||
all_bert_features,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
|
||||
# ## vits并行推理 method 1
|
||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||
# max_len = 0
|
||||
# for i in range(0, len(batch_phones)):
|
||||
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||||
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
|
||||
# ))
|
||||
print(f"############ {i18n('合成音频')} ############")
|
||||
if not self.configs.use_vocoder:
|
||||
if speed_factor == 1.0:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [
|
||||
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
||||
for i in range(0, len(pred_semantic_list))
|
||||
]
|
||||
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = (
|
||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
if self.is_v2pro != True:
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
).detach()[0, 0, :]
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment = [
|
||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||
for i in range(1, len(audio_frag_end_idx))
|
||||
]
|
||||
else:
|
||||
# ## vits串行推理
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
if self.is_v2pro != True:
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
).detach()[0, 0, :]
|
||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||
else:
|
||||
if parallel_infer:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
||||
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.extend(audio_fragments)
|
||||
else:
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.using_vocoder_synthesis(
|
||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||
|
||||
batch_audio_fragment = []
|
||||
|
||||
# ## vits并行推理 method 1
|
||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||
# max_len = 0
|
||||
# for i in range(0, len(batch_phones)):
|
||||
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||||
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
|
||||
# ))
|
||||
print(f"############ {i18n('合成音频')} ############")
|
||||
if not self.configs.use_vocoder:
|
||||
if speed_factor == 1.0:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [
|
||||
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
||||
for i in range(0, len(pred_semantic_list))
|
||||
]
|
||||
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = (
|
||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
)
|
||||
batch_audio_fragment.append(audio_fragment)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
).detach()[0, 0, :]
|
||||
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment = [
|
||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||
for i in range(1, len(audio_frag_end_idx))
|
||||
]
|
||||
else:
|
||||
# ## vits串行推理
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
).detach()[0, 0, :]
|
||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||
else:
|
||||
if parallel_infer:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
||||
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.extend(audio_fragments)
|
||||
else:
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.using_vocoder_synthesis(
|
||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.append(audio_fragment)
|
||||
|
||||
else:
|
||||
# refer_audio_spec: torch.Tensor = [
|
||||
# item.to(dtype=self.precision, device=self.configs.device)
|
||||
# for item in self.prompt_cache["refer_spec"]
|
||||
# ]
|
||||
semantic_token_generator =self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids[0].unsqueeze(0),
|
||||
all_phoneme_lens,
|
||||
prompt,
|
||||
all_bert_features[0].unsqueeze(0),
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
streaming_mode=True,
|
||||
chunk_length=min_chunk_length,
|
||||
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None,
|
||||
chunk_split_thershold=chunk_split_thershold,
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
phones = batch_phones[0].unsqueeze(0).to(self.configs.device)
|
||||
is_first_chunk = True
|
||||
|
||||
if not self.configs.use_vocoder:
|
||||
# if speed_factor == 1.0:
|
||||
# upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
|
||||
# else:
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
|
||||
else:
|
||||
# if speed_factor == 1.0:
|
||||
# upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
|
||||
# else:
|
||||
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
|
||||
|
||||
last_audio_chunk = None
|
||||
# last_tokens = None
|
||||
last_latent = None
|
||||
previous_tokens = []
|
||||
overlap_len = overlap_length
|
||||
overlap_size = math.ceil(overlap_length*upsample_rate)
|
||||
for semantic_tokens, is_final in semantic_token_generator:
|
||||
if semantic_tokens is None and last_audio_chunk is not None:
|
||||
yield self.audio_postprocess(
|
||||
[[last_audio_chunk[-overlap_size:]]],
|
||||
output_sr,
|
||||
None,
|
||||
speed_factor,
|
||||
False,
|
||||
0.0,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
break
|
||||
|
||||
_semantic_tokens = semantic_tokens
|
||||
print(f"semantic_tokens shape:{semantic_tokens.shape}")
|
||||
|
||||
previous_tokens.append(semantic_tokens)
|
||||
|
||||
_semantic_tokens = torch.cat(previous_tokens, dim=-1)
|
||||
|
||||
if not is_first_chunk and semantic_tokens.shape[-1] < 10:
|
||||
overlap_len = overlap_length+(10-semantic_tokens.shape[-1])
|
||||
else:
|
||||
overlap_len = overlap_length
|
||||
|
||||
|
||||
if not self.configs.use_vocoder:
|
||||
token_padding_length = 0
|
||||
# token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1]
|
||||
# if token_padding_length>0:
|
||||
# _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486)
|
||||
# else:
|
||||
# token_padding_length = 0
|
||||
|
||||
audio_chunk, latent, latent_mask = self.vits_model.decode_streaming(
|
||||
_semantic_tokens.unsqueeze(0),
|
||||
phones, refer_audio_spec,
|
||||
speed=speed_factor,
|
||||
sv_emb=sv_emb,
|
||||
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
||||
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
||||
if last_latent is not None else None,
|
||||
padding_length=token_padding_length
|
||||
)
|
||||
audio_chunk=audio_chunk.detach()[0, 0, :]
|
||||
else:
|
||||
raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式"))
|
||||
|
||||
if overlap_len>overlap_length:
|
||||
audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):]
|
||||
|
||||
audio_chunk_ = audio_chunk
|
||||
if is_first_chunk and not is_final:
|
||||
is_first_chunk = False
|
||||
audio_chunk_ = audio_chunk_[:-overlap_size]
|
||||
elif is_first_chunk and is_final:
|
||||
is_first_chunk = False
|
||||
elif not is_first_chunk and not is_final:
|
||||
audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size)
|
||||
audio_chunk_ = (
|
||||
audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \
|
||||
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
|
||||
)
|
||||
|
||||
last_latent = latent
|
||||
last_audio_chunk = audio_chunk
|
||||
yield self.audio_postprocess(
|
||||
[[audio_chunk_]],
|
||||
output_sr,
|
||||
None,
|
||||
speed_factor,
|
||||
False,
|
||||
0.0,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
|
||||
if is_first_package:
|
||||
print(f"first_package_delay: {time.perf_counter()-t0:.3f}")
|
||||
is_first_package = False
|
||||
|
||||
|
||||
yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16)
|
||||
|
||||
t5 = time.perf_counter()
|
||||
t_45 += t5 - t4
|
||||
@ -1327,17 +1490,18 @@ class TTS:
|
||||
fragment_interval,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
elif streaming_mode:...
|
||||
else:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
if not (return_fragment or streaming_mode):
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
if len(audio) == 0:
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||||
return
|
||||
yield self.audio_postprocess(
|
||||
audio,
|
||||
@ -1384,16 +1548,17 @@ class TTS:
|
||||
fragment_interval: float = 0.3,
|
||||
super_sampling: bool = False,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
||||
)
|
||||
if fragment_interval>0:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
||||
)
|
||||
|
||||
for i, batch in enumerate(audio):
|
||||
for j, audio_fragment in enumerate(batch):
|
||||
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
|
||||
if max_audio > 1:
|
||||
audio_fragment /= max_audio
|
||||
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment
|
||||
audio[i][j] = audio_fragment
|
||||
|
||||
if split_bucket:
|
||||
@ -1413,13 +1578,18 @@ class TTS:
|
||||
max_audio = np.abs(audio).max()
|
||||
if max_audio > 1:
|
||||
audio /= max_audio
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
t2 = time.perf_counter()
|
||||
print(f"超采样用时:{t2 - t1:.3f}s")
|
||||
else:
|
||||
# audio = audio.float() * 32768
|
||||
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy()
|
||||
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
|
||||
# try:
|
||||
# if speed_factor != 1.0:
|
||||
# audio = speed_change(audio, speed=speed_factor, sr=int(sr))
|
||||
@ -1612,24 +1782,43 @@ class TTS:
|
||||
self,
|
||||
audio_fragments: List[torch.Tensor],
|
||||
overlap_len: int,
|
||||
search_len:int= 320
|
||||
):
|
||||
# overlap_len-=search_len
|
||||
|
||||
dtype = audio_fragments[0].dtype
|
||||
|
||||
for i in range(len(audio_fragments) - 1):
|
||||
f1 = audio_fragments[i]
|
||||
f2 = audio_fragments[i + 1]
|
||||
f1 = audio_fragments[i].float()
|
||||
f2 = audio_fragments[i + 1].float()
|
||||
w1 = f1[-overlap_len:]
|
||||
w2 = f2[:overlap_len]
|
||||
assert w1.shape == w2.shape
|
||||
corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1), padding=w2.shape[-1] // 2).view(-1)[:-1]
|
||||
idx = corr.argmax()
|
||||
f1_ = f1[: -(overlap_len - idx)]
|
||||
w2 = f2[:overlap_len+search_len]
|
||||
# w2 = w2[-w2.shape[-1]//2:]
|
||||
# assert w1.shape == w2.shape
|
||||
corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1)
|
||||
|
||||
corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8
|
||||
idx = (corr_norm/corr_den.sqrt()).argmax()
|
||||
|
||||
print(f"seg_idx: {idx}")
|
||||
|
||||
# idx = corr.argmax()
|
||||
f1_ = f1[: -overlap_len]
|
||||
audio_fragments[i] = f1_
|
||||
|
||||
f2_ = f2[idx:]
|
||||
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
|
||||
f2_[: (overlap_len - idx)] = (
|
||||
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
|
||||
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
|
||||
window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype)
|
||||
f2_[: overlap_len] = (
|
||||
window[: overlap_len] * f2_[: overlap_len]
|
||||
+ window[overlap_len :] * f1[-overlap_len :]
|
||||
)
|
||||
|
||||
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
|
||||
# f2_[: (overlap_len - idx)] = (
|
||||
# window * f2_[: (overlap_len - idx)]
|
||||
# + (1-window) * f1[-(overlap_len - idx) :]
|
||||
# )
|
||||
|
||||
audio_fragments[i + 1] = f2_
|
||||
|
||||
return torch.cat(audio_fragments, 0)
|
||||
return torch.cat(audio_fragments, 0).to(dtype)
|
||||
|
||||
@ -261,41 +261,21 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
if self.false.device != padding_mask.device:
|
||||
self.false = self.false.to(padding_mask.device)
|
||||
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
||||
x_item = x[i, idx, :].unsqueeze(0)
|
||||
attn_item = attn[i, idx, :].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i, idx, :] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x = x + attn
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||
|
||||
@ -417,7 +417,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||
)
|
||||
with gr.Row():
|
||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True)
|
||||
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(
|
||||
|
||||
@ -37,6 +37,10 @@ from einops import rearrange, repeat
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from module.distrib import broadcast_tensors, is_distributed
|
||||
from module.ddp_utils import SyncFunction
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@ -69,27 +73,40 @@ def sample_vectors(samples, num: int):
|
||||
return samples[indices]
|
||||
|
||||
|
||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
||||
dim, dtype = samples.shape[-1], samples.dtype
|
||||
max_kmeans_samples = 500
|
||||
samples = samples[:max_kmeans_samples, :]
|
||||
def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64):
|
||||
N, D = samples.shape
|
||||
dtype, device = samples.dtype, samples.device
|
||||
|
||||
if frames_to_use < N:
|
||||
indices = torch.randperm(N, device=device)[:frames_to_use]
|
||||
samples = samples[indices]
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
print("kmeans start ... ")
|
||||
for _ in tqdm(range(num_iters)):
|
||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
||||
dists = -(diffs**2).sum(dim=-1)
|
||||
# Store cluster assignments
|
||||
all_assignments = []
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
for i in range(0, samples.shape[0], batch_size):
|
||||
batch = samples[i : i + batch_size] # [B, D]
|
||||
dists = torch.cdist(batch, means, p=2) # [B, C]
|
||||
assignments = dists.argmin(dim=1) # [B]
|
||||
all_assignments.append(assignments)
|
||||
|
||||
buckets = torch.cat(all_assignments, dim=0) # [N]
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
# Compute new means
|
||||
new_means = torch.zeros_like(means)
|
||||
for i in range(num_clusters):
|
||||
mask = buckets == i
|
||||
if mask.any():
|
||||
new_means[i] = samples[mask].mean(dim=0)
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
means = torch.where(zero_mask[:, None], means, new_means)
|
||||
|
||||
return means, bins
|
||||
|
||||
@ -141,13 +158,24 @@ class EuclideanCodebook(nn.Module):
|
||||
if self.inited:
|
||||
return
|
||||
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
# [B * T * world_size, D]
|
||||
data = SyncFunction.apply(data)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
else:
|
||||
embed = torch.empty_like(self.embed)
|
||||
cluster_size = torch.empty_like(self.cluster_size)
|
||||
dist.broadcast(embed, src=0)
|
||||
dist.broadcast(cluster_size, src=0)
|
||||
|
||||
self.embed.data.copy_(embed)
|
||||
self.embed_avg.data.copy_(embed.clone())
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
# broadcast_tensors(self.buffers())
|
||||
broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
||||
@ -161,9 +189,17 @@ class EuclideanCodebook(nn.Module):
|
||||
if not torch.any(expired_codes):
|
||||
return
|
||||
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace_(batch_samples, mask=expired_codes)
|
||||
# broadcast_tensors(self.buffers())
|
||||
if is_distributed():
|
||||
# [B * T * world_size, D]
|
||||
batch_samples = SyncFunction.apply(batch_samples)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
new_embeds = sample_vectors(batch_samples, expired_codes.sum())
|
||||
else:
|
||||
new_embeds = torch.zeros(expired_codes.sum(), self.embed.size(1), device=self.embed.device)
|
||||
dist.broadcast(new_embeds, src=0)
|
||||
self.embed.data[expired_codes] = new_embeds
|
||||
broadcast_tensors(self.buffers())
|
||||
|
||||
def preprocess(self, x):
|
||||
x = rearrange(x, "... d -> (...) d")
|
||||
@ -208,17 +244,26 @@ class EuclideanCodebook(nn.Module):
|
||||
quantize = self.dequantize(embed_ind)
|
||||
|
||||
if self.training:
|
||||
### Update codebook by EMA
|
||||
embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
|
||||
embed_sum = x.t() @ embed_onehot # [D, cb-size]
|
||||
if is_distributed():
|
||||
dist.all_reduce(embed_onehot_sum)
|
||||
dist.all_reduce(embed_sum)
|
||||
# Update ema cluster count N_i^t, eq. (6) in vqvae paper
|
||||
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
|
||||
# Update ema embed: eq. (7) in vqvae paper
|
||||
self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
|
||||
# apply laplace smoothing
|
||||
n = self.cluster_size.sum()
|
||||
cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
|
||||
# Update ema embed: eq. (8) in vqvae paper
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
# We do the expiry of code at that point as buffers are in sync
|
||||
# and all the workers will take the same decision.
|
||||
self.expire_codes_(x)
|
||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
|
||||
181
GPT_SoVITS/module/ddp_utils.py
Normal file
181
GPT_SoVITS/module/ddp_utils.py
Normal file
@ -0,0 +1,181 @@
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel.distributed import _find_tensors
|
||||
from packaging import version
|
||||
|
||||
|
||||
# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
|
||||
class SyncFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# @torch.no_grad()
|
||||
def forward(ctx, tensor):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# Collect batch sizes from all processes
|
||||
local_bs = torch.tensor([tensor.shape[0]], device=tensor.device)
|
||||
batch_sizes = [torch.zeros_like(local_bs) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(batch_sizes, local_bs)
|
||||
|
||||
# Convert to integer list and find the minimum
|
||||
batch_sizes_int = [bs.item() for bs in batch_sizes]
|
||||
min_bs = min(batch_sizes_int)
|
||||
|
||||
# Crop the tensor to the minimum batch size if needed
|
||||
cropped_tensor = tensor[:min_bs] if tensor.shape[0] > min_bs else tensor
|
||||
|
||||
# Prepare for gathering
|
||||
out_shape = (min_bs * world_size,) + tensor.shape[1:]
|
||||
gathered_tensor = torch.zeros(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
# Build tensor list for all_gather
|
||||
tensor_list = list(torch.chunk(gathered_tensor, world_size))
|
||||
|
||||
# Perform all_gather using the cropped tensors
|
||||
torch.distributed.all_gather(tensor_list, cropped_tensor)
|
||||
|
||||
# Save for backward pass
|
||||
ctx.min_bs = min_bs
|
||||
ctx.world_size = world_size
|
||||
ctx.orig_shape = tensor.shape
|
||||
|
||||
return gathered_tensor
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
assert False
|
||||
grad_input = grad_output.clone()
|
||||
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
|
||||
|
||||
idx_from = torch.distributed.get_rank() * ctx.batch_size
|
||||
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
|
||||
return grad_input[idx_from:idx_to]
|
||||
|
||||
class DDP(DistributedDataParallel):
|
||||
"""
|
||||
Override the forward call in lightning so it goes to training and validation step respectively
|
||||
"""
|
||||
|
||||
def forward(self, *inputs, **kwargs): # pragma: no cover
|
||||
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
|
||||
self._sync_params()
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
assert len(self.device_ids) == 1
|
||||
if self.module.training:
|
||||
output = self.module.training_step(*inputs[0], **kwargs[0])
|
||||
elif self.module.testing:
|
||||
output = self.module.test_step(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
||||
if torch.is_grad_enabled():
|
||||
# We'll return the output object verbatim since it is a freeform
|
||||
# object. We need to find any tensors in this object, though,
|
||||
# because we need to figure out which parameters were used during
|
||||
# this forward pass, to ensure we short circuit reduction for any
|
||||
# unused parameters. Only if `find_unused_parameters` is set.
|
||||
if self.find_unused_parameters:
|
||||
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||
else:
|
||||
self.reducer.prepare_for_backward([])
|
||||
else:
|
||||
from torch.nn.parallel.distributed import (
|
||||
Join,
|
||||
_DDPSink,
|
||||
_tree_flatten_with_rref,
|
||||
_tree_unflatten_with_rref,
|
||||
)
|
||||
|
||||
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
|
||||
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
||||
self.logger.set_runtime_stats_and_log()
|
||||
self.num_iterations += 1
|
||||
self.reducer.prepare_for_forward()
|
||||
|
||||
# Notify the join context that this process has not joined, if
|
||||
# needed
|
||||
work = Join.notify_join_context(self)
|
||||
if work:
|
||||
self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
|
||||
|
||||
# Calling _rebuild_buckets before forward compuation,
|
||||
# It may allocate new buckets before deallocating old buckets
|
||||
# inside _rebuild_buckets. To save peak memory usage,
|
||||
# call _rebuild_buckets before the peak memory usage increases
|
||||
# during forward computation.
|
||||
# This should be called only once during whole training period.
|
||||
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
|
||||
print("Reducer buckets have been rebuilt in this iteration.")
|
||||
self._has_rebuilt_buckets = True
|
||||
|
||||
# sync params according to location (before/after forward) user
|
||||
# specified as part of hook, if hook was specified.
|
||||
buffer_hook_registered = hasattr(self, "buffer_hook")
|
||||
if self._check_sync_bufs_pre_fwd():
|
||||
self._sync_buffers()
|
||||
|
||||
if self._join_config.enable:
|
||||
# Notify joined ranks whether they should sync in backwards pass or not.
|
||||
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
|
||||
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if self.module.training:
|
||||
output = self.module.training_step(*inputs[0], **kwargs[0])
|
||||
elif self.module.testing:
|
||||
output = self.module.test_step(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
||||
|
||||
# sync params according to location (before/after forward) user
|
||||
# specified as part of hook, if hook was specified.
|
||||
if self._check_sync_bufs_post_fwd():
|
||||
self._sync_buffers()
|
||||
|
||||
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
||||
self.require_forward_param_sync = True
|
||||
# We'll return the output object verbatim since it is a freeform
|
||||
# object. We need to find any tensors in this object, though,
|
||||
# because we need to figure out which parameters were used during
|
||||
# this forward pass, to ensure we short circuit reduction for any
|
||||
# unused parameters. Only if `find_unused_parameters` is set.
|
||||
if self.find_unused_parameters and not self.static_graph:
|
||||
# Do not need to populate this for static graph.
|
||||
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||
else:
|
||||
self.reducer.prepare_for_backward([])
|
||||
else:
|
||||
self.require_forward_param_sync = False
|
||||
|
||||
# TODO: DDPSink is currently enabled for unused parameter detection and
|
||||
# static graph training for first iteration.
|
||||
if (self.find_unused_parameters and not self.static_graph) or (
|
||||
self.static_graph and self.num_iterations == 1
|
||||
):
|
||||
state_dict = {
|
||||
"static_graph": self.static_graph,
|
||||
"num_iterations": self.num_iterations,
|
||||
}
|
||||
|
||||
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
|
||||
output_placeholders = [None for _ in range(len(output_tensor_list))]
|
||||
# Do not touch tensors that have no grad_fn, which can cause issues
|
||||
# such as https://github.com/pytorch/pytorch/issues/60733
|
||||
for i, output in enumerate(output_tensor_list):
|
||||
if torch.is_tensor(output) and output.grad_fn is None:
|
||||
output_placeholders[i] = output
|
||||
|
||||
# When find_unused_parameters=True, makes tensors which require grad
|
||||
# run through the DDPSink backward pass. When not all outputs are
|
||||
# used in loss, this makes those corresponding tensors receive
|
||||
# undefined gradient which the reducer then handles to ensure
|
||||
# param.grad field is not touched and we don't error out.
|
||||
passthrough_tensor_list = _DDPSink.apply(
|
||||
self.reducer,
|
||||
state_dict,
|
||||
*output_tensor_list,
|
||||
)
|
||||
for i in range(len(output_placeholders)):
|
||||
if output_placeholders[i] is None:
|
||||
output_placeholders[i] = passthrough_tensor_list[i]
|
||||
|
||||
# Reconstruct output data structure.
|
||||
output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
|
||||
return output
|
||||
123
GPT_SoVITS/module/distrib.py
Normal file
123
GPT_SoVITS/module/distrib.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Torch distributed utilities."""
|
||||
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def rank():
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def world_size():
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def is_distributed():
|
||||
return world_size() > 1
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
||||
if is_distributed():
|
||||
return torch.distributed.all_reduce(tensor, op)
|
||||
|
||||
|
||||
def _is_complex_or_float(tensor):
|
||||
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
||||
|
||||
|
||||
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
||||
# utility function to check that the number of params in all workers is the same,
|
||||
# and thus avoid a deadlock with distributed all reduce.
|
||||
if not is_distributed() or not params:
|
||||
return
|
||||
# print('params[0].device ', params[0].device)
|
||||
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
||||
all_reduce(tensor)
|
||||
if tensor.item() != len(params) * world_size():
|
||||
# If not all the workers have the same number, for at least one of them,
|
||||
# this inequality will be verified.
|
||||
raise RuntimeError(
|
||||
f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
|
||||
)
|
||||
|
||||
|
||||
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
||||
"""Broadcast the tensors from the given parameters to all workers.
|
||||
This can be used to ensure that all workers have the same model to start with.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
||||
_check_number_of_params(tensors)
|
||||
handles = []
|
||||
for tensor in tensors:
|
||||
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
||||
handles.append(handle)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
|
||||
def sync_buffer(buffers, average=True):
|
||||
"""
|
||||
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for buffer in buffers:
|
||||
if torch.is_floating_point(buffer.data):
|
||||
if average:
|
||||
handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
|
||||
handles.append((buffer, handle))
|
||||
for buffer, handle in handles:
|
||||
handle.wait()
|
||||
if average:
|
||||
buffer.data /= world_size
|
||||
|
||||
|
||||
def sync_grad(params):
|
||||
"""
|
||||
Simpler alternative to DistributedDataParallel, that doesn't rely
|
||||
on any black magic. For simple models it can also be as fast.
|
||||
Just call this on your model parameters after the call to backward!
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for p in params:
|
||||
if p.grad is not None:
|
||||
handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||
handles.append((p, handle))
|
||||
for p, handle in handles:
|
||||
handle.wait()
|
||||
p.grad.data /= world_size()
|
||||
|
||||
|
||||
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
|
||||
"""Average a dictionary of metrics across all workers, using the optional
|
||||
`count` as unormalized weight.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return metrics
|
||||
keys, values = zip(*metrics.items())
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
||||
tensor *= count
|
||||
all_reduce(tensor)
|
||||
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
||||
return dict(zip(keys, averaged))
|
||||
@ -151,6 +151,8 @@ class DurationPredictor(nn.Module):
|
||||
return x * x_mask
|
||||
|
||||
|
||||
WINDOW = {}
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -209,7 +211,7 @@ class TextEncoder(nn.Module):
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
@ -222,13 +224,44 @@ class TextEncoder(nn.Module):
|
||||
text = self.text_embedding(text).transpose(1, 2)
|
||||
text = self.encoder_text(text * text_mask, text_mask)
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
|
||||
if padding_length is not None and padding_length!=0:
|
||||
y = y[:, :, :-padding_length]
|
||||
y_mask = y_mask[:, :, :-padding_length]
|
||||
|
||||
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
|
||||
if result_length is not None:
|
||||
y = y[:, :, -result_length:]
|
||||
y_mask = y_mask[:, :, -result_length:]
|
||||
|
||||
if overlap_frames is not None:
|
||||
overlap_len = overlap_frames.shape[-1]
|
||||
window = WINDOW.get(overlap_len, None)
|
||||
if window is None:
|
||||
# WINDOW[overlap_len] = torch.hann_window(overlap_len*2, device=y.device, dtype=y.dtype)
|
||||
WINDOW[overlap_len] = torch.sin(torch.arange(overlap_len*2, device=y.device) * torch.pi / (overlap_len*2))
|
||||
window = WINDOW[overlap_len]
|
||||
|
||||
|
||||
window = window.to(y.device)
|
||||
y[:,:,:overlap_len] = (
|
||||
window[:overlap_len].view(1, 1, -1) * y[:,:,:overlap_len]
|
||||
+ window[overlap_len:].view(1, 1, -1) * overlap_frames
|
||||
)
|
||||
|
||||
y_ = y
|
||||
y_mask_ = y_mask
|
||||
|
||||
|
||||
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
stats = self.proj(y) * y_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return y, m, logs, y_mask
|
||||
return y, m, logs, y_mask, y_, y_mask_
|
||||
|
||||
def extract_latent(self, x):
|
||||
x = self.ssl_proj(x)
|
||||
@ -921,7 +954,7 @@ class SynthesizerTrn(nn.Module):
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
|
||||
@ -949,7 +982,7 @@ class SynthesizerTrn(nn.Module):
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -957,6 +990,7 @@ class SynthesizerTrn(nn.Module):
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
def get_ge(refer, sv_emb):
|
||||
@ -989,7 +1023,7 @@ class SynthesizerTrn(nn.Module):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
|
||||
quantized,
|
||||
y_lengths,
|
||||
text,
|
||||
@ -1004,6 +1038,59 @@ class SynthesizerTrn(nn.Module):
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||
def get_ge(refer, sv_emb):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
||||
ge += sv_emb.unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
return ge
|
||||
|
||||
if type(refer) == list:
|
||||
ges = []
|
||||
for idx, _refer in enumerate(refer):
|
||||
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||
ges.append(ge)
|
||||
ge = torch.stack(ges, 0).mean(0)
|
||||
else:
|
||||
ge = get_ge(refer, sv_emb)
|
||||
|
||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
result_length = (2*result_length) if result_length is not None else None
|
||||
padding_length = (2*padding_length) if padding_length is not None else None
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(
|
||||
quantized,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||
speed,
|
||||
result_length=result_length,
|
||||
overlap_frames=overlap_frames,
|
||||
padding_length=padding_length
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o, y_, y_mask_
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
@ -1226,7 +1313,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
||||
fea, y_mask_ = self.wns1(
|
||||
@ -1260,7 +1347,7 @@ class SynthesizerTrnV3(nn.Module):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
@ -1377,7 +1464,7 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
@ -1420,7 +1507,7 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
x, m_p, logs_p, y_mask, y_, y_mask_ = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
|
||||
@ -124,7 +124,7 @@ def run(rank, n_gpus, hps):
|
||||
collate_fn=collate_fn,
|
||||
batch_sampler=train_sampler,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=4,
|
||||
prefetch_factor=3,
|
||||
)
|
||||
# if rank == 0:
|
||||
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
||||
|
||||
@ -118,13 +118,13 @@ def run(rank, n_gpus, hps):
|
||||
collate_fn = TextAudioSpeakerCollate()
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=6,
|
||||
num_workers=5,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_sampler=train_sampler,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=4,
|
||||
prefetch_factor=3,
|
||||
)
|
||||
# if rank == 0:
|
||||
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
||||
|
||||
@ -120,13 +120,13 @@ def run(rank, n_gpus, hps):
|
||||
collate_fn = TextAudioSpeakerCollate()
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=6,
|
||||
num_workers=5,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_sampler=train_sampler,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=4,
|
||||
prefetch_factor=3,
|
||||
)
|
||||
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
|
||||
os.makedirs(save_root, exist_ok=True)
|
||||
|
||||
611
GPT_SoVITS/stream_v2pro.py
Normal file
611
GPT_SoVITS/stream_v2pro.py
Normal file
@ -0,0 +1,611 @@
|
||||
# 这是一个实验性质的实现,旨在探索 stream infer 的可能性。(xiao hai xie zhe wan de)
|
||||
from typing import List
|
||||
from export_torch_script import ExportERes2NetV2, SSLModel, T2SModel, VitsModel, get_raw_t2s_model, init_sv_cn, resamplex, sample, spectrogram_torch
|
||||
import export_torch_script
|
||||
from my_utils import load_audio
|
||||
import torch
|
||||
from torch import LongTensor, Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import soundfile
|
||||
from inference_webui import get_phones_and_bert
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class StreamT2SModel(nn.Module):
|
||||
def __init__(self, t2s: T2SModel):
|
||||
super(StreamT2SModel, self).__init__()
|
||||
self.t2s = t2s
|
||||
|
||||
@torch.jit.export
|
||||
def pre_infer(
|
||||
self,
|
||||
prompts: LongTensor,
|
||||
ref_seq: LongTensor,
|
||||
text_seq: LongTensor,
|
||||
ref_bert: torch.Tensor,
|
||||
text_bert: torch.Tensor,
|
||||
top_k: int,
|
||||
) -> tuple[int, Tensor, Tensor, List[Tensor], List[Tensor]]:
|
||||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||
bert = bert.unsqueeze(0)
|
||||
|
||||
x = self.t2s.ar_text_embedding(all_phoneme_ids)
|
||||
x = x + self.t2s.bert_proj(bert.transpose(1, 2))
|
||||
x: torch.Tensor = self.t2s.ar_text_position(x)
|
||||
|
||||
# [1,N,512] [1,N]
|
||||
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
y = prompts
|
||||
# x_example = x[:,:,0] * 0.0
|
||||
|
||||
x_len = x.shape[1]
|
||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||
|
||||
y_emb = self.t2s.ar_audio_embedding(y)
|
||||
y_len: int = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_pos = self.t2s.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
x_attn_mask_pad = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = (
|
||||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
.unsqueeze(0)
|
||||
.expand(bsz * self.t2s.num_head, -1, -1)
|
||||
.view(bsz, self.t2s.num_head, src_len, src_len)
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.process_prompt(
|
||||
xy_pos, xy_attn_mask, None
|
||||
)
|
||||
|
||||
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
|
||||
logits = logits[:, :-1]
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0
|
||||
)[0]
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
y_emb: Tensor = self.t2s.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos: Tensor = (
|
||||
y_emb * self.t2s.ar_audio_position.x_scale
|
||||
+ self.t2s.ar_audio_position.alpha
|
||||
* self.t2s.ar_audio_position.pe[:, y_len].to(
|
||||
dtype=y_emb.dtype, device=y_emb.device
|
||||
)
|
||||
)
|
||||
|
||||
return y_len, y, xy_pos, k_cache, v_cache
|
||||
|
||||
@torch.jit.export
|
||||
def decode_next_token(
|
||||
self,
|
||||
idx: int, # 记住从1开始 到1500
|
||||
top_k: int,
|
||||
y_len: int,
|
||||
y: Tensor,
|
||||
xy_pos: Tensor,
|
||||
k_cache: List[Tensor],
|
||||
v_cache: List[Tensor],
|
||||
) -> tuple[Tensor, Tensor, int, List[Tensor], List[Tensor]]:
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
xy_dec, k_cache, v_cache = self.t2s.t2s_transformer.decode_next_token(
|
||||
xy_pos, k_cache, v_cache
|
||||
)
|
||||
logits = self.t2s.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0
|
||||
)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
last_token = int(samples[0, 0])
|
||||
|
||||
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
# stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
|
||||
return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache
|
||||
|
||||
# if stop:
|
||||
# if y.shape[1] == 0:
|
||||
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
# break
|
||||
|
||||
y_emb = self.t2s.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = (
|
||||
y_emb * self.t2s.ar_audio_position.x_scale
|
||||
+ self.t2s.ar_audio_position.alpha
|
||||
* self.t2s.ar_audio_position.pe[:, y_len + idx].to(
|
||||
dtype=y_emb.dtype, device=y_emb.device
|
||||
)
|
||||
)
|
||||
return y, xy_pos, last_token, k_cache, v_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
idx: int, # 记住从1开始 到1500
|
||||
top_k: int,
|
||||
y_len: int,
|
||||
y: Tensor,
|
||||
xy_pos: Tensor,
|
||||
k_cache: List[Tensor],
|
||||
v_cache: List[Tensor],
|
||||
):
|
||||
return self.decode_next_token(idx,top_k,y_len,y,xy_pos,k_cache,v_cache)
|
||||
|
||||
|
||||
class StepVitsModel(nn.Module):
|
||||
def __init__(self, vits: VitsModel,sv_model:ExportERes2NetV2):
|
||||
super().__init__()
|
||||
self.hps = vits.hps
|
||||
self.vq_model = vits.vq_model
|
||||
self.hann_window = vits.hann_window
|
||||
self.sv = sv_model
|
||||
|
||||
def ref_handle(self, ref_audio_32k):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio_32k.float(),
|
||||
self.hps.data.filter_length,
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
refer = refer.to(ref_audio_32k.dtype)
|
||||
ref_audio_16k = resamplex(ref_audio_32k, 32000, 16000).to(ref_audio_32k.dtype).to(ref_audio_32k.device)
|
||||
sv_emb = self.sv(ref_audio_16k)
|
||||
return refer, sv_emb
|
||||
|
||||
def extract_latent(self, ssl_content):
|
||||
codes = self.vq_model.extract_latent(ssl_content)
|
||||
return codes[0]
|
||||
|
||||
def forward(self, pred_semantic, text_seq, refer, sv_emb=None):
|
||||
return self.vq_model(
|
||||
pred_semantic, text_seq, refer, speed=1.0, sv_emb=sv_emb
|
||||
)[0, 0]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def find_best_audio_offset_fast(reference_audio: Tensor, search_audio: Tensor):
|
||||
ref_len = len(reference_audio)
|
||||
search_len = len(search_audio)
|
||||
|
||||
if search_len < ref_len:
|
||||
raise ValueError(
|
||||
f"搜索音频长度 ({search_len}) 必须大于等于参考音频长度 ({ref_len})"
|
||||
)
|
||||
|
||||
# 使用F.conv1d计算原始互相关
|
||||
reference_flipped = reference_audio.unsqueeze(0).unsqueeze(0)
|
||||
search_padded = search_audio.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# 计算点积
|
||||
dot_products = F.conv1d(search_padded, reference_flipped).squeeze()
|
||||
|
||||
if len(dot_products.shape) == 0:
|
||||
dot_products = dot_products.unsqueeze(0)
|
||||
|
||||
# 计算参考音频的平方和
|
||||
ref_squared_sum = torch.sum(reference_audio**2)
|
||||
|
||||
# 计算搜索音频每个位置的平方和(滑动窗口)
|
||||
search_squared = search_audio**2
|
||||
search_squared_padded = search_squared.unsqueeze(0).unsqueeze(0)
|
||||
ones_kernel = torch.ones(
|
||||
1, 1, ref_len, dtype=search_audio.dtype, device=search_audio.device
|
||||
)
|
||||
|
||||
segment_squared_sums = F.conv1d(search_squared_padded, ones_kernel).squeeze()
|
||||
|
||||
if len(segment_squared_sums.shape) == 0:
|
||||
segment_squared_sums = segment_squared_sums.unsqueeze(0)
|
||||
|
||||
# 计算归一化因子
|
||||
ref_norm = torch.sqrt(ref_squared_sum)
|
||||
segment_norms = torch.sqrt(segment_squared_sums)
|
||||
|
||||
# 避免除零
|
||||
epsilon = 1e-8
|
||||
normalization_factor = ref_norm * segment_norms + epsilon
|
||||
|
||||
# 归一化互相关
|
||||
correlation_scores = dot_products / normalization_factor
|
||||
|
||||
best_offset = torch.argmax(correlation_scores).item()
|
||||
|
||||
return best_offset, correlation_scores
|
||||
|
||||
|
||||
import time
|
||||
|
||||
def test_stream(
|
||||
gpt_path,
|
||||
vits_path,
|
||||
version,
|
||||
ref_audio_path,
|
||||
ref_text,
|
||||
output_path,
|
||||
device="cpu",
|
||||
is_half=True,
|
||||
):
|
||||
if export_torch_script.sv_cn_model == None:
|
||||
init_sv_cn(device,is_half)
|
||||
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
|
||||
print(f"device: {device}")
|
||||
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
|
||||
ref_text, "all_zh", "v2"
|
||||
)
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T
|
||||
if is_half:
|
||||
ref_bert = ref_bert.half()
|
||||
ref_bert = ref_bert.to(ref_seq.device)
|
||||
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了,真的神奇,接下来我们说说狐狸,可能这就是狐狸吧.它有长长的尾巴,尖尖的耳朵,传说中还有九条尾巴。你觉得狐狸神奇吗?", "auto", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T
|
||||
if is_half:
|
||||
text_bert = text_bert.half()
|
||||
text_bert = text_bert.to(text_seq.device)
|
||||
|
||||
ssl_content = ssl(ref_audio)
|
||||
if is_half:
|
||||
ssl_content = ssl_content.half()
|
||||
ssl_content = ssl_content.to(device)
|
||||
|
||||
sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
|
||||
vits.eval()
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
if is_half:
|
||||
raw_t2s = raw_t2s.half()
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
# t2s = torch.jit.script(t2s_m).to(device)
|
||||
t2s = t2s_m
|
||||
print("#### script t2s_m ####")
|
||||
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
|
||||
stream_t2s = StreamT2SModel(t2s).to(device)
|
||||
stream_t2s = torch.jit.script(stream_t2s)
|
||||
|
||||
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
|
||||
if is_half:
|
||||
ref_audio_sr = ref_audio_sr.half()
|
||||
ref_audio_sr = ref_audio_sr.to(device)
|
||||
|
||||
top_k = 15
|
||||
|
||||
codes = vits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompts = prompt_semantic.unsqueeze(0)
|
||||
|
||||
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||||
sv_emb = sv_model(audio_16k)
|
||||
print("text_seq",text_seq.shape)
|
||||
|
||||
refer = spectrogram_torch(
|
||||
vits.hann_window,
|
||||
ref_audio_sr,
|
||||
vits.hps.data.filter_length,
|
||||
vits.hps.data.sampling_rate,
|
||||
vits.hps.data.hop_length,
|
||||
vits.hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
|
||||
st = time.time()
|
||||
et = time.time()
|
||||
|
||||
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
idx = 1
|
||||
last_idx = 0
|
||||
audios = []
|
||||
raw_audios = []
|
||||
last_audio_ret = None
|
||||
offset_index = []
|
||||
full_audios = []
|
||||
print("y.shape:", y.shape)
|
||||
cut_id = 0
|
||||
while True:
|
||||
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
||||
# print("y.shape:", y.shape)
|
||||
stop = last_token==t2s.EOS
|
||||
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
|
||||
|
||||
if last_token < 50 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
|
||||
cut_id = idx + 7
|
||||
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
|
||||
# y = torch.cat([y, y[:,-1:]], dim=1)
|
||||
# idx+=1
|
||||
|
||||
if stop :
|
||||
idx -=1
|
||||
print('stop')
|
||||
print(idx, y[:,-idx+last_idx:])
|
||||
print(idx,last_idx, y.shape)
|
||||
print(y[:,-idx:-idx+20])
|
||||
|
||||
|
||||
# 玄学这档子事说不清楚
|
||||
if idx == cut_id or stop:
|
||||
print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}")
|
||||
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||
full_audios.append(audio)
|
||||
if last_idx == 0:
|
||||
last_audio_ret = audio[-1280*8:-1280*8+256]
|
||||
audio = audio[:-1280*8]
|
||||
raw_audios.append(audio)
|
||||
et = time.time()
|
||||
else:
|
||||
if stop:
|
||||
audio_ = audio[last_idx*1280 -1280*8:]
|
||||
raw_audios.append(audio_)
|
||||
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
|
||||
offset_index.append(i)
|
||||
audio = audio_[i:]
|
||||
else:
|
||||
audio_ = audio[last_idx*1280 -1280*8:-1280*8]
|
||||
raw_audios.append(audio_)
|
||||
i, x = find_best_audio_offset_fast(last_audio_ret, audio_[:1280])
|
||||
offset_index.append(i)
|
||||
last_audio_ret = audio[-1280*8:-1280*8+256]
|
||||
audio = audio_[i:]
|
||||
last_idx = idx
|
||||
# print(f'write {output_path}/out_{audio_index}')
|
||||
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
audios.append(audio)
|
||||
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
|
||||
if idx>1500:
|
||||
break
|
||||
|
||||
if stop:
|
||||
break
|
||||
|
||||
idx+=1
|
||||
|
||||
at = time.time()
|
||||
|
||||
for (i,a) in enumerate(audios):
|
||||
print(f'write {output_path}/out_{i}')
|
||||
soundfile.write(f"{output_path}/out_{i}.wav", a.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
print(f"frist token: {et - st:.4f} seconds")
|
||||
print(f"all token: {at - st:.4f} seconds")
|
||||
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
audio = torch.cat(audios, dim=0)
|
||||
soundfile.write(f"{output_path}/out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
audio_raw = torch.cat(raw_audios, dim=0)
|
||||
soundfile.write(f"{output_path}/out.raw.wav", audio_raw.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
|
||||
|
||||
max_duration = full_audios[-1].shape[0]
|
||||
plt.xlim(0, max_duration)
|
||||
|
||||
last_line = 0
|
||||
|
||||
for i,a in enumerate(full_audios):
|
||||
plt.plot((a+2.0*i).float().detach().cpu().numpy(), color=colors[i], alpha=0.5, label=f"Audio {i}")
|
||||
# plt.axvline(x=last_line, color=colors[i], linestyle='--')
|
||||
last_line = a.shape[0]-8*1280
|
||||
plt.axvline(x=last_line, color=colors[i], linestyle='--')
|
||||
|
||||
plt.plot((audio-2.0).float().detach().cpu().numpy(), color='black', label='Final Audio')
|
||||
|
||||
plt.plot((audio_raw-4.0).float().detach().cpu().numpy(), color='cyan', label='Raw Audio')
|
||||
|
||||
print("offset_index:", offset_index)
|
||||
plt.show()
|
||||
|
||||
|
||||
def export_prov2(
|
||||
gpt_path,
|
||||
vits_path,
|
||||
version,
|
||||
ref_audio_path,
|
||||
ref_text,
|
||||
output_path,
|
||||
device="cpu",
|
||||
is_half=True,
|
||||
lang="auto",
|
||||
):
|
||||
if export_torch_script.sv_cn_model == None:
|
||||
init_sv_cn(device,is_half)
|
||||
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
|
||||
print(f"device: {device}")
|
||||
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
|
||||
ref_text, lang, "v2"
|
||||
)
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T
|
||||
if is_half:
|
||||
ref_bert = ref_bert.half()
|
||||
ref_bert = ref_bert.to(ref_seq.device)
|
||||
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了.The King and His Stories.Once there was a king.He likes to write stories, but his stories were not good.", "auto", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T
|
||||
if is_half:
|
||||
text_bert = text_bert.half()
|
||||
text_bert = text_bert.to(text_seq.device)
|
||||
|
||||
ssl_content = ssl(ref_audio)
|
||||
if is_half:
|
||||
ssl_content = ssl_content.half()
|
||||
ssl_content = ssl_content.to(device)
|
||||
|
||||
sv_model = ExportERes2NetV2(export_torch_script.sv_cn_model)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
|
||||
vits.eval()
|
||||
vits = StepVitsModel(vits, sv_model)
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
if is_half:
|
||||
raw_t2s = raw_t2s.half()
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
# t2s = torch.jit.script(t2s_m).to(device)
|
||||
t2s = t2s_m
|
||||
print("#### script t2s_m ####")
|
||||
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
|
||||
stream_t2s = StreamT2SModel(t2s).to(device)
|
||||
stream_t2s = torch.jit.script(stream_t2s)
|
||||
|
||||
ref_audio_sr = resamplex(ref_audio, 16000, 32000)
|
||||
ref_audio_sr = ref_audio_sr.to(device)
|
||||
if is_half:
|
||||
ref_audio_sr = ref_audio_sr.half()
|
||||
|
||||
top_k = 15
|
||||
|
||||
prompts = vits.extract_latent(ssl_content)
|
||||
|
||||
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||||
sv_emb = sv_model(audio_16k)
|
||||
print("text_seq",text_seq.shape)
|
||||
# torch.jit.trace()
|
||||
|
||||
refer,sv_emb = vits.ref_handle(ref_audio_sr)
|
||||
|
||||
st = time.time()
|
||||
et = time.time()
|
||||
|
||||
y_len, y, xy_pos, k_cache, v_cache = stream_t2s.pre_infer(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
idx = 1
|
||||
print("y.shape:", y.shape)
|
||||
while True:
|
||||
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
||||
# print("y.shape:", y.shape)
|
||||
|
||||
idx+=1
|
||||
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
|
||||
if idx>1500:
|
||||
break
|
||||
|
||||
if last_token == t2s.EOS:
|
||||
break
|
||||
|
||||
at = time.time()
|
||||
print("EOS:",t2s.EOS)
|
||||
|
||||
print(f"frist token: {et - st:.4f} seconds")
|
||||
print(f"all token: {at - st:.4f} seconds")
|
||||
print("sv_emb", sv_emb.shape)
|
||||
print("refer",refer.shape)
|
||||
y = y[:,-idx:].unsqueeze(0)
|
||||
print("y", y.shape)
|
||||
audio = vits(y, text_seq, refer, sv_emb)
|
||||
soundfile.write(f"{output_path}/out_final.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||||
torch._dynamo.mark_dynamic(text_seq, 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||
torch._dynamo.mark_dynamic(refer, 2)
|
||||
torch._dynamo.mark_dynamic(y, 2)
|
||||
|
||||
inputs = {
|
||||
"forward": (y, text_seq, refer, sv_emb),
|
||||
"extract_latent": ssl_content,
|
||||
"ref_handle": ref_audio_sr,
|
||||
}
|
||||
|
||||
stream_t2s.save(f"{output_path}/t2s.pt")
|
||||
torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt")
|
||||
torch.jit.script(find_best_audio_offset_fast, optimize=True).save(f"{output_path}/find_best_audio_offset_fast.pt")
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument(
|
||||
"--sovits_model", required=True, help="Path to the SoVITS model file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_audio", required=True, help="Path to the reference audio file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_text", required=True, help="Path to the reference text file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", required=True, help="Path to the output directory"
|
||||
)
|
||||
parser.add_argument("--device", help="Device to use", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--version", help="version of the model", default="v2Pro")
|
||||
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
|
||||
parser.add_argument("--lang", default="auto", help="Language for text processing (default: auto)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.output_path):
|
||||
os.makedirs(args.output_path)
|
||||
|
||||
is_half = not args.no_half
|
||||
with torch.no_grad():
|
||||
export_prov2(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
version=args.version,
|
||||
ref_audio_path=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
is_half=is_half,
|
||||
lang=args.lang,
|
||||
)
|
||||
@ -238,6 +238,46 @@ def _expand_number(m):
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
# 加减乘除
|
||||
RE_ASMD = re.compile(
|
||||
r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))\s+([\+\-\×÷=])\s+((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
|
||||
)
|
||||
# RE_ASMD = re.compile(
|
||||
# r"\b((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))\b"
|
||||
# )
|
||||
|
||||
asmd_map = {"+": " plus ", "-": " minus ", "×": " times ", "÷": " divided by ", "=": " Equals "}
|
||||
|
||||
|
||||
def replace_asmd(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
|
||||
return result
|
||||
|
||||
|
||||
RE_INTEGER = re.compile(r"(?:^|\s+)(-)" r"(\d+)")
|
||||
|
||||
|
||||
def replace_negative_num(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
sign = match.group(1)
|
||||
number = match.group(2)
|
||||
sign: str = "negative " if sign else ""
|
||||
result = f"{sign}{number}"
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def normalize(text):
|
||||
"""
|
||||
!!! 所有的处理都需要正确的输入 !!!
|
||||
@ -245,7 +285,13 @@ def normalize(text):
|
||||
"""
|
||||
|
||||
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
|
||||
text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
|
||||
|
||||
# 处理数学运算
|
||||
# 替换text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
|
||||
while RE_ASMD.search(text):
|
||||
text = RE_ASMD.sub(replace_asmd, text)
|
||||
text = RE_INTEGER.sub(replace_negative_num, text)
|
||||
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_time_re, _expand_time, text)
|
||||
text = re.sub(_measurement_re, _expand_measurement, text)
|
||||
|
||||
@ -347,7 +347,7 @@ Use v4 from v1/v2/v3 environment:
|
||||
|
||||
2. Clone the latest codes from github.
|
||||
|
||||
3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.ckpt, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS/pretrained_models`.
|
||||
3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.pth, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS/pretrained_models`.
|
||||
|
||||
## V2Pro Release Notes
|
||||
|
||||
|
||||
126
api_v2.py
126
api_v2.py
@ -27,20 +27,23 @@ POST:
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_k": 15, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
}
|
||||
```
|
||||
|
||||
@ -101,7 +104,7 @@ RESP:
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator
|
||||
from typing import Generator, Union
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -121,6 +124,7 @@ from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from pydantic import BaseModel
|
||||
import threading
|
||||
|
||||
# print(sys.path)
|
||||
i18n = I18nAuto()
|
||||
@ -154,7 +158,7 @@ class TTS_Request(BaseModel):
|
||||
aux_ref_audio_paths: list = None
|
||||
prompt_lang: str = None
|
||||
prompt_text: str = ""
|
||||
top_k: int = 5
|
||||
top_k: int = 15
|
||||
top_p: float = 1
|
||||
temperature: float = 1
|
||||
text_split_method: str = "cut5"
|
||||
@ -165,17 +169,58 @@ class TTS_Request(BaseModel):
|
||||
fragment_interval: float = 0.3
|
||||
seed: int = -1
|
||||
media_type: str = "wav"
|
||||
streaming_mode: bool = False
|
||||
streaming_mode: Union[bool, int] = False
|
||||
parallel_infer: bool = True
|
||||
repetition_penalty: float = 1.35
|
||||
sample_steps: int = 32
|
||||
super_sampling: bool = False
|
||||
overlap_length: int = 2
|
||||
min_chunk_length: int = 16
|
||||
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
# Author: AkagawaTsurunaki
|
||||
# Issue:
|
||||
# Stack overflow probabilistically occurs
|
||||
# when the function `sf_writef_short` of `libsndfile_64bit.dll` is called
|
||||
# using the Python library `soundfile`
|
||||
# Note:
|
||||
# This is an issue related to `libsndfile`, not this project itself.
|
||||
# It happens when you generate a large audio tensor (about 499804 frames in my PC)
|
||||
# and try to convert it to an ogg file.
|
||||
# Related:
|
||||
# https://github.com/RVC-Boss/GPT-SoVITS/issues/1199
|
||||
# https://github.com/libsndfile/libsndfile/issues/1023
|
||||
# https://github.com/bastibe/python-soundfile/issues/396
|
||||
# Suggestion:
|
||||
# Or split the whole audio data into smaller audio segment to avoid stack overflow?
|
||||
|
||||
def handle_pack_ogg():
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
|
||||
|
||||
# See: https://docs.python.org/3/library/threading.html
|
||||
# The stack size of this thread is at least 32768
|
||||
# If stack overflow error still occurs, just modify the `stack_size`.
|
||||
# stack_size = n * 4096, where n should be a positive integer.
|
||||
# Here we chose n = 4096.
|
||||
stack_size = 4096 * 4096
|
||||
try:
|
||||
threading.stack_size(stack_size)
|
||||
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
|
||||
pack_ogg_thread.start()
|
||||
pack_ogg_thread.join()
|
||||
except RuntimeError as e:
|
||||
# If changing the thread stack size is unsupported, a RuntimeError is raised.
|
||||
print("RuntimeError: {}".format(e))
|
||||
print("Changing the thread stack size is unsupported.")
|
||||
except ValueError as e:
|
||||
# If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified.
|
||||
print("ValueError: {}".format(e))
|
||||
print("The specified stack size is invalid.")
|
||||
|
||||
return io_buffer
|
||||
|
||||
|
||||
@ -286,8 +331,8 @@ def check_params(req: dict):
|
||||
)
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
|
||||
elif media_type == "ogg" and not streaming_mode:
|
||||
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||
# elif media_type == "ogg" and not streaming_mode:
|
||||
# return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||
|
||||
if text_split_method not in cut_method_names:
|
||||
return JSONResponse(
|
||||
@ -307,25 +352,26 @@ async def tts_handle(req: dict):
|
||||
"text": "", # str.(required) text to be synthesized
|
||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||
"ref_audio_path": "", # str.(required) reference audio path
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_k": 15, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
}
|
||||
returns:
|
||||
StreamingResponse: audio stream response.
|
||||
@ -338,9 +384,35 @@ async def tts_handle(req: dict):
|
||||
check_res = check_params(req)
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
if streaming_mode == 0:
|
||||
streaming_mode = False
|
||||
return_fragment = False
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 1:
|
||||
streaming_mode = False
|
||||
return_fragment = True
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 2:
|
||||
streaming_mode = True
|
||||
return_fragment = False
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 3:
|
||||
streaming_mode = True
|
||||
return_fragment = False
|
||||
fixed_length_chunk = True
|
||||
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"})
|
||||
|
||||
req["streaming_mode"] = streaming_mode
|
||||
req["return_fragment"] = return_fragment
|
||||
req["fixed_length_chunk"] = fixed_length_chunk
|
||||
|
||||
print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}")
|
||||
|
||||
streaming_mode = streaming_mode or return_fragment
|
||||
|
||||
if streaming_mode or return_fragment:
|
||||
req["return_fragment"] = True
|
||||
|
||||
try:
|
||||
tts_generator = tts_pipeline.run(req)
|
||||
@ -388,10 +460,10 @@ async def tts_get_endpoint(
|
||||
aux_ref_audio_paths: list = None,
|
||||
prompt_lang: str = None,
|
||||
prompt_text: str = "",
|
||||
top_k: int = 5,
|
||||
top_k: int = 15,
|
||||
top_p: float = 1,
|
||||
temperature: float = 1,
|
||||
text_split_method: str = "cut0",
|
||||
text_split_method: str = "cut5",
|
||||
batch_size: int = 1,
|
||||
batch_threshold: float = 0.75,
|
||||
split_bucket: bool = True,
|
||||
@ -399,11 +471,13 @@ async def tts_get_endpoint(
|
||||
fragment_interval: float = 0.3,
|
||||
seed: int = -1,
|
||||
media_type: str = "wav",
|
||||
streaming_mode: bool = False,
|
||||
parallel_infer: bool = True,
|
||||
repetition_penalty: float = 1.35,
|
||||
sample_steps: int = 32,
|
||||
super_sampling: bool = False,
|
||||
streaming_mode: Union[bool, int] = False,
|
||||
overlap_length: int = 2,
|
||||
min_chunk_length: int = 16,
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
@ -428,6 +502,8 @@ async def tts_get_endpoint(
|
||||
"repetition_penalty": float(repetition_penalty),
|
||||
"sample_steps": int(sample_steps),
|
||||
"super_sampling": super_sampling,
|
||||
"overlap_length": int(overlap_length),
|
||||
"min_chunk_length": int(min_chunk_length),
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
@ -373,7 +373,7 @@ if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ]; then
|
||||
location=$(pip show torch | grep Location | awk -F ": " '{print $2}')
|
||||
cd "${location}"/torch/lib/ || exit
|
||||
rm libhsa-runtime64.so*
|
||||
cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so
|
||||
cp "$(readlink -f /opt/rocm/lib/libhsa-runtime64.so)" libhsa-runtime64.so
|
||||
echo -e "${SUCCESS}ROCm Runtime Lib Updated..."
|
||||
fi
|
||||
|
||||
|
||||
@ -16,10 +16,10 @@ pypinyin
|
||||
pyopenjtalk>=0.4.1
|
||||
g2p_en
|
||||
torchaudio
|
||||
modelscope==1.10.0
|
||||
modelscope
|
||||
sentencepiece
|
||||
transformers>=4.43,<=4.50
|
||||
peft
|
||||
peft<0.18.0
|
||||
chardet
|
||||
PyYAML
|
||||
psutil
|
||||
@ -39,7 +39,5 @@ x_transformers
|
||||
torchmetrics<=1.5
|
||||
pydantic<=2.10.6
|
||||
ctranslate2>=4.0,<5
|
||||
huggingface_hub>=0.13
|
||||
tokenizers>=0.13,<1
|
||||
av>=11
|
||||
tqdm
|
||||
|
||||
@ -1,34 +1,13 @@
|
||||
import os
|
||||
|
||||
|
||||
def check_fw_local_models():
|
||||
"""
|
||||
启动时检查本地是否有 Faster Whisper 模型.
|
||||
"""
|
||||
model_size_list = [
|
||||
"medium",
|
||||
"medium.en",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
]
|
||||
for i, size in enumerate(model_size_list):
|
||||
if os.path.exists(f"tools/asr/models/faster-whisper-{size}"):
|
||||
model_size_list[i] = size + "-local"
|
||||
return model_size_list
|
||||
|
||||
|
||||
def get_models():
|
||||
model_size_list = [
|
||||
"medium",
|
||||
"medium.en",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
"large-v3-turbo",
|
||||
#"distil-large-v2",
|
||||
#"distil-large-v3",
|
||||
#"distil-large-v3.5",
|
||||
]
|
||||
return model_size_list
|
||||
|
||||
@ -36,7 +15,7 @@ def get_models():
|
||||
asr_dict = {
|
||||
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
|
||||
"Faster Whisper (多语种)": {
|
||||
"lang": ["auto", "zh", "en", "ja", "ko", "yue"],
|
||||
"lang": ["auto", "en", "ja", "ko"],
|
||||
"size": get_models(),
|
||||
"path": "fasterwhisper_asr.py",
|
||||
"precision": ["float32", "float16", "int8"],
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from faster_whisper import WhisperModel
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
from huggingface_hub import snapshot_download as snapshot_download_hf
|
||||
from modelscope import snapshot_download as snapshot_download_ms
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.asr.config import get_models
|
||||
@ -40,11 +40,32 @@ language_code_list = [
|
||||
|
||||
|
||||
def download_model(model_size: str):
|
||||
if "distil" in model_size:
|
||||
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
|
||||
url = "https://huggingface.co/api/models/gpt2"
|
||||
try:
|
||||
requests.get(url, timeout=3)
|
||||
source = "HF"
|
||||
except Exception:
|
||||
source = "ModelScope"
|
||||
|
||||
model_path = ""
|
||||
if source == "HF":
|
||||
if "distil" in model_size:
|
||||
if "3.5" in model_size:
|
||||
repo_id = "distil-whisper/distil-large-v3.5-ct2"
|
||||
model_path = "tools/asr/models/faster-distil-whisper-large-v3.5"
|
||||
else:
|
||||
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
|
||||
elif model_size == "large-v3-turbo":
|
||||
repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
|
||||
model_path = "tools/asr/models/faster-whisper-large-v3-turbo"
|
||||
else:
|
||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
||||
model_path = (
|
||||
model_path or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}"
|
||||
)
|
||||
else:
|
||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
||||
model_path = f"tools/asr/models/{repo_id.strip('Systran/')}"
|
||||
repo_id = "XXXXRT/faster-whisper"
|
||||
model_path = "tools/asr/models"
|
||||
|
||||
files: list[str] = [
|
||||
"config.json",
|
||||
@ -52,32 +73,31 @@ def download_model(model_size: str):
|
||||
"tokenizer.json",
|
||||
"vocabulary.txt",
|
||||
]
|
||||
if model_size == "large-v3" or "distil" in model_size:
|
||||
if "large-v3" in model_size or "distil" in model_size:
|
||||
files.append("preprocessor_config.json")
|
||||
files.append("vocabulary.json")
|
||||
|
||||
files.remove("vocabulary.txt")
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
allow_patterns=files,
|
||||
local_dir=model_path,
|
||||
)
|
||||
break
|
||||
except LocalEntryNotFoundError:
|
||||
if attempt < 1:
|
||||
time.sleep(2)
|
||||
else:
|
||||
print("[ERROR] LocalEntryNotFoundError and no fallback.")
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}")
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
if source == "ModelScope":
|
||||
files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files]
|
||||
|
||||
if source == "HF":
|
||||
print(f"Downloading model from HuggingFace: {repo_id} to {model_path}")
|
||||
snapshot_download_hf(
|
||||
repo_id,
|
||||
local_dir=model_path,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=files,
|
||||
)
|
||||
else:
|
||||
print(f"Downloading model from ModelScope: {repo_id} to {model_path}")
|
||||
snapshot_download_ms(
|
||||
repo_id,
|
||||
local_dir=model_path,
|
||||
allow_patterns=files,
|
||||
)
|
||||
return model_path + f"/faster-whisper-{model_size}".replace("whisper-distil", "distil-whisper")
|
||||
return model_path
|
||||
|
||||
|
||||
@ -106,7 +126,7 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
|
||||
)
|
||||
text = ""
|
||||
|
||||
if info.language == "zh":
|
||||
if info.language in ["zh", "yue"]:
|
||||
print("检测为中文文本, 转 FunASR 处理")
|
||||
text = only_asr(file_path, language=info.language.lower())
|
||||
|
||||
|
||||
@ -4,9 +4,8 @@ import argparse
|
||||
import os
|
||||
import traceback
|
||||
|
||||
# from funasr.utils import version_checker
|
||||
# version_checker.check_for_update = lambda: None
|
||||
from funasr import AutoModel
|
||||
from modelscope import snapshot_download
|
||||
from tqdm import tqdm
|
||||
|
||||
funasr_models = {} # 存储模型避免重复加载
|
||||
@ -16,40 +15,43 @@ def only_asr(input_file, language):
|
||||
try:
|
||||
model = create_model(language)
|
||||
text = model.generate(input=input_file)[0]["text"]
|
||||
except:
|
||||
except Exception:
|
||||
text = ""
|
||||
print(traceback.format_exc())
|
||||
return text
|
||||
|
||||
|
||||
def create_model(language="zh"):
|
||||
path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
path_vad = path_vad if os.path.exists(path_vad) else "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
path_punc = path_punc if os.path.exists(path_punc) else "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
vad_model_revision = punc_model_revision = "v2.0.4"
|
||||
|
||||
if language == "zh":
|
||||
path_vad = "tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
path_punc = "tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
path_asr = "tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
path_asr = (
|
||||
path_asr
|
||||
if os.path.exists(path_asr)
|
||||
else "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
snapshot_download(
|
||||
"iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
local_dir="tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
)
|
||||
snapshot_download(
|
||||
"iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
local_dir="tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
)
|
||||
snapshot_download(
|
||||
"iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
local_dir="tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
)
|
||||
model_revision = "v2.0.4"
|
||||
elif language == "yue":
|
||||
path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
||||
path_asr = (
|
||||
path_asr
|
||||
if os.path.exists(path_asr)
|
||||
else "iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
||||
snapshot_download(
|
||||
"iic/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
|
||||
local_dir="tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
|
||||
)
|
||||
model_revision = "master"
|
||||
path_vad = path_punc = None
|
||||
vad_model_revision = punc_model_revision = None
|
||||
###友情提示:粤语带VAD识别可能会有少量shape不对报错的,但是不带VAD可以.不带vad只能分阶段单独加标点。不过标点模型对粤语效果真的不行…
|
||||
vad_model_revision = punc_model_revision = ""
|
||||
model_revision = "master"
|
||||
else:
|
||||
raise ValueError("FunASR 不支持该语言" + ": " + language)
|
||||
raise ValueError(f"{language} is not supported")
|
||||
|
||||
vad_model_revision = punc_model_revision = "v2.0.4"
|
||||
|
||||
if language in funasr_models:
|
||||
return funasr_models[language]
|
||||
@ -83,7 +85,7 @@ def execute_asr(input_folder, output_folder, model_size, language):
|
||||
file_path = os.path.join(input_folder, file_name)
|
||||
text = model.generate(input=file_path)[0]["text"]
|
||||
output.append(f"{file_path}|{output_file_name}|{language.upper()}|{text}")
|
||||
except:
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
|
||||
output_folder = output_folder or "output/asr_opt"
|
||||
|
||||
@ -38,7 +38,7 @@
|
||||
"hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)": "hop_size: FO hop size, the smaller the value, the higher the accuracy)",
|
||||
"max:归一化后最大值多少": "Loudness multiplier after normalized",
|
||||
"max_sil_kept:切完后静音最多留多长": "Maximum length for silence to be kept",
|
||||
"min_interval:最短切割间隔": "Minumum interval for audio cutting",
|
||||
"min_interval:最短切割间隔": "Minimum interval for audio cutting",
|
||||
"min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值": "min_length: the minimum length of each segment. If the first segment is too short, it will be concatenated with the next segment until it exceeds this value",
|
||||
"temperature": "temperature",
|
||||
"threshold:音量小于这个值视作静音的备选切割点": "Noise gate threshold (loudness below this value will be treated as noise",
|
||||
@ -176,7 +176,7 @@
|
||||
"语音降噪": "Speech Denoising",
|
||||
"请上传3~10秒内参考音频,超过会报错!": "Please upload a reference audio within the 3-10 second range; if it exceeds this duration, it will raise errors.",
|
||||
"请上传参考音频": "Please Upload the Reference Audio",
|
||||
"请填入推理文本": "Please Fill in the Terget Text",
|
||||
"请填入推理文本": "Please Fill in the Target Text",
|
||||
"请填入正确的List路径": "Please Fill in the Correct List Path",
|
||||
"请填入正确的音频文件夹路径": "Please Fill in the Correct Audio Folder Path",
|
||||
"请输入有效文本": "Please enter valid text.",
|
||||
|
||||
25
webui.py
25
webui.py
@ -86,7 +86,6 @@ from config import (
|
||||
from tools import my_utils
|
||||
from tools.my_utils import check_details, check_for_existance
|
||||
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
|
||||
@ -117,8 +116,8 @@ def set_default():
|
||||
gpu_info = "\n".join(gpu_infos)
|
||||
if is_gpu_ok:
|
||||
minmem = min(mem)
|
||||
default_batch_size = minmem // 2 if version not in v3v4set else minmem // 8
|
||||
default_batch_size_s1 = minmem // 2
|
||||
default_batch_size = int(minmem // 2 if version not in v3v4set else minmem // 8)
|
||||
default_batch_size_s1 = int(minmem // 2)
|
||||
else:
|
||||
default_batch_size = default_batch_size_s1 = int(psutil.virtual_memory().total / 1024 / 1024 / 1024 / 4)
|
||||
if version not in v3v4set:
|
||||
@ -343,7 +342,7 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
|
||||
os.environ["sovits_path"] = sovits_path
|
||||
os.environ["cnhubert_base_path"] = cnhubert_base_path
|
||||
os.environ["bert_path"] = bert_path
|
||||
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_number(gpu_number)
|
||||
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
|
||||
os.environ["is_half"] = str(is_half)
|
||||
os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
|
||||
os.environ["is_share"] = str(is_share)
|
||||
@ -628,7 +627,7 @@ def open1Bb(
|
||||
data["output_dir"] = "%s/logs_s1_%s" % (s1_dir, version)
|
||||
# data["version"]=version
|
||||
|
||||
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(gpu_numbers.replace("-", ","))
|
||||
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ",")))
|
||||
os.environ["hz"] = "25hz"
|
||||
tmp_config_path = "%s/tmp_s1.yaml" % tmp
|
||||
with open(tmp_config_path, "w") as f:
|
||||
@ -801,7 +800,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
|
||||
{
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
"is_half": str(is_half),
|
||||
}
|
||||
)
|
||||
@ -892,7 +891,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
|
||||
{
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
}
|
||||
)
|
||||
os.environ.update(config)
|
||||
@ -914,7 +913,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
|
||||
{
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
}
|
||||
)
|
||||
os.environ.update(config)
|
||||
@ -986,7 +985,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
|
||||
{
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
}
|
||||
)
|
||||
os.environ.update(config)
|
||||
@ -1086,7 +1085,7 @@ def open1abc(
|
||||
{
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
}
|
||||
)
|
||||
os.environ.update(config)
|
||||
@ -1133,7 +1132,7 @@ def open1abc(
|
||||
{
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
}
|
||||
)
|
||||
os.environ.update(config)
|
||||
@ -1155,7 +1154,7 @@ def open1abc(
|
||||
{
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
}
|
||||
)
|
||||
os.environ.update(config)
|
||||
@ -1195,7 +1194,7 @@ def open1abc(
|
||||
{
|
||||
"i_part": str(i_part),
|
||||
"all_parts": str(all_parts),
|
||||
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
|
||||
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
|
||||
}
|
||||
)
|
||||
os.environ.update(config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user