mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
model control
This commit is contained in:
parent
b3e8eb40c2
commit
c9b6945b22
82
api.py
82
api.py
@ -195,8 +195,24 @@ def is_full(*items): # 任意一项为空返回False
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def change_sovits_weights(sovits_path):
|
class Speaker:
|
||||||
global vq_model, hps
|
def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
|
||||||
|
self.name = name
|
||||||
|
self.sovits = sovits
|
||||||
|
self.gpt = gpt
|
||||||
|
self.phones = phones
|
||||||
|
self.bert = bert
|
||||||
|
self.prompt = prompt
|
||||||
|
|
||||||
|
speaker_list = {}
|
||||||
|
|
||||||
|
|
||||||
|
class Sovits:
|
||||||
|
def __init__(self, vq_model, hps):
|
||||||
|
self.vq_model = vq_model
|
||||||
|
self.hps = hps
|
||||||
|
|
||||||
|
def get_sovits_weights(sovits_path):
|
||||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||||
hps = dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
hps = DictToAttrRecursive(hps)
|
hps = DictToAttrRecursive(hps)
|
||||||
@ -222,10 +238,17 @@ def change_sovits_weights(sovits_path):
|
|||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
|
|
||||||
|
sovits = Sovits(vq_model, hps)
|
||||||
|
return sovits
|
||||||
|
|
||||||
def change_gpt_weights(gpt_path):
|
class Gpt:
|
||||||
global hz, max_sec, t2s_model, config
|
def __init__(self, max_sec, t2s_model):
|
||||||
hz = 50
|
self.max_sec = max_sec
|
||||||
|
self.t2s_model = t2s_model
|
||||||
|
|
||||||
|
global hz
|
||||||
|
hz = 50
|
||||||
|
def get_gpt_weights(gpt_path):
|
||||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||||
config = dict_s1["config"]
|
config = dict_s1["config"]
|
||||||
max_sec = config["data"]["max_sec"]
|
max_sec = config["data"]["max_sec"]
|
||||||
@ -238,6 +261,19 @@ def change_gpt_weights(gpt_path):
|
|||||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||||
logger.info("Number of parameter: %.2fM" % (total / 1e6))
|
logger.info("Number of parameter: %.2fM" % (total / 1e6))
|
||||||
|
|
||||||
|
gpt = Gpt(max_sec, t2s_model)
|
||||||
|
return gpt
|
||||||
|
|
||||||
|
def change_gpt_sovits_weights(gpt_path,sovits_path):
|
||||||
|
try:
|
||||||
|
gpt = get_gpt_weights(gpt_path)
|
||||||
|
sovits = get_sovits_weights(sovits_path)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
|
||||||
|
|
||||||
|
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
|
||||||
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -504,7 +540,15 @@ def only_punc(text):
|
|||||||
return not any(t.isalnum() or t.isalpha() for t in text)
|
return not any(t.isalnum() or t.isalpha() for t in text)
|
||||||
|
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1, spk = "default"):
|
||||||
|
infer_sovits = speaker_list[spk].sovits
|
||||||
|
vq_model = infer_sovits.vq_model
|
||||||
|
hps = infer_sovits.hps
|
||||||
|
|
||||||
|
infer_gpt = speaker_list[spk].gpt
|
||||||
|
t2s_model = infer_gpt.t2s_model
|
||||||
|
max_sec = infer_gpt.max_sec
|
||||||
|
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
prompt_language, text = prompt_language, text.strip("\n")
|
prompt_language, text = prompt_language, text.strip("\n")
|
||||||
@ -523,6 +567,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||||
codes = vq_model.extract_latent(ssl_content)
|
codes = vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
version = vq_model.version
|
version = vq_model.version
|
||||||
os.environ['version'] = version
|
os.environ['version'] = version
|
||||||
@ -544,7 +589,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||||
bert = bert.to(device).unsqueeze(0)
|
bert = bert.to(device).unsqueeze(0)
|
||||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# pred_semantic = t2s_model.model.infer(
|
# pred_semantic = t2s_model.model.infer(
|
||||||
@ -763,9 +807,7 @@ if is_half:
|
|||||||
else:
|
else:
|
||||||
bert_model = bert_model.to(device)
|
bert_model = bert_model.to(device)
|
||||||
ssl_model = ssl_model.to(device)
|
ssl_model = ssl_model.to(device)
|
||||||
change_sovits_weights(sovits_path)
|
change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
|
||||||
change_gpt_weights(gpt_path)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -777,14 +819,18 @@ app = FastAPI()
|
|||||||
@app.post("/set_model")
|
@app.post("/set_model")
|
||||||
async def set_model(request: Request):
|
async def set_model(request: Request):
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
global gpt_path
|
return change_gpt_sovits_weights(
|
||||||
gpt_path=json_post_raw.get("gpt_model_path")
|
gpt_path = json_post_raw.get("gpt_model_path"),
|
||||||
global sovits_path
|
sovits_path = json_post_raw.get("sovits_model_path")
|
||||||
sovits_path=json_post_raw.get("sovits_model_path")
|
)
|
||||||
logger.info("gptpath"+gpt_path+";vitspath"+sovits_path)
|
|
||||||
change_sovits_weights(sovits_path)
|
|
||||||
change_gpt_weights(gpt_path)
|
@app.get("/set_model")
|
||||||
return "ok"
|
async def set_model(
|
||||||
|
gpt_model_path: str = None,
|
||||||
|
sovits_model_path: str = None,
|
||||||
|
):
|
||||||
|
return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/control")
|
@app.post("/control")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user