修改`api.py`使其可以正常兼容v4模型。已经过测试,可以使用。
This commit is contained in:
Karasukaigan 2025-05-09 03:24:50 +08:00
parent 9ebae35a7d
commit b122988be0

40
api.py
View File

@ -176,9 +176,9 @@ import subprocess
class DefaultRefer: class DefaultRefer:
def __init__(self, path, text, language): def __init__(self, path, text, language):
self.path = args.default_refer_path self.path = path
self.text = args.default_refer_text self.text = text
self.language = args.default_refer_language self.language = language
def is_ready(self) -> bool: def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language) return is_full(self.path, self.text, self.language)
@ -294,11 +294,13 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
def get_sovits_weights(sovits_path): def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3 = os.path.exists(path_sovits_v3) path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 == True and is_exist_s2gv3 == False: if if_lora_v3 == True and not os.path.exists(path_sovits_v3):
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
if model_version == "v4" and not os.path.exists(path_sovits_v4):
logger.info("SoVITS V4 底模缺失,无法加载相应 LoRA 权重")
dict_s2 = load_sovits_new(sovits_path) dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"] hps = dict_s2["config"]
@ -310,12 +312,13 @@ def get_sovits_weights(sovits_path):
hps.model.version = "v1" hps.model.version = "v1"
else: else:
hps.model.version = "v2" hps.model.version = "v2"
if model_version == "v3": if model_version == "v3":
hps.model.version = "v3" hps.model.version = "v3"
if model_version == "v4":
hps.model.version = "v4"
model_params_dict = vars(hps.model) model_params_dict = vars(hps.model)
if model_version != "v3": if model_version != "v3" and model_version != "v4":
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -342,10 +345,13 @@ def get_sovits_weights(sovits_path):
else: else:
vq_model = vq_model.to(device) vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
if if_lora_v3 == False: if if_lora_v3 == False or model_version != "v4":
vq_model.load_state_dict(dict_s2["weight"], strict=False) vq_model.load_state_dict(dict_s2["weight"], strict=False)
else: else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False) if model_version == "v4":
vq_model.load_state_dict(load_sovits_new(path_sovits_v4)["weight"], strict=False)
else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
lora_rank = dict_s2["lora_rank"] lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig( lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@ -394,13 +400,11 @@ def change_gpt_sovits_weights(gpt_path, sovits_path):
try: try:
gpt = get_gpt_weights(gpt_path) gpt = get_gpt_weights(gpt_path)
sovits = get_sovits_weights(sovits_path) sovits = get_sovits_weights(sovits_path)
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
except Exception as e: except Exception as e:
return JSONResponse({"code": 400, "message": str(e)}, status_code=400) return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
with torch.no_grad(): with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt") inputs = tokenizer(text, return_tensors="pt")
@ -759,7 +763,7 @@ def get_tts_wav(
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
if version != "v3": if version != "v3" and version != "v4":
refers = [] refers = []
if inp_refs: if inp_refs:
for path in inp_refs: for path in inp_refs:
@ -811,7 +815,7 @@ def get_tts_wav(
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime() t3 = ttime()
if version != "v3": if version != "v3" and version != "v4":
audio = ( audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
.detach() .detach()
@ -880,7 +884,7 @@ def get_tts_wav(
audio_opt = np.concatenate(audio_opt, 0) audio_opt = np.concatenate(audio_opt, 0)
t4 = ttime() t4 = ttime()
sr = hps.data.sampling_rate if version != "v3" else 24000 sr = hps.data.sampling_rate if version != "v3" and version != "v4" else 24000
if if_sr and sr == 24000: if if_sr and sr == 24000:
audio_opt = torch.from_numpy(audio_opt).float().to(device) audio_opt = torch.from_numpy(audio_opt).float().to(device)
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
@ -901,7 +905,7 @@ def get_tts_wav(
if not stream_mode == "normal": if not stream_mode == "normal":
if media_type == "wav": if media_type == "wav":
sr = 48000 if if_sr else 24000 sr = 48000 if if_sr else 24000
sr = hps.data.sampling_rate if version != "v3" else sr sr = hps.data.sampling_rate if version != "v3" and version != "v4" else sr
audio_bytes = pack_wav(audio_bytes, sr) audio_bytes = pack_wav(audio_bytes, sr)
yield audio_bytes.getvalue() yield audio_bytes.getvalue()