From 3124fcf4971ab5b7fc74df7dacbdb6d30bc7b6fd Mon Sep 17 00:00:00 2001 From: SapphireLab <36986837+SapphireLab@users.noreply.github.com> Date: Mon, 15 Apr 2024 23:47:39 +0800 Subject: [PATCH] =?UTF-8?q?[ASR]=20=E4=BF=AE=E5=A4=8DFasterWhisper?= =?UTF-8?q?=E9=81=8D=E5=8E=86=E8=BE=93=E5=85=A5=E8=B7=AF=E5=BE=84=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5=20(#956)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove glob * rename * reset mirror pos --- tools/asr/fasterwhisper_asr.py | 42 +++++++++++++++++++--------------- tools/asr/funasr_asr.py | 7 +++--- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index f7b31aa..669ac3a 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -1,18 +1,16 @@ import argparse import os -os.environ["HF_ENDPOINT"]="https://hf-mirror.com" import traceback -import requests -from glob import glob -import torch +os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import torch from faster_whisper import WhisperModel from tqdm import tqdm from tools.asr.config import check_fw_local_models -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" - language_code_list = [ "af", "am", "ar", "as", "az", "ba", "be", "bg", "bn", "bo", @@ -36,7 +34,7 @@ language_code_list = [ "vi", "yi", "yo", "zh", "yue", "auto"] -def execute_asr(input_folder, output_folder, model_size, language,precision): +def execute_asr(input_folder, output_folder, model_size, language, precision): if '-local' in model_size: model_size = model_size[:-6] model_path = f'tools/asr/models/faster-whisper-{model_size}' @@ -50,17 +48,18 @@ def execute_asr(input_folder, output_folder, model_size, language,precision): model = WhisperModel(model_path, device=device, compute_type=precision) except: return print(traceback.format_exc()) + + input_file_names = os.listdir(input_folder) + input_file_names.sort() + output = [] output_file_name = os.path.basename(input_folder) - output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list') - - if not os.path.exists(output_folder): - os.makedirs(output_folder) - - for file in tqdm(glob(os.path.join(input_folder, '**/*.wav'), recursive=True)): + + for file_name in tqdm(input_file_names): try: + file_path = os.path.join(input_folder, file_name) segments, info = model.transcribe( - audio = file, + audio = file_path, beam_size = 5, vad_filter = True, vad_parameters = dict(min_silence_duration_ms=700), @@ -68,18 +67,23 @@ def execute_asr(input_folder, output_folder, model_size, language,precision): text = '' if info.language == "zh": - print("检测为中文文本,转funasr处理") + print("检测为中文文本, 转 FunASR 处理") if("only_asr"not in globals()): - from tools.asr.funasr_asr import only_asr##如果用英文就不需要导入下载模型 - text = only_asr(file) + from tools.asr.funasr_asr import \ + only_asr # #如果用英文就不需要导入下载模型 + text = only_asr(file_path) if text == '': for segment in segments: text += segment.text - output.append(f"{file}|{output_file_name}|{info.language.upper()}|{text}") + output.append(f"{file_path}|{output_file_name}|{info.language.upper()}|{text}") except: return print(traceback.format_exc()) - + + output_folder = output_folder or "output/asr_opt" + os.makedirs(output_folder, exist_ok=True) + output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list') + with open(output_file_path, "w", encoding="utf-8") as f: f.write("\n".join(output)) print(f"ASR 任务完成->标注文件路径: {output_file_path}\n") diff --git a/tools/asr/funasr_asr.py b/tools/asr/funasr_asr.py index 6aa3038..831da6c 100644 --- a/tools/asr/funasr_asr.py +++ b/tools/asr/funasr_asr.py @@ -38,10 +38,11 @@ def execute_asr(input_folder, output_folder, model_size, language): output = [] output_file_name = os.path.basename(input_folder) - for name in tqdm(input_file_names): + for file_name in tqdm(input_file_names): try: - text = model.generate(input="%s/%s"%(input_folder, name))[0]["text"] - output.append(f"{input_folder}/{name}|{output_file_name}|{language.upper()}|{text}") + file_path = os.path.join(input_folder, file_name) + text = model.generate(input=file_path)[0]["text"] + output.append(f"{file_path}|{output_file_name}|{language.upper()}|{text}") except: print(traceback.format_exc())