diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index 02f4e66..fa07cf7 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -7,7 +7,7 @@ from glob import glob from faster_whisper import WhisperModel from tqdm import tqdm -from config import check_fw_local_models +from tools.asr.config import check_fw_local_models os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" @@ -34,7 +34,7 @@ language_code_list = [ "vi", "yi", "yo", "zh", "yue", "auto"] -def execute_asr(input_folder, output_folder, model_size, language): +def execute_asr(input_folder, output_folder, model_size, language,precision): if 'local' in model_size: model_size = model_size.split('(')[0] model_path = f'tools/asr/models/faster-whisper-{model_size}' @@ -42,12 +42,11 @@ def execute_asr(input_folder, output_folder, model_size, language): model_path = model_size if language == 'auto': language = None #不设置语种由模型自动输出概率最高的语种 - + print("loading faster whisper model:",model_size,model_path) try: - model = WhisperModel(model_path, device="cuda", compute_type="float16") + model = WhisperModel(model_path, device="cuda", compute_type=precision) except: return print(traceback.format_exc()) - output = [] output_file_name = os.path.basename(input_folder) output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list') @@ -84,9 +83,11 @@ if __name__ == '__main__': parser.add_argument("-s", "--model_size", type=str, default='large-v3', choices=check_fw_local_models(), help="Model Size of Faster Whisper") - parser.add_argument("-l", "--language", type=str, default='zh', + parser.add_argument("-l", "--language", type=str, default='ja', choices=language_code_list, help="Language of the audio files.") + parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'], + help="fp16 or fp32") cmd = parser.parse_args() output_file_path = execute_asr( @@ -94,4 +95,5 @@ if __name__ == '__main__': output_folder = cmd.output_folder, model_size = cmd.model_size, language = cmd.language, + precision = cmd.precision, ) \ No newline at end of file diff --git a/tools/asr/funasr_asr.py b/tools/asr/funasr_asr.py index 9d6a5e4..f6673b7 100644 --- a/tools/asr/funasr_asr.py +++ b/tools/asr/funasr_asr.py @@ -56,6 +56,8 @@ if __name__ == '__main__': help="Model Size of FunASR is Large") parser.add_argument("-l", "--language", type=str, default='zh', choices=['zh'], help="Language of the audio files.") + parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'], + help="fp16 or fp32")#还没接入 cmd = parser.parse_args() execute_asr(