适配v4版本

This commit is contained in:
ChasonJiang 2025-04-21 21:39:11 +08:00
parent 9d481da610
commit 30fdb60295
3 changed files with 174 additions and 65 deletions

View File

@ -25,7 +25,7 @@ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from BigVGAN.bigvgan import BigVGAN
from feature_extractor.cnhubert import CNHubert
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from module.models import SynthesizerTrn, SynthesizerTrnV3
from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
from peft import LoraConfig, get_peft_model
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -67,6 +67,20 @@ mel_fn = lambda x: mel_spectrogram_torch(
},
)
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
# 将 NumPy 数组转换为原始 PCM 流
@ -141,7 +155,7 @@ custom:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2
default:
v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
@ -149,7 +163,7 @@ default:
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
version: v1
default_v2:
v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
@ -157,7 +171,7 @@ default_v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2
default_v3:
v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
@ -165,6 +179,14 @@ default_v3:
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
version: v3
v4:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v4
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth
"""
@ -220,6 +242,16 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
"v4": {
"device": "cpu",
"is_half": False,
"version": "v4",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
}
configs: dict = None
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
@ -255,7 +287,7 @@ class TTS_Config:
assert isinstance(configs, dict)
version = configs.get("version", "v2").lower()
assert version in ["v1", "v2", "v3"]
assert version in ["v1", "v2", "v3", "v4"]
self.default_configs[version] = configs.get(version, self.default_configs[version])
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
@ -276,7 +308,7 @@ class TTS_Config:
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
self.languages = self.v1_languages if self.version == "v1" else self.v2_languages
self.is_v3_synthesizer: bool = False
self.use_vocoder: bool = False
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs[version]["t2s_weights_path"]
@ -369,10 +401,18 @@ class TTS:
self.bert_tokenizer: AutoTokenizer = None
self.bert_model: AutoModelForMaskedLM = None
self.cnhuhbert_model: CNHubert = None
self.bigvgan_model: BigVGAN = None
self.vocoder = None
self.sr_model: AP_BWE = None
self.sr_model_not_exist: bool = False
self.vocoder_configs: dict = {
"sr": None,
"T_ref": None,
"T_chunk": None,
"upsample_rate": None,
"overlapped_len": None,
}
self._init_models()
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
@ -391,6 +431,7 @@ class TTS:
"aux_ref_audio_paths": [],
}
self.stop_flag: bool = False
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
@ -456,14 +497,14 @@ class TTS:
# print(f"model_version:{model_version}")
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
if model_version != "v3":
if model_version not in ["v3", "v4"]:
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs,
)
self.configs.is_v3_synthesizer = False
self.configs.use_vocoder = False
else:
vits_model = SynthesizerTrnV3(
self.configs.filter_length // 2 + 1,
@ -471,8 +512,8 @@ class TTS:
n_speakers=self.configs.n_speakers,
**kwargs,
)
self.configs.is_v3_synthesizer = True
self.init_bigvgan()
self.configs.use_vocoder = True
self.init_vocoder(model_version)
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
del vits_model.enc_q
@ -481,8 +522,9 @@ class TTS:
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
)
else:
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
print(
f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}"
f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits)['weight'], strict=False)}"
)
lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig(
@ -521,20 +563,64 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half()
def init_bigvgan(self):
if self.bigvgan_model is not None:
return
self.bigvgan_model = BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False,
) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
self.bigvgan_model.remove_weight_norm()
self.bigvgan_model = self.bigvgan_model.eval()
def init_vocoder(self, version: str):
if version == "v3":
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
return
if self.vocoder is not None:
self.vocoder.cpu()
del self.vocoder
self.empty_cache()
self.vocoder = BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False,
) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
self.vocoder.remove_weight_norm()
self.vocoder_configs["sr"] = 24000
self.vocoder_configs["T_ref"] = 468
self.vocoder_configs["T_chunk"] = 934
self.vocoder_configs["upsample_rate"] = 256
self.vocoder_configs["overlapped_len"] = 12
elif version == "v4":
if self.vocoder is not None and self.vocoder.__class__.__name__ == "Generator":
return
if self.vocoder is not None:
self.vocoder.cpu()
del self.vocoder
self.empty_cache()
self.vocoder = Generator(
initial_channel=100,
resblock="1",
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0, is_bias=True
)
self.vocoder.remove_weight_norm()
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
print("loading vocoder",self.vocoder.load_state_dict(state_dict_g))
self.vocoder_configs["sr"] = 48000
self.vocoder_configs["T_ref"] = 500
self.vocoder_configs["T_chunk"] = 1000
self.vocoder_configs["upsample_rate"] = 480
self.vocoder_configs["overlapped_len"] = 12
self.vocoder = self.vocoder.eval()
if self.configs.is_half == True:
self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device)
self.vocoder = self.vocoder.half().to(self.configs.device)
else:
self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
self.vocoder = self.vocoder.to(self.configs.device)
def init_sr_model(self):
if self.sr_model is not None:
@ -913,13 +999,13 @@ class TTS:
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket and speed_factor == 1.0 and not (self.configs.is_v3_synthesizer and parallel_infer):
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
print(i18n("分桶处理模式已开启"))
elif speed_factor != 1.0:
print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理"))
split_bucket = False
elif self.configs.is_v3_synthesizer and parallel_infer:
print(i18n("当开启并行推理模式时SoVits V3模型不支持分桶处理,已自动关闭分桶处理"))
elif self.configs.use_vocoder and parallel_infer:
print(i18n("当开启并行推理模式时SoVits V3/4模型不支持分桶处理,已自动关闭分桶处理"))
split_bucket = False
else:
print(i18n("分桶处理模式已关闭"))
@ -936,7 +1022,7 @@ class TTS:
if not no_prompt_text:
assert prompt_lang in self.configs.languages
if no_prompt_text and self.configs.is_v3_synthesizer:
if no_prompt_text and self.configs.use_vocoder:
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
if ref_audio_path in [None, ""] and (
@ -1044,7 +1130,7 @@ class TTS:
t_34 = 0.0
t_45 = 0.0
audio = []
output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
for item in data:
t3 = time.perf_counter()
if return_fragment:
@ -1106,7 +1192,7 @@ class TTS:
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# ))
print(f"############ {i18n('合成音频')} ############")
if not self.configs.is_v3_synthesizer:
if not self.configs.use_vocoder:
if speed_factor == 1.0:
print(f"{i18n('并行合成中')}...")
# ## vits并行推理 method 2
@ -1143,7 +1229,7 @@ class TTS:
else:
if parallel_infer:
print(f"{i18n('并行合成中')}...")
audio_fragments = self.v3_synthesis_batched_infer(
audio_fragments = self.useing_vocoder_synthesis_batched_infer(
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.extend(audio_fragments)
@ -1153,7 +1239,7 @@ class TTS:
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.v3_synthesis(
audio_fragment = self.useing_vocoder_synthesis(
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.append(audio_fragment)
@ -1169,7 +1255,7 @@ class TTS:
speed_factor,
False,
fragment_interval,
super_sampling if self.configs.is_v3_synthesizer else False,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
else:
audio.append(batch_audio_fragment)
@ -1190,7 +1276,7 @@ class TTS:
speed_factor,
split_bucket,
fragment_interval,
super_sampling if self.configs.is_v3_synthesizer else False,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
except Exception as e:
@ -1272,7 +1358,7 @@ class TTS:
return sr, audio
def v3_synthesis(
def useing_vocoder_synthesis(
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
@ -1285,19 +1371,23 @@ class TTS:
ref_audio = ref_audio.to(self.configs.device).float()
if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr != 24000:
tgt_sr = self.vocoder_configs["sr"]
if ref_sr != tgt_sr:
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
mel2 = mel_fn(ref_audio)
mel2 = mel_fn(ref_audio) if self.configs.version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if T_min > 468:
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
T_ref = self.vocoder_configs["T_ref"]
T_chunk = self.vocoder_configs["T_chunk"]
if T_min > T_ref:
mel2 = mel2[:, :, -T_ref:]
fea_ref = fea_ref[:, :, -T_ref:]
T_min = T_ref
chunk_len = T_chunk - T_min
mel2 = mel2.to(self.precision)
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
@ -1324,12 +1414,12 @@ class TTS:
cfm_res = denorm_spec(cfm_res)
with torch.inference_mode():
wav_gen = self.bigvgan_model(cfm_res)
wav_gen = self.vocoder(cfm_res)
audio = wav_gen[0][0] # .cpu().detach().numpy()
return audio
def v3_synthesis_batched_infer(
def useing_vocoder_synthesis_batched_infer(
self,
idx_list: List[int],
semantic_tokens_list: List[torch.Tensor],
@ -1350,21 +1440,27 @@ class TTS:
if ref_sr != 24000:
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
mel2 = mel_fn(ref_audio)
tgt_sr = self.vocoder_configs["sr"]
if ref_sr != tgt_sr:
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
mel2 = mel_fn(ref_audio) if self.configs.version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if T_min > 468:
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
T_ref = self.vocoder_configs["T_ref"]
T_chunk = self.vocoder_configs["T_chunk"]
if T_min > T_ref:
mel2 = mel2[:, :, -T_ref:]
fea_ref = fea_ref[:, :, -T_ref:]
T_min = T_ref
chunk_len = T_chunk - T_min
mel2 = mel2.to(self.precision)
# #### batched inference
overlapped_len = 12
overlapped_len = self.vocoder_configs["overlapped_len"]
feat_chunks = []
feat_lens = []
feat_list = []
@ -1413,11 +1509,11 @@ class TTS:
pred_spec = denorm_spec(pred_spec)
with torch.no_grad():
wav_gen = self.bigvgan_model(pred_spec)
wav_gen = self.vocoder(pred_spec)
audio = wav_gen[0][0] # .cpu().detach().numpy()
audio_fragments = []
upsample_rate = 256
upsample_rate = self.vocoder_configs["upsample_rate"]
pos = 0
while pos < audio.shape[-1]:

View File

@ -30,3 +30,11 @@ v3:
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
v4:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v4
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth

View File

@ -195,26 +195,31 @@ def change_choices():
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
pretrained_sovits_name = [
"GPT_SoVITS/pretrained_models/s2G488k.pth",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
path_sovits_v3,
"GPT_SoVITS/pretrained_models/s2Gv3.pth",
"GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
]
pretrained_gpt_name = [
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
]
_ = [[], []]
for i in range(3):
for i in range(4):
if os.path.exists(pretrained_gpt_name[i]):
_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name, pretrained_sovits_name = _
if os.path.exists("./weight.json"):
pass
else:
@ -232,8 +237,8 @@ with open("./weight.json", "r", encoding="utf-8") as file:
sovits_path = sovits_path[0]
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
for path in SoVITS_weight_root + GPT_weight_root:
os.makedirs(path, exist_ok=True)
@ -257,13 +262,14 @@ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast
v3v4set={"v3","v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
global version, model_version, dict_language, if_lora_v3
# global vq_model, hps, version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3)
if if_lora_v3 and not os.path.exists(path_sovits_v3):
info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
print(sovits_path,version, model_version, if_lora_v3)
is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
if if_lora_v3 == True and is_exist == False:
info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info)
raise FileExistsError(info)
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
@ -281,13 +287,12 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
else:
text_update = {"__type__": "update", "value": ""}
text_language_update = {"__type__": "update", "value": i18n("中文")}
if model_version == "v3":
if model_version in v3v4set:
visible_sample_steps = True
visible_inp_refs = False
else:
visible_sample_steps = False
visible_inp_refs = True
# prompt_language,text_language,prompt_text,prompt_language,text,text_language,inp_refs,ref_text_free,
yield (
{"__type__": "update", "choices": list(dict_language.keys())},
{"__type__": "update", "choices": list(dict_language.keys())},