modified: GPT_SoVITS/TTS_infer_pack/TTS.py

modified:   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	modified:   GPT_SoVITS/inference_webui_fast.py
This commit is contained in:
ChasonJiang 2025-03-04 23:23:44 +08:00
parent 6dd2f72090
commit e9791aff9c
3 changed files with 139 additions and 6 deletions

View File

@ -4,6 +4,7 @@ import os, sys, gc
import random
import traceback
import torchaudio
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -26,10 +27,37 @@ from tools.my_utils import load_audio
from module.mel_processing import spectrogram_torch
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from BigVGAN.bigvgan import BigVGAN
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
spec_min = -12
spec_max = 2
def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram_torch(x, **{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
"num_mels": 100,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"center": False
})
# configs/tts_infer.yaml
"""
custom:
@ -157,6 +185,8 @@ class TTS_Config:
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
self.is_v3_synthesizer:bool = False
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path']
@ -252,6 +282,7 @@ class TTS:
self.bert_tokenizer:AutoTokenizer = None
self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None
self.bigvgan_model:BigVGAN = None
self._init_models()
@ -359,6 +390,19 @@ class TTS:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.t2s_model = self.t2s_model.half()
def init_bigvgan(self):
self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
self.bigvgan_model.remove_weight_norm()
self.bigvgan_model = self.bigvgan_model.eval()
if self.configs.is_half == True:
self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device)
else:
self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
def enable_half_precision(self, enable: bool = True, save: bool = True):
'''
To enable half precision for the TTS model.
@ -383,6 +427,8 @@ class TTS:
self.bert_model =self.bert_model.half()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half()
if self.bigvgan_model is not None:
self.bigvgan_model = self.bigvgan_model.half()
else:
if self.t2s_model is not None:
self.t2s_model = self.t2s_model.float()
@ -392,6 +438,8 @@ class TTS:
self.bert_model = self.bert_model.float()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float()
if self.bigvgan_model is not None:
self.bigvgan_model = self.bigvgan_model.float()
def set_device(self, device: torch.device, save: bool = True):
'''
@ -728,6 +776,9 @@ class TTS:
if not no_prompt_text:
assert prompt_lang in self.configs.languages
if no_prompt_text and self.configs.is_v3_synthesizer:
raise RuntimeError("prompt_text cannot be empty when using SoVITS_V3")
if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
@ -1014,6 +1065,61 @@ class TTS:
return sr, audio
def v3_synthesis(self,
semantic_tokens:torch.Tensor,
phones:torch.Tensor,
speed:float=1.0,
sample_steps:int=16
):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio:torch.Tensor = self.prompt_cache["raw_audio"]
ref_sr = self.prompt_cache["raw_sr"]
ref_audio=ref_audio.to(self.configs.device).float()
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr!=24000:
ref_audio=resample(ref_audio, ref_sr)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if (T_min > 468):
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
mel2=mel2.to(self.precision)
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
cfm_resss = []
idx = 0
while (1):
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
with torch.inference_mode():
wav_gen = self.bigvgan_model(cmf_res)
audio=wav_gen[0][0]#.cpu().detach().numpy()
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
@ -1036,3 +1142,14 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
resample_transform_dict={}
def resample(audio_tensor, sr0, device):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
return resample_transform_dict[sr0](audio_tensor)

View File

@ -118,11 +118,11 @@ class TextPreprocessor:
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
# language = language.replace("all_","")
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "zh":
if language == "all_zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
@ -130,7 +130,7 @@ class TextPreprocessor:
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"yue",version)

View File

@ -194,10 +194,20 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
tts_pipeline.init_vits_weights(sovits_path)
global version, dict_language
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3)
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
if if_lora_v3 and not os.path.exists(path_sovits_v3):
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info)
raise FileExistsError(info)
tts_pipeline.init_vits_weights(sovits_path)
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
@ -210,7 +220,13 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
else:
text_update = {'__type__':'update', 'value':''}
text_language_update = {'__type__':'update', 'value':i18n("中文")}
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
if model_version=="v3":
visible_sample_steps=True
visible_inp_refs=False
else:
visible_sample_steps=False
visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}