Merge b2d725ed9a2d50465f7882ecdd06715d66ef66bd into 0d2f2734024ccd1c44aead7608214b924bffbebe

This commit is contained in:
Songhuadanjiang 2025-06-10 19:15:58 +08:00 committed by GitHub
commit 0b1a6935a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

8
api.py
View File

@ -313,9 +313,11 @@ def get_sovits_weights(sovits_path):
if model_version == "v3":
hps.model.version = "v3"
if model_version == "v4":
hps.model.version = "v4"
model_params_dict = vars(hps.model)
if model_version != "v3":
if model_version != "v3" and model_version != "v4":
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@ -759,7 +761,7 @@ def get_tts_wav(
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
if version != "v3":
if version != "v3" and version != "v4":
refers = []
if inp_refs:
for path in inp_refs:
@ -811,7 +813,7 @@ def get_tts_wav(
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime()
if version != "v3":
if version != "v3" and version != "v4":
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
.detach()