support mps, optimized device selection

This commit is contained in:
Wu Zichen 2024-01-24 17:27:58 +08:00
parent cb9d8fe8a5
commit a8e603445f

View File

@ -35,7 +35,13 @@ from my_utils import load_audio
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
device = "cuda"
if torch.cuda.is_available():
device = "cuda"
elif torch.mps.is_available():
device = "mps"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True: