diff --git a/api.py b/api.py index 6de3fa15..a2f8422e 100644 --- a/api.py +++ b/api.py @@ -998,9 +998,9 @@ else: # 初始化模型 cnhubert.cnhubert_base_path = cnhubert_base_path # tokenizer = AutoTokenizer.from_pretrained(bert_path) -tokenizer = AutoTokenizer.from_pretrained(os.path.join(os.getcwd(), bert_path)) +tokenizer = AutoTokenizer.from_pretrained(os.path.join(os.getcwd(), bert_path), local_files_only=True) -bert_model = AutoModelForMaskedLM.from_pretrained(os.path.join(os.getcwd(), bert_path)) +bert_model = AutoModelForMaskedLM.from_pretrained(os.path.join(os.getcwd(), bert_path), local_files_only=True) ssl_model = cnhubert.get_model() if is_half: bert_model = bert_model.half().to(device)