mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-17 23:19:47 +08:00
parent
9ebae35a7d
commit
b122988be0
36
api.py
36
api.py
@ -176,9 +176,9 @@ import subprocess
|
||||
|
||||
class DefaultRefer:
|
||||
def __init__(self, path, text, language):
|
||||
self.path = args.default_refer_path
|
||||
self.text = args.default_refer_text
|
||||
self.language = args.default_refer_language
|
||||
self.path = path
|
||||
self.text = text
|
||||
self.language = language
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return is_full(self.path, self.text, self.language)
|
||||
@ -294,11 +294,13 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
||||
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
||||
if if_lora_v3 == True and not os.path.exists(path_sovits_v3):
|
||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
if model_version == "v4" and not os.path.exists(path_sovits_v4):
|
||||
logger.info("SoVITS V4 底模缺失,无法加载相应 LoRA 权重")
|
||||
|
||||
dict_s2 = load_sovits_new(sovits_path)
|
||||
hps = dict_s2["config"]
|
||||
@ -310,12 +312,13 @@ def get_sovits_weights(sovits_path):
|
||||
hps.model.version = "v1"
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
|
||||
if model_version == "v3":
|
||||
hps.model.version = "v3"
|
||||
if model_version == "v4":
|
||||
hps.model.version = "v4"
|
||||
|
||||
model_params_dict = vars(hps.model)
|
||||
if model_version != "v3":
|
||||
if model_version != "v3" and model_version != "v4":
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
@ -342,8 +345,11 @@ def get_sovits_weights(sovits_path):
|
||||
else:
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
if if_lora_v3 == False:
|
||||
if if_lora_v3 == False or model_version != "v4":
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
else:
|
||||
if model_version == "v4":
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v4)["weight"], strict=False)
|
||||
else:
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
|
||||
lora_rank = dict_s2["lora_rank"]
|
||||
@ -394,12 +400,10 @@ def change_gpt_sovits_weights(gpt_path, sovits_path):
|
||||
try:
|
||||
gpt = get_gpt_weights(gpt_path)
|
||||
sovits = get_sovits_weights(sovits_path)
|
||||
except Exception as e:
|
||||
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
|
||||
|
||||
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
|
||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
@ -759,7 +763,7 @@ def get_tts_wav(
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
||||
if version != "v3":
|
||||
if version != "v3" and version != "v4":
|
||||
refers = []
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
@ -811,7 +815,7 @@ def get_tts_wav(
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
t3 = ttime()
|
||||
|
||||
if version != "v3":
|
||||
if version != "v3" and version != "v4":
|
||||
audio = (
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
||||
.detach()
|
||||
@ -880,7 +884,7 @@ def get_tts_wav(
|
||||
audio_opt = np.concatenate(audio_opt, 0)
|
||||
t4 = ttime()
|
||||
|
||||
sr = hps.data.sampling_rate if version != "v3" else 24000
|
||||
sr = hps.data.sampling_rate if version != "v3" and version != "v4" else 24000
|
||||
if if_sr and sr == 24000:
|
||||
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
||||
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||
@ -901,7 +905,7 @@ def get_tts_wav(
|
||||
if not stream_mode == "normal":
|
||||
if media_type == "wav":
|
||||
sr = 48000 if if_sr else 24000
|
||||
sr = hps.data.sampling_rate if version != "v3" else sr
|
||||
sr = hps.data.sampling_rate if version != "v3" and version != "v4" else sr
|
||||
audio_bytes = pack_wav(audio_bytes, sr)
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user