mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-17 23:19:47 +08:00
parent
9ebae35a7d
commit
b122988be0
36
api.py
36
api.py
@ -176,9 +176,9 @@ import subprocess
|
|||||||
|
|
||||||
class DefaultRefer:
|
class DefaultRefer:
|
||||||
def __init__(self, path, text, language):
|
def __init__(self, path, text, language):
|
||||||
self.path = args.default_refer_path
|
self.path = path
|
||||||
self.text = args.default_refer_text
|
self.text = text
|
||||||
self.language = args.default_refer_language
|
self.language = language
|
||||||
|
|
||||||
def is_ready(self) -> bool:
|
def is_ready(self) -> bool:
|
||||||
return is_full(self.path, self.text, self.language)
|
return is_full(self.path, self.text, self.language)
|
||||||
@ -294,11 +294,13 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
|||||||
|
|
||||||
def get_sovits_weights(sovits_path):
|
def get_sovits_weights(sovits_path):
|
||||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
||||||
|
|
||||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||||
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
if if_lora_v3 == True and not os.path.exists(path_sovits_v3):
|
||||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||||
|
if model_version == "v4" and not os.path.exists(path_sovits_v4):
|
||||||
|
logger.info("SoVITS V4 底模缺失,无法加载相应 LoRA 权重")
|
||||||
|
|
||||||
dict_s2 = load_sovits_new(sovits_path)
|
dict_s2 = load_sovits_new(sovits_path)
|
||||||
hps = dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
@ -310,12 +312,13 @@ def get_sovits_weights(sovits_path):
|
|||||||
hps.model.version = "v1"
|
hps.model.version = "v1"
|
||||||
else:
|
else:
|
||||||
hps.model.version = "v2"
|
hps.model.version = "v2"
|
||||||
|
|
||||||
if model_version == "v3":
|
if model_version == "v3":
|
||||||
hps.model.version = "v3"
|
hps.model.version = "v3"
|
||||||
|
if model_version == "v4":
|
||||||
|
hps.model.version = "v4"
|
||||||
|
|
||||||
model_params_dict = vars(hps.model)
|
model_params_dict = vars(hps.model)
|
||||||
if model_version != "v3":
|
if model_version != "v3" and model_version != "v4":
|
||||||
vq_model = SynthesizerTrn(
|
vq_model = SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
@ -342,8 +345,11 @@ def get_sovits_weights(sovits_path):
|
|||||||
else:
|
else:
|
||||||
vq_model = vq_model.to(device)
|
vq_model = vq_model.to(device)
|
||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
if if_lora_v3 == False:
|
if if_lora_v3 == False or model_version != "v4":
|
||||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
|
else:
|
||||||
|
if model_version == "v4":
|
||||||
|
vq_model.load_state_dict(load_sovits_new(path_sovits_v4)["weight"], strict=False)
|
||||||
else:
|
else:
|
||||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
|
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
|
||||||
lora_rank = dict_s2["lora_rank"]
|
lora_rank = dict_s2["lora_rank"]
|
||||||
@ -394,12 +400,10 @@ def change_gpt_sovits_weights(gpt_path, sovits_path):
|
|||||||
try:
|
try:
|
||||||
gpt = get_gpt_weights(gpt_path)
|
gpt = get_gpt_weights(gpt_path)
|
||||||
sovits = get_sovits_weights(sovits_path)
|
sovits = get_sovits_weights(sovits_path)
|
||||||
except Exception as e:
|
|
||||||
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
|
|
||||||
|
|
||||||
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
|
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
|
||||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
|
||||||
|
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -759,7 +763,7 @@ def get_tts_wav(
|
|||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
|
|
||||||
if version != "v3":
|
if version != "v3" and version != "v4":
|
||||||
refers = []
|
refers = []
|
||||||
if inp_refs:
|
if inp_refs:
|
||||||
for path in inp_refs:
|
for path in inp_refs:
|
||||||
@ -811,7 +815,7 @@ def get_tts_wav(
|
|||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
|
|
||||||
if version != "v3":
|
if version != "v3" and version != "v4":
|
||||||
audio = (
|
audio = (
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
||||||
.detach()
|
.detach()
|
||||||
@ -880,7 +884,7 @@ def get_tts_wav(
|
|||||||
audio_opt = np.concatenate(audio_opt, 0)
|
audio_opt = np.concatenate(audio_opt, 0)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
|
|
||||||
sr = hps.data.sampling_rate if version != "v3" else 24000
|
sr = hps.data.sampling_rate if version != "v3" and version != "v4" else 24000
|
||||||
if if_sr and sr == 24000:
|
if if_sr and sr == 24000:
|
||||||
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
||||||
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||||
@ -901,7 +905,7 @@ def get_tts_wav(
|
|||||||
if not stream_mode == "normal":
|
if not stream_mode == "normal":
|
||||||
if media_type == "wav":
|
if media_type == "wav":
|
||||||
sr = 48000 if if_sr else 24000
|
sr = 48000 if if_sr else 24000
|
||||||
sr = hps.data.sampling_rate if version != "v3" else sr
|
sr = hps.data.sampling_rate if version != "v3" and version != "v4" else sr
|
||||||
audio_bytes = pack_wav(audio_bytes, sr)
|
audio_bytes = pack_wav(audio_bytes, sr)
|
||||||
yield audio_bytes.getvalue()
|
yield audio_bytes.getvalue()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user