diff --git a/api.py b/api.py index 3b17394..a6ec11f 100644 --- a/api.py +++ b/api.py @@ -195,8 +195,24 @@ def is_full(*items): # 任意一项为空返回False return True -def change_sovits_weights(sovits_path): - global vq_model, hps +class Speaker: + def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None): + self.name = name + self.sovits = sovits + self.gpt = gpt + self.phones = phones + self.bert = bert + self.prompt = prompt + +speaker_list = {} + + +class Sovits: + def __init__(self, vq_model, hps): + self.vq_model = vq_model + self.hps = hps + +def get_sovits_weights(sovits_path): dict_s2 = torch.load(sovits_path, map_location="cpu") hps = dict_s2["config"] hps = DictToAttrRecursive(hps) @@ -222,10 +238,17 @@ def change_sovits_weights(sovits_path): vq_model.eval() vq_model.load_state_dict(dict_s2["weight"], strict=False) + sovits = Sovits(vq_model, hps) + return sovits -def change_gpt_weights(gpt_path): - global hz, max_sec, t2s_model, config - hz = 50 +class Gpt: + def __init__(self, max_sec, t2s_model): + self.max_sec = max_sec + self.t2s_model = t2s_model + +global hz +hz = 50 +def get_gpt_weights(gpt_path): dict_s1 = torch.load(gpt_path, map_location="cpu") config = dict_s1["config"] max_sec = config["data"]["max_sec"] @@ -238,6 +261,19 @@ def change_gpt_weights(gpt_path): total = sum([param.nelement() for param in t2s_model.parameters()]) logger.info("Number of parameter: %.2fM" % (total / 1e6)) + gpt = Gpt(max_sec, t2s_model) + return gpt + +def change_gpt_sovits_weights(gpt_path,sovits_path): + try: + gpt = get_gpt_weights(gpt_path) + sovits = get_sovits_weights(sovits_path) + except Exception as e: + return JSONResponse({"code": 400, "message": str(e)}, status_code=400) + + speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + def get_bert_feature(text, word2ph): with torch.no_grad(): @@ -504,7 +540,15 @@ def only_punc(text): return not any(t.isalnum() or t.isalpha() for t in text) -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1, spk = "default"): + infer_sovits = speaker_list[spk].sovits + vq_model = infer_sovits.vq_model + hps = infer_sovits.hps + + infer_gpt = speaker_list[spk].gpt + t2s_model = infer_gpt.t2s_model + max_sec = infer_gpt.max_sec + t0 = ttime() prompt_text = prompt_text.strip("\n") prompt_language, text = prompt_language, text.strip("\n") @@ -523,6 +567,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) t1 = ttime() version = vq_model.version os.environ['version'] = version @@ -544,7 +589,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - prompt = prompt_semantic.unsqueeze(0).to(device) t2 = ttime() with torch.no_grad(): # pred_semantic = t2s_model.model.infer( @@ -763,9 +807,7 @@ if is_half: else: bert_model = bert_model.to(device) ssl_model = ssl_model.to(device) -change_sovits_weights(sovits_path) -change_gpt_weights(gpt_path) - +change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path) @@ -777,14 +819,18 @@ app = FastAPI() @app.post("/set_model") async def set_model(request: Request): json_post_raw = await request.json() - global gpt_path - gpt_path=json_post_raw.get("gpt_model_path") - global sovits_path - sovits_path=json_post_raw.get("sovits_model_path") - logger.info("gptpath"+gpt_path+";vitspath"+sovits_path) - change_sovits_weights(sovits_path) - change_gpt_weights(gpt_path) - return "ok" + return change_gpt_sovits_weights( + gpt_path = json_post_raw.get("gpt_model_path"), + sovits_path = json_post_raw.get("sovits_model_path") + ) + + +@app.get("/set_model") +async def set_model( + gpt_model_path: str = None, + sovits_model_path: str = None, +): + return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path) @app.post("/control")