2025-6-1 14:13

This commit is contained in:
Songhuadanjiang 2025-06-01 14:13:10 +08:00 committed by GitHub
parent 968952fd2a
commit b2d725ed9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

8
api.py
View File

@ -313,9 +313,11 @@ def get_sovits_weights(sovits_path):
if model_version == "v3": if model_version == "v3":
hps.model.version = "v3" hps.model.version = "v3"
if model_version == "v4":
hps.model.version = "v4"
model_params_dict = vars(hps.model) model_params_dict = vars(hps.model)
if model_version != "v3": if model_version != "v3" and model_version != "v4":
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -759,7 +761,7 @@ def get_tts_wav(
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
if version != "v3": if version != "v3" and version != "v4":
refers = [] refers = []
if inp_refs: if inp_refs:
for path in inp_refs: for path in inp_refs:
@ -811,7 +813,7 @@ def get_tts_wav(
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime() t3 = ttime()
if version != "v3": if version != "v3" and version != "v4":
audio = ( audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
.detach() .detach()