适配V3版本

This commit is contained in:
ChasonJiang 2025-03-05 18:16:29 +08:00
parent b5fa9dd85b
commit bf06ac589c

View File

@ -19,7 +19,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from feature_extractor.cnhubert import CNHubert from feature_extractor.cnhubert import CNHubert
from module.models import SynthesizerTrn from module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, get_peft_model
import librosa import librosa
from time import time as ttime from time import time as ttime
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
@ -29,6 +30,7 @@ from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from BigVGAN.bigvgan import BigVGAN from BigVGAN.bigvgan import BigVGAN
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
language=os.environ.get("language","Auto") language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language) i18n = I18nAuto(language=language)
@ -84,6 +86,14 @@ default_v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2 version: v2
default_v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
version: v3
""" """
def set_seed(seed:int): def set_seed(seed:int):
@ -110,7 +120,7 @@ def set_seed(seed:int):
class TTS_Config: class TTS_Config:
default_configs={ default_configs={
"default":{ "v1":{
"device": "cpu", "device": "cpu",
"is_half": False, "is_half": False,
"version": "v1", "version": "v1",
@ -119,7 +129,7 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}, },
"default_v2":{ "v2":{
"device": "cpu", "device": "cpu",
"is_half": False, "is_half": False,
"version": "v2", "version": "v2",
@ -128,6 +138,15 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}, },
"v3":{
"device": "cpu",
"is_half": False,
"version": "v3",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
} }
configs:dict = None configs:dict = None
v1_languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v1_languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
@ -164,12 +183,9 @@ class TTS_Config:
assert isinstance(configs, dict) assert isinstance(configs, dict)
version = configs.get("version", "v2").lower() version = configs.get("version", "v2").lower()
assert version in ["v1", "v2"] assert version in ["v1", "v2", "v3"]
self.default_configs["default"] = configs.get("default", self.default_configs["default"]) self.default_configs[version] = configs.get(version, self.default_configs[version])
self.default_configs["default_v2"] = configs.get("default_v2", self.default_configs["default_v2"]) self.configs:dict = configs.get("custom", deepcopy(self.default_configs[version]))
default_config_key = "default"if version=="v1" else "default_v2"
self.configs:dict = configs.get("custom", deepcopy(self.default_configs[default_config_key]))
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
@ -177,7 +193,7 @@ class TTS_Config:
print(f"Warning: CUDA is not available, set device to CPU.") print(f"Warning: CUDA is not available, set device to CPU.")
self.device = torch.device("cpu") self.device = torch.device("cpu")
# self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
# if str(self.device) == "cpu" and self.is_half: # if str(self.device) == "cpu" and self.is_half:
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.") # print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
# self.is_half = False # self.is_half = False
@ -187,22 +203,22 @@ class TTS_Config:
self.vits_weights_path = self.configs.get("vits_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None) self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages self.languages = self.v1_languages if self.version=="v1" else self.v2_languages
self.is_v3_synthesizer:bool = False self.is_v3_synthesizer:bool = False
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path'] self.t2s_weights_path = self.default_configs[version]['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
self.vits_weights_path = self.default_configs[default_config_key]['vits_weights_path'] self.vits_weights_path = self.default_configs[version]['vits_weights_path']
print(f"fall back to default vits_weights_path: {self.vits_weights_path}") print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
self.bert_base_path = self.default_configs[default_config_key]['bert_base_path'] self.bert_base_path = self.default_configs[version]['bert_base_path']
print(f"fall back to default bert_base_path: {self.bert_base_path}") print(f"fall back to default bert_base_path: {self.bert_base_path}")
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
self.cnhuhbert_base_path = self.default_configs[default_config_key]['cnhuhbert_base_path'] self.cnhuhbert_base_path = self.default_configs[version]['cnhuhbert_base_path']
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs() self.update_configs()
@ -254,7 +270,7 @@ class TTS_Config:
def update_version(self, version:str)->None: def update_version(self, version:str)->None:
self.version = version self.version = version
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages self.languages = self.v1_languages if self.version=="v1" else self.v2_languages
def __str__(self): def __str__(self):
self.configs = self.update_configs() self.configs = self.update_configs()
@ -282,7 +298,7 @@ class TTS:
self.configs:TTS_Config = TTS_Config(configs) self.configs:TTS_Config = TTS_Config(configs)
self.t2s_model:Text2SemanticLightningModule = None self.t2s_model:Text2SemanticLightningModule = None
self.vits_model:SynthesizerTrn = None self.vits_model:Union[SynthesizerTrn, SynthesizerTrnV3] = None
self.bert_tokenizer:AutoTokenizer = None self.bert_tokenizer:AutoTokenizer = None
self.bert_model:AutoModelForMaskedLM = None self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None self.cnhuhbert_model:CNHubert = None
@ -341,38 +357,81 @@ class TTS:
self.bert_model = self.bert_model.half() self.bert_model = self.bert_model.half()
def init_vits_weights(self, weights_path: str): def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(weights_path)
path_sovits_v3=self.configs.default_configs["v3"]["vits_weights_path"]
if if_lora_v3==True and os.path.exists(path_sovits_v3)==False:
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
raise FileExistsError(info)
dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
hps = dict_s2["config"] hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.configs.update_version("v1")
else:
self.configs.update_version("v2")
self.configs.save_configs()
hps["model"]["version"] = self.configs.version hps["model"]["semantic_frame_rate"] = "25hz"
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps["model"]["version"] = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps["model"]["version"] = "v1"
else:
hps["model"]["version"] = "v2"
version = hps["model"]["version"]
self.configs.filter_length = hps["data"]["filter_length"] self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"] self.configs.segment_size = hps["train"]["segment_size"]
self.configs.sampling_rate = hps["data"]["sampling_rate"] self.configs.sampling_rate = hps["data"]["sampling_rate"]
self.configs.hop_length = hps["data"]["hop_length"] self.configs.hop_length = hps["data"]["hop_length"]
self.configs.win_length = hps["data"]["win_length"] self.configs.win_length = hps["data"]["win_length"]
self.configs.n_speakers = hps["data"]["n_speakers"] self.configs.n_speakers = hps["data"]["n_speakers"]
self.configs.semantic_frame_rate = "25hz" self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"]
kwargs = hps["model"] kwargs = hps["model"]
self.configs.update_version(version)
if model_version!="v3":
vits_model = SynthesizerTrn( vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1, self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length, self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers, n_speakers=self.configs.n_speakers,
**kwargs **kwargs
) )
model_version=version
if hasattr(vits_model, "enc_q"): if hasattr(vits_model, "enc_q"):
del vits_model.enc_q del vits_model.enc_q
self.configs.is_v3_synthesizer = False
else:
vits_model = SynthesizerTrnV3(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs
)
self.configs.is_v3_synthesizer = True
self.init_bigvgan()
if if_lora_v3==False:
print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}")
else:
print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)}")
lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}")
vits_model.cfm = vits_model.cfm.merge_and_unload()
vits_model = vits_model.to(self.configs.device) vits_model = vits_model.to(self.configs.device)
vits_model = vits_model.eval() vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model self.vits_model = vits_model
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device)!="cpu":
self.vits_model = self.vits_model.half() self.vits_model = self.vits_model.half()
@ -396,6 +455,8 @@ class TTS:
def init_bigvgan(self): def init_bigvgan(self):
if self.bigvgan_model is not None:
return
self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode # remove weight norm in the model and set to eval mode
self.bigvgan_model.remove_weight_norm() self.bigvgan_model.remove_weight_norm()
@ -746,6 +807,7 @@ class TTS:
actual_seed = set_seed(seed) actual_seed = set_seed(seed)
parallel_infer = inputs.get("parallel_infer", True) parallel_infer = inputs.get("parallel_infer", True)
repetition_penalty = inputs.get("repetition_penalty", 1.35) repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 16)
if parallel_infer: if parallel_infer:
print(i18n("并行推理模式已开启")) print(i18n("并行推理模式已开启"))
@ -944,6 +1006,7 @@ class TTS:
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# )) # ))
if not self.configs.is_v3_synthesizer:
if speed_factor == 1.0: if speed_factor == 1.0:
# ## vits并行推理 method 2 # ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
@ -968,6 +1031,16 @@ class TTS:
batch_audio_fragment.append( batch_audio_fragment.append(
audio_fragment audio_fragment
) ###试试重建不带上prompt部分 ) ###试试重建不带上prompt部分
else:
for i, idx in enumerate(idx_list):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.v3_synthesis(
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.append(
audio_fragment
)
t5 = ttime() t5 = ttime()
t_45 += t5 - t4 t_45 += t5 - t4
@ -1124,6 +1197,7 @@ class TTS:
wav_gen = self.bigvgan_model(cmf_res) wav_gen = self.bigvgan_model(cmf_res)
audio=wav_gen[0][0]#.cpu().detach().numpy() audio=wav_gen[0][0]#.cpu().detach().numpy()
return audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int): def speed_change(input_audio:np.ndarray, speed:float, sr:int):