优化代码结构

This commit is contained in:
chasonjiang 2024-03-19 22:29:07 +08:00
parent df22a4fc04
commit c3ac108ed4
2 changed files with 24 additions and 28 deletions

View File

@ -576,22 +576,22 @@ class TTS:
Args:
inputs (dict):
{
"text": "", # str. text to be synthesized
"text_lang: "", # str. language of the text to be synthesized
"ref_audio_path": "", # str. reference audio path
"prompt_text": "", # str. prompt text for the reference audio
"prompt_lang": "", # str. language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentaion_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
}
returns:
tulpe[int, np.ndarray]: sampling rate and audio data.
@ -606,7 +606,7 @@ class TTS:
top_k:int = inputs.get("top_k", 5)
top_p:float = inputs.get("top_p", 1)
temperature:float = inputs.get("temperature", 1)
text_split_method:str = inputs.get("text_split_method", "")
text_split_method:str = inputs.get("text_split_method", "cut0")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
@ -824,16 +824,13 @@ class TTS:
if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield [
self.audio_postprocess(audio,
yield self.audio_postprocess(audio,
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket,
fragment_interval
),
f"<strong>text:</strong> {text} <strong>text_lang:</strong> {text_lang} <strong>prompt_text:</strong> {prompt_text} <strong>prompt_lang:</strong> {prompt_lang} <strong>top_k:</strong> {top_k} <strong>top_p:</strong> {top_p} <strong>temperature:</strong> {temperature} <strong>batch_size:</strong> {batch_size} <strong>batch_threshold:</strong> {batch_threshold} <strong>split_bucket:</strong> {split_bucket} <strong>return_fragment:</strong> {return_fragment} <strong>speed_factor:</strong> {speed_factor} <strong>fragment_interval:</strong> {fragment_interval} <strong>seed:</strong> {actual_seed}"
]
)
except Exception as e:
traceback.print_exc()

View File

@ -6,6 +6,7 @@
全部按英文识别
全部按日文识别
'''
import random
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -94,6 +95,7 @@ def inference(text, text_lang,
split_bucket,fragment_interval,
seed,
):
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
inputs={
"text": text,
"text_lang": dict_language[text_lang],
@ -109,11 +111,10 @@ def inference(text, text_lang,
"split_bucket":split_bucket,
"return_fragment":False,
"fragment_interval":fragment_interval,
"seed":seed,
"seed":actual_seed,
}
for item in tts_pipline.run(inputs):
yield item
yield item, actual_seed
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
@ -211,8 +212,6 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
with gr.Row():
inference_details = gr.Markdown()
inference_button.click(
@ -226,7 +225,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
split_bucket,fragment_interval,
seed
],
[output, inference_details],
[output, seed],
)
stop_infer.click(tts_pipline.stop, [], [])