mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-17 17:10:01 +08:00
support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
This commit is contained in:
parent
0621259549
commit
92819d0b31
@ -35,7 +35,16 @@ from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
from tools.my_utils import load_audio
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
|
||||
from sv import SV
|
||||
resample_transform_dict={}
|
||||
def resample(audio_tensor, sr0,sr1,device):
|
||||
global resample_transform_dict
|
||||
key="%s-%s-%s"%(sr0,sr1,str(device))
|
||||
if key not in resample_transform_dict:
|
||||
resample_transform_dict[key] = torchaudio.transforms.Resample(
|
||||
sr0, sr1
|
||||
).to(device)
|
||||
return resample_transform_dict[key](audio_tensor)
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
@ -102,18 +111,6 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
||||
|
||||
return processed_audio
|
||||
|
||||
|
||||
resample_transform_dict = {}
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0, sr1, device):
|
||||
global resample_transform_dict
|
||||
key = "%s-%s" % (sr0, sr1)
|
||||
if key not in resample_transform_dict:
|
||||
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
||||
return resample_transform_dict[key](audio_tensor)
|
||||
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
@ -252,6 +249,24 @@ class TTS_Config:
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
},
|
||||
"v2Pro": {
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"version": "v2Pro",
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro_pre1.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
},
|
||||
"v2ProPlus": {
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"version": "v2ProPlus",
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus_pre1.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
},
|
||||
}
|
||||
configs: dict = None
|
||||
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
@ -287,7 +302,7 @@ class TTS_Config:
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
version = configs.get("version", "v2").lower()
|
||||
assert version in ["v1", "v2", "v3", "v4"]
|
||||
assert version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]
|
||||
self.default_configs[version] = configs.get(version, self.default_configs[version])
|
||||
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
|
||||
|
||||
@ -403,6 +418,7 @@ class TTS:
|
||||
self.cnhuhbert_model: CNHubert = None
|
||||
self.vocoder = None
|
||||
self.sr_model: AP_BWE = None
|
||||
self.sv_model = None
|
||||
self.sr_model_not_exist: bool = False
|
||||
|
||||
self.vocoder_configs: dict = {
|
||||
@ -463,6 +479,8 @@ class TTS:
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
self.configs.vits_weights_path = weights_path
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
||||
if "Pro"in model_version:
|
||||
self.init_sv_model()
|
||||
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
||||
|
||||
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
|
||||
@ -472,7 +490,6 @@ class TTS:
|
||||
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
|
||||
dict_s2 = load_sovits_new(weights_path)
|
||||
hps = dict_s2["config"]
|
||||
|
||||
hps["model"]["semantic_frame_rate"] = "25hz"
|
||||
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
|
||||
hps["model"]["version"] = "v2" # v3model,v2sybomls
|
||||
@ -480,7 +497,15 @@ class TTS:
|
||||
hps["model"]["version"] = "v1"
|
||||
else:
|
||||
hps["model"]["version"] = "v2"
|
||||
# version = hps["model"]["version"]
|
||||
version = hps["model"]["version"]
|
||||
v3v4set={"v3", "v4"}
|
||||
if model_version not in v3v4set:
|
||||
if "Pro"not in model_version:
|
||||
model_version = version
|
||||
else:
|
||||
hps["model"]["version"] = model_version
|
||||
else:
|
||||
hps["model"]["version"] = model_version
|
||||
|
||||
self.configs.filter_length = hps["data"]["filter_length"]
|
||||
self.configs.segment_size = hps["train"]["segment_size"]
|
||||
@ -496,7 +521,7 @@ class TTS:
|
||||
|
||||
# print(f"model_version:{model_version}")
|
||||
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
|
||||
if model_version not in {"v3", "v4"}:
|
||||
if model_version not in v3v4set:
|
||||
vits_model = SynthesizerTrn(
|
||||
self.configs.filter_length // 2 + 1,
|
||||
self.configs.segment_size // self.configs.hop_length,
|
||||
@ -517,6 +542,8 @@ class TTS:
|
||||
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
|
||||
del vits_model.enc_q
|
||||
|
||||
self.is_v2pro=model_version in {"v2Pro","v2ProPlus"}
|
||||
|
||||
if if_lora_v3 == False:
|
||||
print(
|
||||
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
|
||||
@ -551,7 +578,7 @@ class TTS:
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
self.configs.save_configs()
|
||||
self.configs.hz = 50
|
||||
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
|
||||
dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False)
|
||||
config = dict_s1["config"]
|
||||
self.configs.max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
@ -605,7 +632,7 @@ class TTS:
|
||||
)
|
||||
self.vocoder.remove_weight_norm()
|
||||
state_dict_g = torch.load(
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
|
||||
)
|
||||
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
|
||||
|
||||
@ -631,6 +658,11 @@ class TTS:
|
||||
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
|
||||
self.sr_model_not_exist = True
|
||||
|
||||
def init_sv_model(self):
|
||||
if self.sv_model is not None:
|
||||
return
|
||||
self.sv_model = SV(self.configs.device, self.configs.is_half)
|
||||
|
||||
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||
"""
|
||||
To enable half precision for the TTS model.
|
||||
@ -706,11 +738,11 @@ class TTS:
|
||||
self.prompt_cache["ref_audio_path"] = ref_audio_path
|
||||
|
||||
def _set_ref_spec(self, ref_audio_path):
|
||||
spec = self._get_ref_spec(ref_audio_path)
|
||||
spec_audio = self._get_ref_spec(ref_audio_path)
|
||||
if self.prompt_cache["refer_spec"] in [[], None]:
|
||||
self.prompt_cache["refer_spec"] = [spec]
|
||||
self.prompt_cache["refer_spec"] = [spec_audio]
|
||||
else:
|
||||
self.prompt_cache["refer_spec"][0] = spec
|
||||
self.prompt_cache["refer_spec"][0] = spec_audio
|
||||
|
||||
def _get_ref_spec(self, ref_audio_path):
|
||||
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
|
||||
@ -718,25 +750,33 @@ class TTS:
|
||||
self.prompt_cache["raw_audio"] = raw_audio
|
||||
self.prompt_cache["raw_sr"] = raw_sr
|
||||
|
||||
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
if raw_sr != self.configs.sampling_rate:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
|
||||
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
||||
else:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
|
||||
|
||||
maxx = audio.abs().max()
|
||||
if maxx > 1:
|
||||
audio /= min(2, maxx)
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
audio,
|
||||
self.configs.filter_length,
|
||||
self.configs.sampling_rate,
|
||||
self.configs.hop_length,
|
||||
self.configs.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = spec.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
spec = spec.half()
|
||||
return spec
|
||||
if self.is_v2pro == True:
|
||||
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
|
||||
if self.configs.is_half:
|
||||
audio = audio.half()
|
||||
else:audio=None
|
||||
return spec,audio
|
||||
|
||||
def _set_prompt_semantic(self, ref_wav_path: str):
|
||||
zero_wav = np.zeros(
|
||||
@ -1171,10 +1211,13 @@ class TTS:
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spec: torch.Tensor = [
|
||||
item.to(dtype=self.precision, device=self.configs.device)
|
||||
for item in self.prompt_cache["refer_spec"]
|
||||
]
|
||||
refer_audio_spec = []
|
||||
if self.is_v2pro:sv_emb=[]
|
||||
for spec,audio_tensor in self.prompt_cache["refer_spec"]:
|
||||
spec=spec.to(dtype=self.precision, device=self.configs.device)
|
||||
refer_audio_spec.append(spec)
|
||||
if self.is_v2pro:
|
||||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
||||
|
||||
batch_audio_fragment = []
|
||||
|
||||
@ -1206,9 +1249,10 @@ class TTS:
|
||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
if self.is_v2pro!=True:
|
||||
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
|
||||
else:
|
||||
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment = [
|
||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||
@ -1221,9 +1265,10 @@ class TTS:
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
if self.is_v2pro != True:
|
||||
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
|
||||
else:
|
||||
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
|
||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||
else:
|
||||
if parallel_infer:
|
||||
|
Loading…
x
Reference in New Issue
Block a user