diff --git a/GPT_SoVITS/stream_v2pro.py b/GPT_SoVITS/stream_v2pro.py index e80293b7..0e615a7e 100644 --- a/GPT_SoVITS/stream_v2pro.py +++ b/GPT_SoVITS/stream_v2pro.py @@ -566,18 +566,42 @@ def export_prov2( torch.jit.trace_module(vits, inputs=inputs, optimize=True).save(f"{output_path}/vits.pt") torch.jit.script(find_best_audio_offset_fast, optimize=True).save(f"{output_path}/find_best_audio_offset_fast.pt") +import argparse +import os if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument( + "--sovits_model", required=True, help="Path to the SoVITS model file" + ) + parser.add_argument( + "--ref_audio", required=True, help="Path to the reference audio file" + ) + parser.add_argument( + "--ref_text", required=True, help="Path to the reference text file" + ) + parser.add_argument( + "--output_path", required=True, help="Path to the output directory" + ) + parser.add_argument("--device", help="Device to use", default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--version", help="version of the model", default="v2Pro") + parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights") + + args = parser.parse_args() + + if not os.path.exists(args.output_path): + os.makedirs(args.output_path) + + is_half = not args.no_half with torch.no_grad(): - test_stream( - gpt_path="GPT_SoVITS/pretrained_models/s1v3.ckpt", - vits_path="GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", - version="v2Pro", - # ref_audio_path="/mnt/g/ad_ref.wav", - # ref_text="你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说.", - ref_audio_path="output/denoise_opt/ht/ht.mp4_0000026560_0000147200.wav", - ref_text='说真的,这件衣服才配得上本小姐嘛', - output_path="streaming", - device="cuda", - is_half=True, + export_prov2( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + version=args.version, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + is_half=is_half, )