From a2f2a5f4a7d05ba908a34290f5baab1425ea5955 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Fri, 15 Mar 2024 14:34:10 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E8=87=AA=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=E4=BF=AE=E6=94=B9=E9=9A=8F=E6=9C=BA=E6=95=B0=E7=A7=8D?= =?UTF-8?q?=E5=AD=90=EF=BC=8C=E6=96=B9=E4=BE=BF=E5=A4=8D=E7=8E=B0=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 32 +++++++++++++++++++++----------- GPT_SoVITS/inference_webui.py | 4 ++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index d82ed5d..17d72b3 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -51,17 +51,23 @@ custom: """ -# def set_seed(seed): -# random.seed(seed) -# os.environ['PYTHONHASHSEED'] = str(seed) -# np.random.seed(seed) -# torch.manual_seed(seed) -# torch.cuda.manual_seed(seed) -# torch.cuda.manual_seed_all(seed) -# torch.backends.cudnn.deterministic = True -# torch.backends.cudnn.benchmark = False -# torch.backends.cudnn.enabled = True -# set_seed(1234) +def set_seed(seed:int): + seed = int(seed) + seed = seed if seed != -1 else random.randrange(1 << 32) + print(f"Set seed to {seed}") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + try: + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.enabled = True + except: + pass + return seed class TTS_Config: default_configs={ @@ -563,6 +569,7 @@ class TTS: "return_fragment": False, # bool. step by step return the audio fragment. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. } returns: tulpe[int, np.ndarray]: sampling rate and audio data. @@ -584,6 +591,9 @@ class TTS: split_bucket = inputs.get("split_bucket", True) return_fragment = inputs.get("return_fragment", False) fragment_interval = inputs.get("fragment_interval", 0.3) + seed = inputs.get("seed", -1) + + set_seed(seed) if return_fragment: # split_bucket = False diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index be0aad6..505b665 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -92,6 +92,7 @@ def inference(text, text_lang, text_split_method, batch_size, speed_factor, ref_text_free, split_bucket,fragment_interval, + seed, ): inputs={ "text": text, @@ -108,6 +109,7 @@ def inference(text, text_lang, "split_bucket":split_bucket, "return_fragment":False, "fragment_interval":fragment_interval, + "seed":seed, } for item in tts_pipline.run(inputs): @@ -203,6 +205,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: ) with gr.Row(): split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) + seed = gr.Number(label=i18n("随机种子"),value=-1) # with gr.Column(): output = gr.Audio(label=i18n("输出的语音")) with gr.Row(): @@ -219,6 +222,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: how_to_cut, batch_size, speed_factor, ref_text_free, split_bucket,fragment_interval, + seed ], [output], )