Merge pull request #200 from Lion-Wu/mps

Fix Several Severe Errors
This commit is contained in:
RVC-Boss 2024-01-26 10:28:15 +08:00 committed by GitHub
commit 8d2db8254d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 14 deletions

View File

@ -47,12 +47,12 @@ if os.path.exists(txt_path) == False:
bert_dir = "%s/3-bert" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(bert_dir, exist_ok=True)
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if is_half == True:

View File

@ -38,12 +38,12 @@ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
if os.path.exists(semantic_path) == False:
os.makedirs(opt_dir, exist_ok=True)
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
hps = utils.get_hparams_from_file(s2config_path)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,

3
api.py
View File

@ -13,7 +13,7 @@
`-dt` - `默认参考音频文本`
`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"`
`-d` - `推理设备, "cuda","cpu"`
`-d` - `推理设备, "cuda","cpu","mps"`
`-a` - `绑定地址, 默认"127.0.0.1"`
`-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
`-fp` - `覆盖 config.py 使用全精度`
@ -139,7 +139,6 @@ parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")