GPT-SoVITS/tools/asr/fasterwhisper_asr.py
SapphireLab 94e34568dd fix ASR
2024-02-15 01:39:13 +08:00

141 lines
5.9 KiB
Python

import argparse
import os
import traceback
import torch
from faster_whisper import WhisperModel
from tools.asr.config import fw_model_size_list, BaseASR
from tools.asr.funasr_asr import FunASR
from tools.my_utils import ASR_Logger
os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
language_code_list = [
"af", "am", "ar", "as", "az",
"ba", "be", "bg", "bn", "bo",
"br", "bs", "ca", "cs", "cy",
"da", "de", "el", "en", "es",
"et", "eu", "fa", "fi", "fo",
"fr", "gl", "gu", "ha", "haw",
"he", "hi", "hr", "ht", "hu",
"hy", "id", "is", "it", "ja",
"jw", "ka", "kk", "km", "kn",
"ko", "la", "lb", "ln", "lo",
"lt", "lv", "mg", "mi", "mk",
"ml", "mn", "mr", "ms", "mt",
"my", "ne", "nl", "nn", "no",
"oc", "pa", "pl", "ps", "pt",
"ro", "ru", "sa", "sd", "si",
"sk", "sl", "sn", "so", "sq",
"sr", "su", "sv", "sw", "ta",
"te", "tg", "th", "tk", "tl",
"tr", "tt", "uk", "ur", "uz",
"vi", "yi", "yo", "zh", "yue",
"auto"]
class FasterWhisperASR(BaseASR):
def __init__(self, model_size, device="cuda", precision="float16"):
device, precision = [device, precision] if torch.cuda.is_available() else ["cpu", "float32"]
self.check_local_models()
self.model = self.load_model(model_size, device, precision)
assert self.model is not None, ASR_Logger.error('模型不存在')
self.zh_model = None
@classmethod
def check_local_models(self):
'''
启动时检查本地是否有 Faster Whisper 模型.
'''
self.model_size_list = fw_model_size_list.copy()
self.model_path_dict = {}
for i, size in enumerate(self.model_size_list):
model_name = f"faster-whisper-{size}"
model_path, flag = super().check_local_model(
self,
model_name = model_name,
model_file = 'model.bin',
cache_path = os.path.normpath(os.path.expanduser(f"~/.cache/huggingface/hub/")))
if flag:
self.model_size_list[i] = f"{size}-{flag}"
self.model_path_dict[self.model_size_list[i]] = model_path
return self.model_size_list
def load_model(self, model_size, device="cuda", precision="float16"):
if '-local' in model_size or '-cache' in model_size:
model_path = self.model_path_dict[model_size]
model_size = model_size[:-6]
ASR_Logger.info(f"加载模型: 从 {model_path} 加载 faster-whisper-{model_size} 模型.")
if 'huggingface' in model_path:
ASR_Logger.warning(f"可将 {model_path} 移动到 tools/asr/models/ 文件夹下并重命名为 faster-whisper-{model_size}.")
else:
model_path = model_size
ASR_Logger.warning(f"下载模型: 从 https://hf-mirror.com/Systran/faster-whisper-{model_size} 下载 faster-whisper-{model_size} 模型.")
try:
model = WhisperModel(model_path, device=device, compute_type=precision)
if model.model.device != 'cpu':
device_name = torch.cuda.get_device_name(model.model.device)
else:
device_name = 'CPU'
ASR_Logger.info(f"运行设备: {device_name}, 设定精度: {precision}.")
ASR_Logger.info(f"创建模型: Faster Whisper 完成.\n")
return model
except:
ASR_Logger.info(traceback.format_exc())
ASR_Logger.error(f"模型加载失败 or 下载失败, 可访问 https://hf-mirror.com/Systran/faster-whisper-{model_size} 自行下载, 并放置于 tools/asr/models/ 文件夹下")
return
def inference(self, file_path, language='auto'):
try:
if language == 'auto':
language = None
segments, info = self.model.transcribe(
audio = file_path,
beam_size = 5,
vad_filter = True,
vad_parameters = dict(min_silence_duration_ms=700),
language = language)
if info.language == "zh":
ASR_Logger.info("检测为中文文本, 转 FunASR 处理.")
if self.zh_model is None:
self.zh_model = FunASR()
text, language = self.zh_model.inference(file_path)
else:
text = ''.join([segment.text for segment in segments])
return text, info.language
except:
ASR_Logger.error(f"当前文件 {file_path} 转写失败, 可能不是有效的音频文件.")
ASR_Logger.error(traceback.format_exc())
return '', ''
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_file_or_folder", type=str, required=True,
help="Input audio file path or folder contain audio files.")
parser.add_argument("-o", "--output_folder", type=str, required=True,
help="Output folder to store transcriptions.")
parser.add_argument("-s", "--model_size", type=str, default='large-v3',
choices=FasterWhisperASR.check_local_models(),
help="Model Size of Faster Whisper")
parser.add_argument("-l", "--language", type=str, default='auto',
choices=language_code_list,
help="Language of the audio files.")
parser.add_argument("-p", "--precision", type=str, default='float16',
choices=['float16','float32'], help="fp16 or fp32")
cmd = parser.parse_args()
ASR = FasterWhisperASR(
model_size = cmd.model_size,
precision = cmd.precision,
)
ASR.inference_file_or_folder(
input_file_or_folder = cmd.input_file_or_folder,
output_folder = cmd.output_folder,
language = cmd.language,
)