From d2cf5193486219ca4c571e6440e4203a525df12e Mon Sep 17 00:00:00 2001 From: CMDHL Date: Sun, 28 Jan 2024 08:36:08 -0500 Subject: [PATCH] modified api.py: support runtime model change based on speaker name, and can be connected to SillyTavern in silero format --- api.py | 248 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 247 insertions(+), 1 deletion(-) diff --git a/api.py b/api.py index 60ed9fff..a14e8a25 100644 --- a/api.py +++ b/api.py @@ -1,4 +1,60 @@ """ +# 1.28.2024 在原api.py基础上做出的一些改动 + +## 简介 + +- 原接口不变,仿照silero-api-server格式添加了一些endpoint,可接入傻酒馆sillytavern。 + - 运行api.py直至显示http://127.0.0.1:9880 + - 在staging版本的sillytavern>Extensions>TTS>Select TTS Provider选择silero + - 将http://127.0.0.1:9880填入Provider Endpoint后点击reload + - Select TTS Provider上方显示TTS Provider Loaded则连接成功,之后照常设置即可。 + +- 支持运行中根据讲话人名称自动更换声音模型或参考音频。 + - 如果运行api.py时使用-vd提供了声音模型根目录,可以根据讲话人名称(子文件夹名称或"default")自动更换模型和参考音频。例如: + + python api.py -vd "D:/Voices" + + - 原有方式导入的声音被命名为“default”,可以使用原有方式更改其参考音频。也可以通过POST至 /set_default_models 更改default声音的模型,例如使用powershell新窗口运行: + + Invoke-RestMethod -Uri "http://127.0.0.1:9880/set_default_models" -Method Post -ContentType "application/json" -Body (@{gpt_path="D:\Voices\ZB\ZB.ckpt"; sovits_path="D:\Voices\ZB\ZB.pth"} | ConvertTo-Json) + +- 默认输出语言是中文。可以在运行api.py时使用-ol指定其他输出语言,或后续POST至 /language 进行更改。 + - 例如要将输出语言改为英文,可以在新的powershell窗口运行: + + Invoke-RestMethod -Uri "http://127.0.0.1:9880/language" -Method Post -Body '{"language": "en"}' -ContentType "application/json" + +## 声音模型根目录格式 + + Voices + ├─XXX + ├ ├───XXX.ckpt + ├ ├───XXX.pth + ├ ├───XXX.wav + ├ └───XXX.txt + ├─YYY + ├ ├───YYY.wav + ├ └───YYY.txt + ├─... + ├ + └─ZZZ + ├───ZZZ.ckpt + ├───ZZZ.pth + ├───ZZZ.wav + └───ZZZ.txt + +- 没有提供GPT和SoVITS模型文件的声音(例如上图的YYY)将使用原有方式指定的default声音模型。 +- 每个文件夹中的txt文件是参考音频文本,目前仅支持单语言,内容格式为{语言}|{参考音频文本},例如: + + zh|这是一段参考文本 + +## 新增的执行参数 + +`-vd`- `声音模型根目录,子文件夹以讲话人名称命名` +`-ol` - `输出音频语言, "中文","英文","日文","zh","en","ja"` + + + + # api.py usage ` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` @@ -125,12 +181,18 @@ from module.mel_processing import spectrogram_torch from my_utils import load_audio import config as global_config +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse + g_config = global_config.Config() # AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu" parser = argparse.ArgumentParser(description="GPT-SoVITS api") +parser.add_argument("-vd", "--voices_dir", type=str, default="", help="声音模型根目录,子文件夹以讲话人名称命名") +parser.add_argument("-ol", "--output_language", type=str, default="zh", help="输出音频语言") + parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") @@ -151,6 +213,11 @@ parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, h args = parser.parse_args() +voices_dir = args.voices_dir +current_language=args.output_language +current_gpt_path=args.gpt_path +current_sovits_path=args.sovits_path + sovits_path = args.sovits_path gpt_path = args.gpt_path @@ -165,6 +232,47 @@ class DefaultRefer: return is_full(self.path, self.text, self.language) +class Voice: + def __init__(self, folder): + try: + self.refer_wav_path = os.path.join(voices_dir, folder, f"{folder}.wav") + if not os.path.isfile(self.refer_wav_path): + raise ValueError("找不到参考音频 {refer_wav_path}") + refer_txt_path = os.path.join(voices_dir, folder, f"{folder}.txt") + if not os.path.isfile(refer_txt_path): + raise ValueError("找不到参考文本 {refer_txt_path}") + with open(refer_txt_path, 'r', encoding='utf-8') as file: + content = file.read() + parts = content.split('|', 1) + if len(parts) == 2: + self.refer_lang, self.refer_text = parts + else: + raise ValueError("参考文本格式错误。请使用'|'标注文本语言。目前仅支持单语言文本。例如:\nzh|这是一段参考文本。") + self.sovits_path =os.path.join(voices_dir, folder, f"{folder}.pth") + if not os.path.isfile(self.sovits_path): + self.sovits_path=None + print(f"[WARNING] 找不到 {folder} 专属SoVITS模型。此声音将使用默认SoVITS模型。") + self.gpt_path = os.path.join(voices_dir, folder, f"{folder}.ckpt") + if not os.path.isfile(self.gpt_path): + self.gpt_path=None + print(f"[WARNING] 找不到 {folder} 专属GPT模型。此声音将使用默认GPT模型。") + self.name=folder + except Exception as e: + raise e + + +voices = {} +if voices_dir!="": + print(f"[INFO] 声音模型根目录: {voices_dir}") + for folder in os.listdir(voices_dir): + if os.path.isdir(os.path.join(voices_dir, folder)): + try: + voices[folder]=Voice(folder) + print(f"[INFO] 根目录下发现声音: {folder}") + except Exception as e: + print(f"[WARNING] {folder} 声音模型文件夹格式错误: {e}") + pass + default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) device = args.device @@ -414,6 +522,54 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) +def handle_load(new_gpt_path, new_sovits_path): + global gpt_path,sovits_path,current_gpt_path,current_sovits_path + if(new_gpt_path=="" or new_gpt_path is None): + new_gpt_path = gpt_path + if(gpt_path=="" or gpt_path is None): + print("[ERROR] 未设置默认GPT模型地址") + raise ValueError("未设置默认GPT模型地址") + if(new_sovits_path=="" or new_sovits_path is None): + new_sovits_path = sovits_path + if(sovits_path=="" or sovits_path is None): + print("[ERROR] 未设置默认SoVITS模型地址") + raise ValueError("未设置默认SoVITS模型地址") + if(os.path.normpath(os.path.abspath(current_gpt_path))==os.path.normpath(os.path.abspath(new_gpt_path)) + and os.path.normpath(os.path.abspath(current_sovits_path)) == os.path.normpath(os.path.abspath(new_sovits_path))): + return + print(f"current models: {current_gpt_path}, {current_sovits_path}") + print(f"loading new models: {new_gpt_path}, {new_sovits_path}") + current_gpt_path=new_gpt_path + current_sovits_path=new_sovits_path + + + global dict_s2, hps, dict_s1, config, vq_model, max_sec, t2s_model + dict_s2 = torch.load(new_sovits_path, map_location="cpu") # Corrected the variable name here + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + dict_s1 = torch.load(new_gpt_path, map_location="cpu") + config = dict_s1["config"] + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model) + if is_half: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) + max_sec = config['data']['max_sec'] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + + def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): if ( refer_wav_path == "" or refer_wav_path is None @@ -439,7 +595,8 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): wav.seek(0) torch.cuda.empty_cache() - torch.mps.empty_cache() + if(device=="mps"): #added condition so it can run on my device for testing + torch.mps.empty_cache() return StreamingResponse(wav, media_type="audio/wav") @@ -457,6 +614,15 @@ async def control(command: str = None): return handle_control(command) +@app.post("/set_default_models") +async def set_default_models(request: Request): + global gpt_path,sovits_path + json_post_raw = await request.json() + gpt_path = json_post_raw.get("gpt_path") + sovits_path = json_post_raw.get("sovits_path") + return JSONResponse({"gpt_path":gpt_path,"sovits_path":sovits_path},status_code=200) + + @app.post("/change_refer") async def change_refer(request: Request): json_post_raw = await request.json() @@ -499,5 +665,85 @@ async def tts_endpoint( return handle(refer_wav_path, prompt_text, prompt_language, text, text_language) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/speakers") +async def speakers(request:Request): + voices_info = [ + { + "name":"default", + "voice_id":"default", + "preview_url": f"{str(request.base_url)}sample/default" + } + ] + if(len(voices)>0): + for v in voices.values(): + voices_info.append( + { + "name":v.name, + "voice_id":v.name, + "preview_url": f"{str(request.base_url)}sample/{v.name}" + } + ) + return voices_info + + +@app.post("/generate") +async def generate(request:Request): + json_post_raw=await request.json() + speaker = json_post_raw.get("speaker") + if(speaker=="default"): + handle_load(gpt_path,sovits_path) + handle_result=handle( + None, + None, + None, + json_post_raw.get("text"), + current_language, + ) + else: + handle_load(voices[speaker].gpt_path,voices[speaker].sovits_path) + return handle( + voices[speaker].refer_wav_path, + voices[speaker].refer_text, + voices[speaker].refer_lang, + json_post_raw.get("text"), + current_language, + ) + + +@app.get("/sample/{speaker}") +async def play_sample(speaker: str = 'default'): + if(speaker=='default'): + return FileResponse(default_refer.path,status_code=200) + print(f"sending {voices[speaker].refer_wav_path}") + return FileResponse(voices[speaker].refer_wav_path,status_code=200) + + +@app.post("/session") #just a placeholder +async def session(request:Request): + return JSONResponse({},status_code=200) + +@app.get("/language") +async def get_languages(): + return JSONResponse(list(dict_language.keys()),headers={'Content-Type': 'text/plain; charset=utf-8'}, status_code=200) + +@app.post("/language") +async def set_language(request: Request): + global current_language + json_post_raw=await request.json() + current_language = json_post_raw.get("language") + print(f"[INFO] output language is set to:{current_language}") + return JSONResponse(f"current language: {current_language}",status_code=200) + + + if __name__ == "__main__": uvicorn.run(app, host=host, port=port, workers=1)