适配V3版本

This commit is contained in:
ChasonJiang 2025-03-05 18:16:29 +08:00
parent b5fa9dd85b
commit bf06ac589c

View File

@ -19,7 +19,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from feature_extractor.cnhubert import CNHubert
from module.models import SynthesizerTrn
from module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, get_peft_model
import librosa
from time import time as ttime
from tools.i18n.i18n import I18nAuto, scan_language_list
@ -29,6 +30,7 @@ from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from BigVGAN.bigvgan import BigVGAN
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
@ -84,6 +86,14 @@ default_v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2
default_v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
version: v3
"""
def set_seed(seed:int):
@ -110,7 +120,7 @@ def set_seed(seed:int):
class TTS_Config:
default_configs={
"default":{
"v1":{
"device": "cpu",
"is_half": False,
"version": "v1",
@ -119,7 +129,7 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
"default_v2":{
"v2":{
"device": "cpu",
"is_half": False,
"version": "v2",
@ -128,6 +138,15 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
"v3":{
"device": "cpu",
"is_half": False,
"version": "v3",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
}
configs:dict = None
v1_languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
@ -164,12 +183,9 @@ class TTS_Config:
assert isinstance(configs, dict)
version = configs.get("version", "v2").lower()
assert version in ["v1", "v2"]
self.default_configs["default"] = configs.get("default", self.default_configs["default"])
self.default_configs["default_v2"] = configs.get("default_v2", self.default_configs["default_v2"])
default_config_key = "default"if version=="v1" else "default_v2"
self.configs:dict = configs.get("custom", deepcopy(self.default_configs[default_config_key]))
assert version in ["v1", "v2", "v3"]
self.default_configs[version] = configs.get(version, self.default_configs[version])
self.configs:dict = configs.get("custom", deepcopy(self.default_configs[version]))
self.device = self.configs.get("device", torch.device("cpu"))
@ -177,7 +193,7 @@ class TTS_Config:
print(f"Warning: CUDA is not available, set device to CPU.")
self.device = torch.device("cpu")
# self.is_half = self.configs.get("is_half", False)
self.is_half = self.configs.get("is_half", False)
# if str(self.device) == "cpu" and self.is_half:
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
# self.is_half = False
@ -187,22 +203,22 @@ class TTS_Config:
self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
self.languages = self.v1_languages if self.version=="v1" else self.v2_languages
self.is_v3_synthesizer:bool = False
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path']
self.t2s_weights_path = self.default_configs[version]['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
self.vits_weights_path = self.default_configs[default_config_key]['vits_weights_path']
self.vits_weights_path = self.default_configs[version]['vits_weights_path']
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
self.bert_base_path = self.default_configs[default_config_key]['bert_base_path']
self.bert_base_path = self.default_configs[version]['bert_base_path']
print(f"fall back to default bert_base_path: {self.bert_base_path}")
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
self.cnhuhbert_base_path = self.default_configs[default_config_key]['cnhuhbert_base_path']
self.cnhuhbert_base_path = self.default_configs[version]['cnhuhbert_base_path']
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs()
@ -254,7 +270,7 @@ class TTS_Config:
def update_version(self, version:str)->None:
self.version = version
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
self.languages = self.v1_languages if self.version=="v1" else self.v2_languages
def __str__(self):
self.configs = self.update_configs()
@ -282,7 +298,7 @@ class TTS:
self.configs:TTS_Config = TTS_Config(configs)
self.t2s_model:Text2SemanticLightningModule = None
self.vits_model:SynthesizerTrn = None
self.vits_model:Union[SynthesizerTrn, SynthesizerTrnV3] = None
self.bert_tokenizer:AutoTokenizer = None
self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None
@ -341,38 +357,81 @@ class TTS:
self.bert_model = self.bert_model.half()
def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(weights_path)
path_sovits_v3=self.configs.default_configs["v3"]["vits_weights_path"]
if if_lora_v3==True and os.path.exists(path_sovits_v3)==False:
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
raise FileExistsError(info)
dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.configs.update_version("v1")
else:
self.configs.update_version("v2")
self.configs.save_configs()
hps["model"]["version"] = self.configs.version
hps["model"]["semantic_frame_rate"] = "25hz"
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps["model"]["version"] = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps["model"]["version"] = "v1"
else:
hps["model"]["version"] = "v2"
version = hps["model"]["version"]
self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"]
self.configs.sampling_rate = hps["data"]["sampling_rate"]
self.configs.hop_length = hps["data"]["hop_length"]
self.configs.win_length = hps["data"]["win_length"]
self.configs.n_speakers = hps["data"]["n_speakers"]
self.configs.semantic_frame_rate = "25hz"
self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"]
kwargs = hps["model"]
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs
)
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
self.configs.update_version(version)
if model_version!="v3":
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs
)
model_version=version
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
self.configs.is_v3_synthesizer = False
else:
vits_model = SynthesizerTrnV3(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs
)
self.configs.is_v3_synthesizer = True
self.init_bigvgan()
if if_lora_v3==False:
print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}")
else:
print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)}")
lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}")
vits_model.cfm = vits_model.cfm.merge_and_unload()
vits_model = vits_model.to(self.configs.device)
vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model
if self.configs.is_half and str(self.configs.device)!="cpu":
self.vits_model = self.vits_model.half()
@ -396,6 +455,8 @@ class TTS:
def init_bigvgan(self):
if self.bigvgan_model is not None:
return
self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
self.bigvgan_model.remove_weight_norm()
@ -746,6 +807,7 @@ class TTS:
actual_seed = set_seed(seed)
parallel_infer = inputs.get("parallel_infer", True)
repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 16)
if parallel_infer:
print(i18n("并行推理模式已开启"))
@ -944,30 +1006,41 @@ class TTS:
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# ))
if speed_factor == 1.0:
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
if not self.configs.is_v3_synthesizer:
if speed_factor == 1.0:
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
else:
# ## vits串行推理
for i, idx in enumerate(idx_list):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :])
batch_audio_fragment.append(
audio_fragment
) ###试试重建不带上prompt部分
else:
# ## vits串行推理
for i, idx in enumerate(idx_list):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :])
audio_fragment = self.v3_synthesis(
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.append(
audio_fragment
) ###试试重建不带上prompt部分
)
t5 = ttime()
t_45 += t5 - t4
@ -1076,54 +1149,55 @@ class TTS:
sample_steps:int=16
):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio:torch.Tensor = self.prompt_cache["raw_audio"]
ref_sr = self.prompt_cache["raw_sr"]
ref_audio=ref_audio.to(self.configs.device).float()
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr!=24000:
ref_audio=resample(ref_audio, ref_sr)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if (T_min > 468):
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio:torch.Tensor = self.prompt_cache["raw_audio"]
ref_sr = self.prompt_cache["raw_sr"]
ref_audio=ref_audio.to(self.configs.device).float()
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr!=24000:
ref_audio=resample(ref_audio, ref_sr)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if (T_min > 468):
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
mel2=mel2.to(self.precision)
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
mel2=mel2.to(self.precision)
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
cfm_resss = []
idx = 0
while (1):
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
cfm_resss = []
idx = 0
while (1):
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
with torch.inference_mode():
wav_gen = self.bigvgan_model(cmf_res)
audio=wav_gen[0][0]#.cpu().detach().numpy()
cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
with torch.inference_mode():
wav_gen = self.bigvgan_model(cmf_res)
audio=wav_gen[0][0]#.cpu().detach().numpy()
return audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):