From 72c5d3224e8ea22a97c74fbd54bc884f51553bfb Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Sun, 24 Aug 2025 02:11:47 -0400 Subject: [PATCH] utility updates --- .gitignore | 4 +++- playground/export_bert.py | 19 +++++++++++++++++++ requirements.txt | 1 + 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index d645f020..b98ec8f4 100644 --- a/.gitignore +++ b/.gitignore @@ -194,7 +194,9 @@ cython_debug/ # PyPI configuration file .pypirc +#onnx onnx/ *.onnx tokenizer.json -output.wav \ No newline at end of file +output.wav +config.json \ No newline at end of file diff --git a/playground/export_bert.py b/playground/export_bert.py index c6341a08..b8d413c9 100644 --- a/playground/export_bert.py +++ b/playground/export_bert.py @@ -86,6 +86,25 @@ def export_bert_to_onnx( print(f"Copied tokenizer.json to: {dest_tokenizer_path}") else: print("Warning: tokenizer.json not found") + + # Copy config.json if it exists + if tokenizer_cache_dir and os.path.isdir(tokenizer_cache_dir): + config_json_path = os.path.join(tokenizer_cache_dir, "config.json") + else: + # For models from HuggingFace cache + cache_dir = os.path.expanduser("~/.cache/huggingface/transformers") + config_json_path = None + for root, dirs, files in os.walk(cache_dir): + if "config.json" in files and model_name.replace("/", "--") in root: + config_json_path = os.path.join(root, "config.json") + break + + if config_json_path and os.path.exists(config_json_path): + dest_config_path = os.path.join(output_dir, "config.json") + shutil.copy2(config_json_path, dest_config_path) + print(f"Copied config.json to: {dest_config_path}") + else: + print("Warning: config.json not found") print(f"Model exported successfully to: {output_dir}") return combined_model, onnx_path diff --git a/requirements.txt b/requirements.txt index 3b71a9e6..d6f9a9ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ onnx onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64" onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64" onnxsim +onnxruntime-tools tqdm funasr==1.0.27 cn2an