""" # api.py usage ` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` ## 执行参数: `-s` - `SoVITS模型路径, 可在 config.py 中指定` `-g` - `GPT模型路径, 可在 config.py 中指定` 调用请求缺少参考音频时使用 `-dr` - `默认参考音频路径` `-dt` - `默认参考音频文本` `-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` `-cut` - `默认切分方式,"凑四句一切","凑50字一切","按中文句号。切","按英文句号.切","按标点符号切",分别为cut1,cut2,cut3,cut4,cut5` `-d` - `推理设备, "cuda","cpu","mps"` `-a` - `绑定地址, 默认"127.0.0.1"` `-p` - `绑定端口, 默认9880, 可在 config.py 中指定` `-fp` - `覆盖 config.py 使用全精度` `-hp` - `覆盖 config.py 使用半精度` `-hb` - `cnhubert路径` `-b` - `bert路径` ## 调用: ### 推理 endpoint: `/` 使用执行参数指定的参考音频与切分方式: GET: `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` POST: ```json { "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", "text_language": "zh" } ``` 手动指定当次推理所使用的参考音频与切分方式: GET: `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&how_to_cut=cut1` POST: ```json { "refer_wav_path": "123.wav", "prompt_text": "一二三。", "prompt_language": "zh", "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", "text_language": "zh" "how_to_cut"= "cut1" } ``` RESP: 成功: 直接返回 wav 音频流, http code 200 失败: 返回包含错误信息的 json, http code 400 ### 更换默认参考音频及切分方式 endpoint: `/change_refer` key与推理端一样 GET: `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&how_to_cut=cut1` POST: ```json { "refer_wav_path": "123.wav", "prompt_text": "一二三。", "prompt_language": "zh" "how_to_cut": "cut1" } ``` RESP: 成功: json, http code 200 失败: json, 400 ### 命令控制 endpoint: `/control` command: "restart": 重新运行 "exit": 结束运行 GET: `http://127.0.0.1:9880/control?command=restart` POST: ```json { "command": "restart" } ``` RESP: 无 """ import argparse import os,re import sys now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import signal from time import time as ttime import torch import librosa import soundfile as sf from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np from feature_extractor import cnhubert from io import BytesIO from module.models import SynthesizerTrn from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text from module.mel_processing import spectrogram_torch from my_utils import load_audio import config as global_config g_config = global_config.Config() # AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu" parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") parser.add_argument("-cut", "--default_how_to_cut", type=str, default="cut1", help="默认切分方式") parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps") parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") # bool值的用法为 `python ./api.py -fp ...` # 此时 full_precision==True, half_precision==False parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") args = parser.parse_args() sovits_path = args.sovits_path gpt_path = args.gpt_path splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } class DefaultRefer: def __init__(self, path, text, language, how_to_cut): self.path = args.default_refer_path self.text = args.default_refer_text self.language = args.default_refer_language self.how_to_cut = args.default_how_to_cut def is_ready(self) -> bool: return is_full(self.path, self.text, self.language,self.how_to_cut) default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language, args.default_how_to_cut) device = args.device port = args.port host = args.bind_addr if sovits_path == "": sovits_path = g_config.pretrained_sovits_path print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}") if gpt_path == "": gpt_path = g_config.pretrained_gpt_path print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}") # 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": default_refer.path, default_refer.text, default_refer.language = "", "", "" print("[INFO] 未指定默认参考音频") else: if (default_refer.how_to_cut == "cut1"): cut_info = "凑四句一切" elif (default_refer.how_to_cut == "cut2"): cut_info = "凑50字一切" elif (default_refer.how_to_cut == "cut3"): cut_info = "按中文句号。切" elif (default_refer.how_to_cut == "cut4"): cut_info = "按英文句号.切" elif (default_refer.how_to_cut == "cut5"): cut_info = "按标点符号切" print(f"[INFO] 默认参考音频路径: {default_refer.path}") print(f"[INFO] 默认参考音频文本: {default_refer.text}") print(f"[INFO] 默认参考音频语种: {default_refer.language}") print(f"[INFO] 默认切割方式: {cut_info}") is_half = g_config.is_half if args.full_precision: is_half = False if args.half_precision: is_half = True if args.full_precision and args.half_precision: is_half = g_config.is_half # 炒饭fallback print(f"[INFO] 半精: {is_half}") cnhubert_base_path = args.hubert_path bert_path = args.bert_path cnhubert.cnhubert_base_path = cnhubert_base_path tokenizer = AutoTokenizer.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) if is_half: bert_model = bert_model.half().to(device) else: bert_model = bert_model.to(device) def is_empty(*items): # 任意一项不为空返回False for item in items: if item is not None and item != "": return False return True def is_full(*items): # 任意一项为空返回False for item in items: if item is None or item == "": return False return True def change_sovits_weights(sovits_path): global vq_model, hps dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model ) if ("pretrained" not in sovits_path): del vq_model.enc_q if is_half == True: vq_model = vq_model.half().to(device) else: vq_model = vq_model.to(device) vq_model.eval() print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) with open("./sweight.txt", "w", encoding="utf-8") as f: f.write(sovits_path) def change_gpt_weights(gpt_path): global hz, max_sec, t2s_model, config hz = 50 dict_s1 = torch.load(gpt_path, map_location="cpu") config = dict_s1["config"] max_sec = config["data"]["max_sec"] t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model.load_state_dict(dict_s1["weight"]) if is_half == True: t2s_model = t2s_model.half() t2s_model = t2s_model.to(device) t2s_model.eval() total = sum([param.nelement() for param in t2s_model.parameters()]) print("Number of parameter: %.2fM" % (total / 1e6)) with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path) def get_bert_feature(text, word2ph): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model res = bert_model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] assert len(word2ph) == len(text) phone_level_feature = [] for i in range(len(word2ph)): repeat_feature = res[i].repeat(word2ph[i], 1) phone_level_feature.append(repeat_feature) phone_level_feature = torch.cat(phone_level_feature, dim=0) # if(is_half==True):phone_level_feature=phone_level_feature.half() return phone_level_feature.T n_semantic = 1024 dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] class DictToAttrRecursive: def __init__(self, input_dict): for key, value in input_dict.items(): if isinstance(value, dict): # 如果值是字典,递归调用构造函数 setattr(self, key, DictToAttrRecursive(value)) else: setattr(self, key, value) hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" dict_s1 = torch.load(gpt_path, map_location="cpu") config = dict_s1["config"] ssl_model = cnhubert.get_model() if is_half: ssl_model = ssl_model.half().to(device) else: ssl_model = ssl_model.to(device) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model) if is_half: vq_model = vq_model.half().to(device) else: vq_model = vq_model.to(device) vq_model.eval() print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) hz = 50 max_sec = config['data']['max_sec'] t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model.load_state_dict(dict_s1["weight"]) if is_half: t2s_model = t2s_model.half() t2s_model = t2s_model.to(device) t2s_model.eval() total = sum([param.nelement() for param in t2s_model.parameters()]) print("Number of parameter: %.2fM" % (total / 1e6)) def get_spepc(hps, filename): audio = load_audio(filename, int(hps.data.sampling_rate)) audio = torch.FloatTensor(audio) audio_norm = audio audio_norm = audio_norm.unsqueeze(0) spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False) return spec dict_language = { "中文": "zh", "英文": "en", "日文": "ja", "ZH": "zh", "EN": "en", "JA": "ja", "zh": "zh", "en": "en", "ja": "ja" } def split(todo_text): todo_text = todo_text.replace("……", "。").replace("——", ",") if todo_text[-1] not in splits: todo_text += "。" i_split_head = i_split_tail = 0 len_text = len(todo_text) todo_texts = [] while 1: if i_split_head >= len_text: break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 if todo_text[i_split_head] in splits: i_split_head += 1 todo_texts.append(todo_text[i_split_tail:i_split_head]) i_split_tail = i_split_head else: i_split_head += 1 return todo_texts def cut1(inp): inp = inp.strip("\n") inps = split(inp) split_idx = list(range(0, len(inps), 4)) split_idx[-1] = None if len(split_idx) > 1: opts = [] for idx in range(len(split_idx) - 1): opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) else: opts = [inp] return "\n".join(opts) def cut2(inp): inp = inp.strip("\n") inps = split(inp) if len(inps) < 2: return inp opts = [] summ = 0 tmp_str = "" for i in range(len(inps)): summ += len(inps[i]) tmp_str += inps[i] if summ > 50: summ = 0 opts.append(tmp_str) tmp_str = "" if tmp_str != "": opts.append(tmp_str) # print(opts) if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 opts[-2] = opts[-2] + opts[-1] opts = opts[:-1] return "\n".join(opts) def cut3(inp): inp = inp.strip("\n") return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) def cut4(inp): inp = inp.strip("\n") return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py def cut5(inp): # if not re.search(r'[^\w\s]', inp[-1]): # inp += '。' inp = inp.strip("\n") punds = r'[,.;?!、,。?!;:…]' items = re.split(f'({punds})', inp) mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] # 在句子不存在符号或句尾无符号的时候保证文本完整 if len(items)%2 == 1: mergeitems.append(items[-1]) opt = "\n".join(mergeitems) return opt def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() prompt_text = prompt_text.strip("\n") prompt_language, text = prompt_language, text.strip("\n") zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) wav16k = torch.from_numpy(wav16k) zero_wav_torch = torch.from_numpy(zero_wav) if (is_half == True): wav16k = wav16k.half().to(device) zero_wav_torch = zero_wav_torch.half().to(device) else: wav16k = wav16k.to(device) zero_wav_torch = zero_wav_torch.to(device) wav16k = torch.cat([wav16k, zero_wav_torch]) ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] t1 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) phones1 = cleaned_text_to_sequence(phones1) texts = text.split("\n") audio_opt = [] for text in texts: phones2, word2ph2, norm_text2 = clean_text(text, text_language) phones2 = cleaned_text_to_sequence(phones2) if (prompt_language == "zh"): bert1 = get_bert_feature(norm_text1, word2ph1).to(device) else: bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to( device) if (text_language == "zh"): bert2 = get_bert_feature(norm_text2, word2ph2).to(device) else: bert2 = torch.zeros((1024, len(phones2))).to(bert1) bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) prompt = prompt_semantic.unsqueeze(0).to(device) t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer( pred_semantic, idx = t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_len, prompt, bert, # prompt_phone_len=ph_offset, top_k=config['inference']['top_k'], early_stop_num=hz * max_sec) t3 = ttime() # print(pred_semantic.shape,idx) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 refer = get_spepc(hps, ref_wav_path) # .to(device) if (is_half == True): refer = refer.half().to(device) else: refer = refer.to(device) # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] audio = \ vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[ 0, 0] ###试试重建不带上prompt部分 audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) def handle_control(command): if command == "restart": os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) elif command == "exit": os.kill(os.getpid(), signal.SIGTERM) exit(0) def handle_change(path, text, language, how_to_cut): if is_empty(path, text, language, how_to_cut): return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language","how_to_cut"'}, status_code=400) if path == "" or path is None: path = default_refer.path if text == "" or text is None: text = default_refer.text if language == "" or language is None: language = default_refer.language if how_to_cut == "" or how_to_cut is None: how_to_cut = default_refer.how_to_cut default_refer.path = path default_refer.text = text default_refer.language = language default_refer.how_to_cut = how_to_cut if (default_refer.how_to_cut == "cut1"): cut_info = "凑四句一切" elif (default_refer.how_to_cut == "cut2"): cut_info = "凑50字一切" elif (default_refer.how_to_cut == "cut3"): cut_info = "按中文句号。切" elif (default_refer.how_to_cut == "cut4"): cut_info = "按英文句号.切" elif (default_refer.how_to_cut == "cut5"): cut_info = "按标点符号切" print(f"[INFO] 当前默认参考音频路径: {default_refer.path}") print(f"[INFO] 当前默认参考音频文本: {default_refer.text}") print(f"[INFO] 当前默认参考音频语种: {default_refer.language}") print(f"[INFO] 默认切割方式: {cut_info}") print(f"[INFO] is_ready: {default_refer.is_ready()}") return JSONResponse({"code": 0, "message": "Success"}, status_code=200) def handle(refer_wav_path, prompt_text, prompt_language, text, text_language,how_to_cut): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None or prompt_language == "" or prompt_language is None ): refer_wav_path, prompt_text, prompt_language = ( default_refer.path, default_refer.text, default_refer.language, ) if (how_to_cut == "" or how_to_cut is None): how_to_cut = default_refer.how_to_cut if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) if (how_to_cut == "cut1"): text = cut1(text) #凑四句一切 elif (how_to_cut == "cut2"): text = cut2(text) #凑50字一切 elif (how_to_cut == "cut3"): text = cut3(text) #按中文句号。切 elif (how_to_cut == "cut4"): text = cut4(text) #按英文句号.切 elif (how_to_cut == "cut5"): text = cut5(text) #按标点符号切 while "\n\n" in text: text = text.replace("\n\n", "\n") print("实际输入的目标文本(切句后):", text) with torch.no_grad(): gen = get_tts_wav( refer_wav_path, prompt_text, prompt_language, text, text_language ) sampling_rate, audio_data = next(gen) wav = BytesIO() sf.write(wav, audio_data, sampling_rate, format="wav") wav.seek(0) torch.cuda.empty_cache() if device == "mps": print('executed torch.mps.empty_cache()') torch.mps.empty_cache() return StreamingResponse(wav, media_type="audio/wav") app = FastAPI() #clark新增-----2024-02-21 #可在启动后动态修改模型,以此满足同一个api不同的朗读者请求 @app.post("/set_model") async def set_model(request: Request): json_post_raw = await request.json() global gpt_path gpt_path=json_post_raw.get("gpt_model_path") global sovits_path sovits_path=json_post_raw.get("sovits_model_path") print("gptpath"+gpt_path+";vitspath"+sovits_path) change_sovits_weights(sovits_path) change_gpt_weights(gpt_path) return "ok" # 新增-----end------ @app.post("/control") async def control(request: Request): json_post_raw = await request.json() return handle_control(json_post_raw.get("command")) @app.get("/control") async def control(command: str = None): return handle_control(command) @app.post("/change_refer") async def change_refer(request: Request): json_post_raw = await request.json() return handle_change( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language"), json_post_raw.get("how_to_cut") ) @app.get("/change_refer") async def change_refer( refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None, how_to_cut: str = None ): return handle_change(refer_wav_path, prompt_text, prompt_language, how_to_cut) @app.post("/") async def tts_endpoint(request: Request): json_post_raw = await request.json() return handle( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language"), json_post_raw.get("text"), json_post_raw.get("text_language"), json_post_raw.get("how_to_cut") ) @app.get("/") async def tts_endpoint( refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None, text: str = None, text_language: str = None, how_to_cut: str = None ): return handle(refer_wav_path, prompt_text, prompt_language, text, text_language,how_to_cut) if __name__ == "__main__": uvicorn.run(app, host=host, port=port, workers=1)