diff --git a/GPT_SoVITS/prepare_datasets/1-get-text.py b/GPT_SoVITS/prepare_datasets/1-get-text.py index 9499db4..88c9d85 100644 --- a/GPT_SoVITS/prepare_datasets/1-get-text.py +++ b/GPT_SoVITS/prepare_datasets/1-get-text.py @@ -47,12 +47,12 @@ if os.path.exists(txt_path) == False: bert_dir = "%s/3-bert" % (opt_dir) os.makedirs(opt_dir, exist_ok=True) os.makedirs(bert_dir, exist_ok=True) -if torch.cuda.is_available(): - device = "cuda:0" -elif torch.backends.mps.is_available(): - device = "mps" -else: - device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir) bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir) if is_half == True: diff --git a/GPT_SoVITS/prepare_datasets/3-get-semantic.py b/GPT_SoVITS/prepare_datasets/3-get-semantic.py index a3cf0a3..9ab56a4 100644 --- a/GPT_SoVITS/prepare_datasets/3-get-semantic.py +++ b/GPT_SoVITS/prepare_datasets/3-get-semantic.py @@ -38,12 +38,12 @@ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part) if os.path.exists(semantic_path) == False: os.makedirs(opt_dir, exist_ok=True) -if torch.cuda.is_available(): - device = "cuda" -elif torch.backends.mps.is_available(): - device = "mps" -else: - device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" hps = utils.get_hparams_from_file(s2config_path) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, diff --git a/api.py b/api.py index 1b5b6a0..60ed9ff 100644 --- a/api.py +++ b/api.py @@ -13,7 +13,7 @@ `-dt` - `默认参考音频文本` `-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` -`-d` - `推理设备, "cuda","cpu"` +`-d` - `推理设备, "cuda","cpu","mps"` `-a` - `绑定地址, 默认"127.0.0.1"` `-p` - `绑定端口, 默认9880, 可在 config.py 中指定` `-fp` - `覆盖 config.py 使用全精度` @@ -139,7 +139,6 @@ parser.add_argument("-dt", "--default_refer_text", type=str, default="", help=" parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps") -parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")