spellcheck (#916)

Co-authored-by: starylan <starylan@outlook.com>
This commit is contained in:
SapphireLab 2024-04-03 17:42:23 +08:00 committed by GitHub
parent ed75ecdd6d
commit 72c0eca0a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 39 deletions

View File

@ -140,7 +140,7 @@ class TTS_Config:
self.win_length:int = 2048
self.n_speakers:int = 300
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
# print(self)
def _load_configs(self, configs_path: str)->dict:
@ -207,19 +207,19 @@ class TTS:
self.prompt_cache:dict = {
"ref_audio_path":None,
"prompt_semantic":None,
"refer_spepc":None,
"prompt_text":None,
"prompt_lang":None,
"phones":None,
"bert_features":None,
"norm_text":None,
"ref_audio_path" : None,
"prompt_semantic": None,
"refer_spec" : None,
"prompt_text" : None,
"prompt_lang" : None,
"phones" : None,
"bert_features" : None,
"norm_text" : None,
}
self.stop_flag:bool = False
self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
self.precision:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
def _init_models(self,):
self.init_t2s_weights(self.configs.t2s_weights_path)
@ -312,7 +312,7 @@ class TTS:
return
self.configs.is_half = enable
self.precison = torch.float16 if enable else torch.float32
self.precision = torch.float16 if enable else torch.float32
self.configs.save_configs()
if enable:
if self.t2s_model is not None:
@ -358,9 +358,9 @@ class TTS:
ref_audio_path: str, the path of the reference audio.
'''
self._set_prompt_semantic(ref_audio_path)
self._set_ref_spepc(ref_audio_path)
self._set_ref_spec(ref_audio_path)
def _set_ref_spepc(self, ref_audio_path):
def _set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
@ -376,8 +376,8 @@ class TTS:
spec = spec.to(self.configs.device)
if self.configs.is_half:
spec = spec.half()
# self.refer_spepc = spec
self.prompt_cache["refer_spepc"] = spec
# self.refer_spec = spec
self.prompt_cache["refer_spec"] = spec
def _set_prompt_semantic(self, ref_wav_path:str):
@ -435,7 +435,7 @@ class TTS:
threshold:float=0.75,
split_bucket:bool=True,
device:torch.device=torch.device("cpu"),
precison:torch.dtype=torch.float32,
precision:torch.dtype=torch.float32,
):
_data:list = []
@ -488,13 +488,13 @@ class TTS:
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=precison, device=device)
.to(dtype=precision, device=device)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]\
.to(dtype=precison, device=device)
.to(dtype=precision, device=device)
phones = torch.LongTensor(item["phones"]).to(device)
all_phones = phones
# norm_text = item["norm_text"]
@ -519,7 +519,7 @@ class TTS:
#### 直接对phones和bert_features进行pad会增大复读概率。
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
# all_bert_features_batch = all_bert_features_list
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precison, device=device)
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
# for idx, item in enumerate(all_bert_features_list):
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
@ -555,8 +555,8 @@ class TTS:
Returns:
list (List[np.ndarray]): the data in the original order.
'''
lenght = len(sum(batch_index_list, []))
_data = [None]*lenght
length = len(sum(batch_index_list, []))
_data = [None]*length
for i, index_list in enumerate(batch_index_list):
for j, index in enumerate(index_list):
_data[index] = data[i][j]
@ -584,7 +584,7 @@ class TTS:
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentaion_method.py for details.
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
@ -594,7 +594,7 @@ class TTS:
"seed": -1, # int. random seed for reproducibility.
}
returns:
tulpe[int, np.ndarray]: sampling rate and audio data.
tuple[int, np.ndarray]: sampling rate and audio data.
"""
########## variables initialization ###########
self.stop_flag:bool = False
@ -635,12 +635,12 @@ class TTS:
if prompt_text in [None, ""]:
no_prompt_text = True
assert text_lang in self.configs.langauges
assert text_lang in self.configs.languages
if not no_prompt_text:
assert prompt_lang in self.configs.langauges
assert prompt_lang in self.configs.languages
if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)):
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
@ -682,7 +682,7 @@ class TTS:
threshold=batch_threshold,
split_bucket=split_bucket,
device=self.configs.device,
precison=self.precison
precision=self.precision
)
else:
print(i18n("############ 切分文本 ############"))
@ -714,7 +714,7 @@ class TTS:
threshold=batch_threshold,
split_bucket=False,
device=self.configs.device,
precison=self.precison
precision=self.precision
)
return batch[0]
@ -760,8 +760,8 @@ class TTS:
t4 = ttime()
t_34 += t4 - t3
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
.to(dtype=self.precison, device=self.configs.device)
refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\
.to(dtype=self.precision, device=self.configs.device)
batch_audio_fragment = []
@ -775,7 +775,7 @@ class TTS:
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# ))
# ## vits并行推理 method 2
@ -786,7 +786,7 @@ class TTS:
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones,refer_audio_spepc
all_pred_semantic, _batch_phones, refer_audio_spec
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
@ -797,7 +797,7 @@ class TTS:
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
# audio_fragment =(self.vits_model.decode(
# _pred_semantic, phones, refer_audio_spepc
# _pred_semantic, phones, refer_audio_spec
# ).detach()[0, 0, :])
# batch_audio_fragment.append(
# audio_fragment
@ -866,7 +866,7 @@ class TTS:
)->tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval),
dtype=self.precison,
dtype=self.precision,
device=self.configs.device
)

View File

@ -82,7 +82,7 @@ if bert_path is not None:
tts_config.bert_base_path = bert_path
print(tts_config)
tts_pipline = TTS(tts_config)
tts_pipeline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
@ -113,7 +113,7 @@ def inference(text, text_lang,
"fragment_interval":fragment_interval,
"seed":actual_seed,
}
for item in tts_pipline.run(inputs):
for item in tts_pipeline.run(inputs):
yield item, actual_seed
def custom_sort_key(s):
@ -162,8 +162,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
SoVITS_dropdown.change(tts_pipline.init_vits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(tts_pipline.init_t2s_weights, [GPT_dropdown], [])
SoVITS_dropdown.change(tts_pipeline.init_vits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
with gr.Row():
with gr.Column():
@ -227,7 +227,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
],
[output, seed],
)
stop_infer.click(tts_pipline.stop, [], [])
stop_infer.click(tts_pipeline.stop, [], [])
with gr.Group():
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))