修改`api.py`使其可以正常兼容v4模型。已经过测试,可以使用。
This commit is contained in:
Karasukaigan 2025-05-09 03:24:50 +08:00
parent 9ebae35a7d
commit b122988be0

38
api.py
View File

@ -176,9 +176,9 @@ import subprocess
class DefaultRefer:
def __init__(self, path, text, language):
self.path = args.default_refer_path
self.text = args.default_refer_text
self.language = args.default_refer_language
self.path = path
self.text = text
self.language = language
def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language)
@ -294,11 +294,13 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 == True and is_exist_s2gv3 == False:
if if_lora_v3 == True and not os.path.exists(path_sovits_v3):
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
if model_version == "v4" and not os.path.exists(path_sovits_v4):
logger.info("SoVITS V4 底模缺失,无法加载相应 LoRA 权重")
dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"]
@ -310,12 +312,13 @@ def get_sovits_weights(sovits_path):
hps.model.version = "v1"
else:
hps.model.version = "v2"
if model_version == "v3":
hps.model.version = "v3"
if model_version == "v4":
hps.model.version = "v4"
model_params_dict = vars(hps.model)
if model_version != "v3":
if model_version != "v3" and model_version != "v4":
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@ -342,10 +345,13 @@ def get_sovits_weights(sovits_path):
else:
vq_model = vq_model.to(device)
vq_model.eval()
if if_lora_v3 == False:
if if_lora_v3 == False or model_version != "v4":
vq_model.load_state_dict(dict_s2["weight"], strict=False)
else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
if model_version == "v4":
vq_model.load_state_dict(load_sovits_new(path_sovits_v4)["weight"], strict=False)
else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@ -394,13 +400,11 @@ def change_gpt_sovits_weights(gpt_path, sovits_path):
try:
gpt = get_gpt_weights(gpt_path)
sovits = get_sovits_weights(sovits_path)
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
except Exception as e:
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
@ -759,7 +763,7 @@ def get_tts_wav(
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
if version != "v3":
if version != "v3" and version != "v4":
refers = []
if inp_refs:
for path in inp_refs:
@ -811,7 +815,7 @@ def get_tts_wav(
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime()
if version != "v3":
if version != "v3" and version != "v4":
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
.detach()
@ -880,7 +884,7 @@ def get_tts_wav(
audio_opt = np.concatenate(audio_opt, 0)
t4 = ttime()
sr = hps.data.sampling_rate if version != "v3" else 24000
sr = hps.data.sampling_rate if version != "v3" and version != "v4" else 24000
if if_sr and sr == 24000:
audio_opt = torch.from_numpy(audio_opt).float().to(device)
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
@ -901,7 +905,7 @@ def get_tts_wav(
if not stream_mode == "normal":
if media_type == "wav":
sr = 48000 if if_sr else 24000
sr = hps.data.sampling_rate if version != "v3" else sr
sr = hps.data.sampling_rate if version != "v3" and version != "v4" else sr
audio_bytes = pack_wav(audio_bytes, sr)
yield audio_bytes.getvalue()