""" # api.py usage ` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` ## 执行参数: `-s` - `SoVITS模型路径, 可在 config.py 中指定` `-g` - `GPT模型路径, 可在 config.py 中指定` 调用请求缺少参考音频时使用 `-dr` - `默认参考音频路径` `-dt` - `默认参考音频文本` `-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` `-d` - `推理设备, "cuda","cpu","mps"` `-a` - `绑定地址, 默认"0.0.0.0"` `-p` - `绑定端口, 默认9880, 可在 config.py 中指定` `-fp` - `覆盖 config.py 使用全精度` `-hp` - `覆盖 config.py 使用半精度` `-hb` - `cnhubert路径` `-b` - `bert路径` ## 调用: ### 推理 endpoint: `/` 使用执行参数指定的参考音频: GET: `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` POST: ```json { "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", "text_language": "zh" } ``` 手动指定当次推理所使用的参考音频: GET: `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` POST: ```json { "refer_wav_path": "123.wav", "prompt_text": "一二三。", "prompt_language": "zh", "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", "text_language": "zh" } ``` RESP: 成功: 直接返回 wav 音频流, http code 200 失败: 返回包含错误信息的 json, http code 400 ### 更换默认参考音频 endpoints: `/change_refer`, `/set_refer` key与推理端一样 GET: `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` POST: ```json { "refer_wav_path": "123.wav", "prompt_text": "一二三。", "prompt_language": "zh" } ``` RESP: 成功: json, http code 200 失败: json, 400 ### 更换模型 endpoints: `/change_model`, `/change_weight`, `/set_model`, `/set_weight` key alias: "gpt", "gpt_path", "gpt_model_path" "sovits", "sovits_path", "sovits_model_path" GET: `http://127.0.0.1:9880/change_weight?gpt=./GPT_weights/suijiSUI-e20.ckpt&sovits=./SoVITS_weights/suijiSUI_e20_s3280.pth` POST: ```json { "gpt": "./GPT_weights/suijiSUI-e20.ckpt", "sovits": "./SoVITS_weights/suijiSUI_e20_s3280.pth" } ``` RESP: 成功: json, http code 200 失败: json, 400 | "Internal Server Error" ### 命令控制 endpoint: `/control` command: "restart": 重新运行 "exit": 结束运行 GET: `http://127.0.0.1:9880/control?command=restart` POST: ```json { "command": "restart" } ``` RESP: 无 """ import argparse import os import sys now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import signal from time import time as ttime import torch import librosa import soundfile as sf from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np from feature_extractor import cnhubert from io import BytesIO from module.models import SynthesizerTrn from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text from module.mel_processing import spectrogram_torch from my_utils import load_audio import config as global_config g_config = global_config.Config() parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps") parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") # bool值的用法为 `python ./api.py -fp ...` # 此时 full_precision==True, half_precision==False parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") args = parser.parse_args() sovits_path = args.sovits_path gpt_path = args.gpt_path class DefaultRefer: def __init__(self, path, text, language): self.path = args.default_refer_path self.text = args.default_refer_text self.language = args.default_refer_language def is_ready(self) -> bool: return is_full(self.path, self.text, self.language) default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) device = args.device port = args.port host = args.bind_addr if sovits_path == "": sovits_path = g_config.pretrained_sovits_path print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}") if gpt_path == "": gpt_path = g_config.pretrained_gpt_path print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}") # 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": default_refer.path, default_refer.text, default_refer.language = "", "", "" print("[INFO] 未指定默认参考音频") else: print(f"[INFO] 默认参考音频路径: {default_refer.path}") print(f"[INFO] 默认参考音频文本: {default_refer.text}") print(f"[INFO] 默认参考音频语种: {default_refer.language}") is_half = g_config.is_half if args.full_precision: is_half = False if args.half_precision: is_half = True if args.full_precision and args.half_precision: is_half = g_config.is_half # 炒饭fallback print(f"[INFO] 半精: {is_half}") cnhubert_base_path = args.hubert_path bert_path = args.bert_path cnhubert.cnhubert_base_path = cnhubert_base_path tokenizer = AutoTokenizer.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) if is_half: bert_model = bert_model.half().to(device) else: bert_model = bert_model.to(device) def is_empty(*items): # 任意一项不为空返回False for item in items: if item: return False return True def is_full(*items): # 任意一项为空返回False for item in items: if not item: return False return True def change_sovits_weights(sovits_path): global vq_model, hps dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model ) if ("pretrained" not in sovits_path): del vq_model.enc_q if is_half == True: vq_model = vq_model.half().to(device) else: vq_model = vq_model.to(device) vq_model.eval() print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) with open("./sweight.txt", "w", encoding="utf-8") as f: f.write(sovits_path) def change_gpt_weights(gpt_path): global hz, max_sec, t2s_model, config hz = 50 dict_s1 = torch.load(gpt_path, map_location="cpu") config = dict_s1["config"] max_sec = config["data"]["max_sec"] t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model.load_state_dict(dict_s1["weight"]) if is_half == True: t2s_model = t2s_model.half() t2s_model = t2s_model.to(device) t2s_model.eval() total = sum([param.nelement() for param in t2s_model.parameters()]) print("Number of parameter: %.2fM" % (total / 1e6)) with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path) def get_bert_feature(text, word2ph): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model res = bert_model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] assert len(word2ph) == len(text) phone_level_feature = [] for i in range(len(word2ph)): repeat_feature = res[i].repeat(word2ph[i], 1) phone_level_feature.append(repeat_feature) phone_level_feature = torch.cat(phone_level_feature, dim=0) # if(is_half==True):phone_level_feature=phone_level_feature.half() return phone_level_feature.T n_semantic = 1024 dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] class DictToAttrRecursive: def __init__(self, input_dict): for key, value in input_dict.items(): if isinstance(value, dict): # 如果值是字典,递归调用构造函数 setattr(self, key, DictToAttrRecursive(value)) else: setattr(self, key, value) hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" dict_s1 = torch.load(gpt_path, map_location="cpu") config = dict_s1["config"] ssl_model = cnhubert.get_model() if is_half: ssl_model = ssl_model.half().to(device) else: ssl_model = ssl_model.to(device) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model) if is_half: vq_model = vq_model.half().to(device) else: vq_model = vq_model.to(device) vq_model.eval() print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) hz = 50 max_sec = config['data']['max_sec'] t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model.load_state_dict(dict_s1["weight"]) if is_half: t2s_model = t2s_model.half() t2s_model = t2s_model.to(device) t2s_model.eval() total = sum([param.nelement() for param in t2s_model.parameters()]) print("Number of parameter: %.2fM" % (total / 1e6)) def get_spepc(hps, filename): audio = load_audio(filename, int(hps.data.sampling_rate)) audio = torch.FloatTensor(audio) audio_norm = audio audio_norm = audio_norm.unsqueeze(0) spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False) return spec dict_language = { "中文": "zh", "英文": "en", "日文": "ja", "ZH": "zh", "EN": "en", "JA": "ja", "zh": "zh", "en": "en", "ja": "ja" } def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() prompt_text = prompt_text.strip("\n") prompt_language, text = prompt_language, text.strip("\n") zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) wav16k = torch.from_numpy(wav16k) zero_wav_torch = torch.from_numpy(zero_wav) if (is_half == True): wav16k = wav16k.half().to(device) zero_wav_torch = zero_wav_torch.half().to(device) else: wav16k = wav16k.to(device) zero_wav_torch = zero_wav_torch.to(device) wav16k = torch.cat([wav16k, zero_wav_torch]) ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] t1 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) phones1 = cleaned_text_to_sequence(phones1) texts = text.split("\n") audio_opt = [] for text in texts: phones2, word2ph2, norm_text2 = clean_text(text, text_language) phones2 = cleaned_text_to_sequence(phones2) if (prompt_language == "zh"): bert1 = get_bert_feature(norm_text1, word2ph1).to(device) else: bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to( device) if (text_language == "zh"): bert2 = get_bert_feature(norm_text2, word2ph2).to(device) else: bert2 = torch.zeros((1024, len(phones2))).to(bert1) bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) prompt = prompt_semantic.unsqueeze(0).to(device) t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer( pred_semantic, idx = t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_len, prompt, bert, # prompt_phone_len=ph_offset, top_k=config['inference']['top_k'], early_stop_num=hz * max_sec) t3 = ttime() # print(pred_semantic.shape,idx) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 refer = get_spepc(hps, ref_wav_path) # .to(device) if (is_half == True): refer = refer.half().to(device) else: refer = refer.to(device) # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] audio = \ vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[ 0, 0] ###试试重建不带上prompt部分 audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) def handle_control(command): if command == "restart": os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) elif command == "exit": os.kill(os.getpid(), signal.SIGTERM) exit(0) def handle_change_refer(path, text, language): if is_empty(path, text, language): return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400) if path != "" or path is not None: default_refer.path = path if text != "" or text is not None: default_refer.text = text if language != "" or language is not None: default_refer.language = language print(f"[INFO] 当前默认参考音频路径: {default_refer.path}") print(f"[INFO] 当前默认参考音频文本: {default_refer.text}") print(f"[INFO] 当前默认参考音频语种: {default_refer.language}") print(f"[INFO] is_ready: {default_refer.is_ready()}") return JSONResponse({"code": 0, "message": "Success"}, status_code=200) def handle_refer(refer_wav_path, prompt_text, prompt_language, text, text_language): if ( # 缺任意一个 not refer_wav_path or not prompt_text or not prompt_language ): # 使用全局 refer_wav_path, prompt_text, prompt_language = ( default_refer.path, default_refer.text, default_refer.language, ) if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) with torch.no_grad(): gen = get_tts_wav( refer_wav_path, prompt_text, prompt_language, text, text_language ) sampling_rate, audio_data = next(gen) wav = BytesIO() sf.write(wav, audio_data, sampling_rate, format="wav") wav.seek(0) if device == "cuda": torch.cuda.empty_cache() if device == "mps": print('executed torch.mps.empty_cache()') torch.mps.empty_cache() return StreamingResponse(wav, media_type="audio/wav") def handle_change_weights(gpt, sovits): if is_empty(gpt, sovits): return JSONResponse({"code": 400, "message": f"缺少任意一项以下参数: {gpt_alias}, {sovits_alias}"}, status_code=400) global gpt_path, sovits_path if gpt: gpt_path = gpt print(f"New gpt_path: {gpt_path}") change_gpt_weights(gpt_path) if sovits: sovits_path = sovits print(f"New sovits_path: {sovits_path}") change_sovits_weights(sovits_path) return JSONResponse({"code": 0, "message": "Success"}, status_code=200) gpt_alias = ( "gpt", "gpt_path", "gpt_model_path" # @JavaAndPython55 用的这个key, 嫌太长直接alias了 ) sovits_alias = ( "sovits", "sovits_path", "sovits_model_path" ) app = FastAPI() # clark新增-----2024-02-21 # 可在启动后动态修改模型,以此满足同一个api不同的朗读者请求 @app.post("/set_model") @app.post("/set_weight") @app.post("/change_model") @app.post("/change_weight") async def change_weight(request: Request): json_post_raw = await request.json() gpt, sovits = "", "" for ga in gpt_alias: g = json_post_raw.get(ga) if g: gpt = g break for sa in sovits_alias: s = json_post_raw.get(sa) if s: sovits = s break return handle_change_weights(gpt, sovits) @app.get("/set_model") @app.get("/set_weight") @app.get("/change_model") @app.get("/change_weight") async def change_weight( gpt: str = None, gpt_path: str = None, gpt_model_path: str = None, sovits: str = None, sovits_path: str = None, sovits_model_path: str = None, ): GPT, SOVITS = "", "" for gg in (gpt, gpt_path, gpt_model_path): if gg: GPT = gg break for ss in (sovits, sovits_path, sovits_model_path): if ss: SOVITS = ss break return handle_change_weights(GPT, SOVITS) @app.post("/control") async def control(request: Request): json_post_raw = await request.json() return handle_control(json_post_raw.get("command")) @app.get("/control") async def control(command: str = None): return handle_control(command) @app.post("/set_refer") @app.post("/change_refer") async def change_refer(request: Request): json_post_raw = await request.json() return handle_change_refer( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language") ) @app.get("/set_refer") @app.get("/change_refer") async def change_refer( refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None ): return handle_change_refer(refer_wav_path, prompt_text, prompt_language) @app.post("/") async def tts_endpoint(request: Request): json_post_raw = await request.json() return handle_refer( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language"), json_post_raw.get("text"), json_post_raw.get("text_language"), ) @app.get("/") async def tts_endpoint( refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None, text: str = None, text_language: str = None, ): return handle_refer(refer_wav_path, prompt_text, prompt_language, text, text_language) if __name__ == "__main__": uvicorn.run(app, host=host, port=port, workers=1)