diff --git a/api.py b/api.py index 0f34149f..c8a1f5d6 100644 --- a/api.py +++ b/api.py @@ -1007,6 +1007,7 @@ else: logger.info(f"数据类型: int16") # 初始化模型 +os.environ["bert_path"] = bert_path cnhubert.cnhubert_base_path = cnhubert_base_path tokenizer = AutoTokenizer.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)