diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index d82ed5d..17d72b3 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -51,17 +51,23 @@ custom: """ -# def set_seed(seed): -# random.seed(seed) -# os.environ['PYTHONHASHSEED'] = str(seed) -# np.random.seed(seed) -# torch.manual_seed(seed) -# torch.cuda.manual_seed(seed) -# torch.cuda.manual_seed_all(seed) -# torch.backends.cudnn.deterministic = True -# torch.backends.cudnn.benchmark = False -# torch.backends.cudnn.enabled = True -# set_seed(1234) +def set_seed(seed:int): + seed = int(seed) + seed = seed if seed != -1 else random.randrange(1 << 32) + print(f"Set seed to {seed}") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + try: + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.enabled = True + except: + pass + return seed class TTS_Config: default_configs={ @@ -563,6 +569,7 @@ class TTS: "return_fragment": False, # bool. step by step return the audio fragment. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. } returns: tulpe[int, np.ndarray]: sampling rate and audio data. @@ -584,6 +591,9 @@ class TTS: split_bucket = inputs.get("split_bucket", True) return_fragment = inputs.get("return_fragment", False) fragment_interval = inputs.get("fragment_interval", 0.3) + seed = inputs.get("seed", -1) + + set_seed(seed) if return_fragment: # split_bucket = False diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index be0aad6..505b665 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -92,6 +92,7 @@ def inference(text, text_lang, text_split_method, batch_size, speed_factor, ref_text_free, split_bucket,fragment_interval, + seed, ): inputs={ "text": text, @@ -108,6 +109,7 @@ def inference(text, text_lang, "split_bucket":split_bucket, "return_fragment":False, "fragment_interval":fragment_interval, + "seed":seed, } for item in tts_pipline.run(inputs): @@ -203,6 +205,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: ) with gr.Row(): split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) + seed = gr.Number(label=i18n("随机种子"),value=-1) # with gr.Column(): output = gr.Audio(label=i18n("输出的语音")) with gr.Row(): @@ -219,6 +222,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: how_to_cut, batch_size, speed_factor, ref_text_free, split_bucket,fragment_interval, + seed ], [output], )