添加了自定义修改随机数种子,方便复现结果。

This commit is contained in:
chasonjiang 2024-03-15 14:34:10 +08:00
parent b8ce03fd1b
commit a2f2a5f4a7
2 changed files with 25 additions and 11 deletions

View File

@ -51,17 +51,23 @@ custom:
"""
# def set_seed(seed):
# random.seed(seed)
# os.environ['PYTHONHASHSEED'] = str(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = True
# set_seed(1234)
def set_seed(seed:int):
seed = int(seed)
seed = seed if seed != -1 else random.randrange(1 << 32)
print(f"Set seed to {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
try:
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = True
except:
pass
return seed
class TTS_Config:
default_configs={
@ -563,6 +569,7 @@ class TTS:
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
}
returns:
tulpe[int, np.ndarray]: sampling rate and audio data.
@ -584,6 +591,9 @@ class TTS:
split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False)
fragment_interval = inputs.get("fragment_interval", 0.3)
seed = inputs.get("seed", -1)
set_seed(seed)
if return_fragment:
# split_bucket = False

View File

@ -92,6 +92,7 @@ def inference(text, text_lang,
text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed,
):
inputs={
"text": text,
@ -108,6 +109,7 @@ def inference(text, text_lang,
"split_bucket":split_bucket,
"return_fragment":False,
"fragment_interval":fragment_interval,
"seed":seed,
}
for item in tts_pipline.run(inputs):
@ -203,6 +205,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
)
with gr.Row():
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
seed = gr.Number(label=i18n("随机种子"),value=-1)
# with gr.Column():
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():
@ -219,6 +222,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
how_to_cut, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed
],
[output],
)