mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-05-09 22:29:06 +08:00
modified: GPT_SoVITS/TTS_infer_pack/TTS.py
modified: GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py modified: GPT_SoVITS/inference_webui_fast.py
This commit is contained in:
parent
6dd2f72090
commit
e9791aff9c
@ -4,6 +4,7 @@ import os, sys, gc
|
|||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import torchaudio
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -26,10 +27,37 @@ from tools.my_utils import load_audio
|
|||||||
from module.mel_processing import spectrogram_torch
|
from module.mel_processing import spectrogram_torch
|
||||||
from TTS_infer_pack.text_segmentation_method import splits
|
from TTS_infer_pack.text_segmentation_method import splits
|
||||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||||
|
from BigVGAN.bigvgan import BigVGAN
|
||||||
|
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
|
||||||
language=os.environ.get("language","Auto")
|
language=os.environ.get("language","Auto")
|
||||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
i18n = I18nAuto(language=language)
|
i18n = I18nAuto(language=language)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
spec_min = -12
|
||||||
|
spec_max = 2
|
||||||
|
def norm_spec(x):
|
||||||
|
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
||||||
|
def denorm_spec(x):
|
||||||
|
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||||||
|
mel_fn=lambda x: mel_spectrogram_torch(x, **{
|
||||||
|
"n_fft": 1024,
|
||||||
|
"win_size": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"num_mels": 100,
|
||||||
|
"sampling_rate": 24000,
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": None,
|
||||||
|
"center": False
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# configs/tts_infer.yaml
|
# configs/tts_infer.yaml
|
||||||
"""
|
"""
|
||||||
custom:
|
custom:
|
||||||
@ -157,6 +185,8 @@ class TTS_Config:
|
|||||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||||
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
|
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
|
||||||
|
|
||||||
|
self.is_v3_synthesizer:bool = False
|
||||||
|
|
||||||
|
|
||||||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||||
self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path']
|
self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path']
|
||||||
@ -252,6 +282,7 @@ class TTS:
|
|||||||
self.bert_tokenizer:AutoTokenizer = None
|
self.bert_tokenizer:AutoTokenizer = None
|
||||||
self.bert_model:AutoModelForMaskedLM = None
|
self.bert_model:AutoModelForMaskedLM = None
|
||||||
self.cnhuhbert_model:CNHubert = None
|
self.cnhuhbert_model:CNHubert = None
|
||||||
|
self.bigvgan_model:BigVGAN = None
|
||||||
|
|
||||||
self._init_models()
|
self._init_models()
|
||||||
|
|
||||||
@ -359,6 +390,19 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.t2s_model = self.t2s_model.half()
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
|
|
||||||
|
def init_bigvgan(self):
|
||||||
|
self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||||
|
# remove weight norm in the model and set to eval mode
|
||||||
|
self.bigvgan_model.remove_weight_norm()
|
||||||
|
self.bigvgan_model = self.bigvgan_model.eval()
|
||||||
|
if self.configs.is_half == True:
|
||||||
|
self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device)
|
||||||
|
else:
|
||||||
|
self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||||
'''
|
'''
|
||||||
To enable half precision for the TTS model.
|
To enable half precision for the TTS model.
|
||||||
@ -383,6 +427,8 @@ class TTS:
|
|||||||
self.bert_model =self.bert_model.half()
|
self.bert_model =self.bert_model.half()
|
||||||
if self.cnhuhbert_model is not None:
|
if self.cnhuhbert_model is not None:
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||||
|
if self.bigvgan_model is not None:
|
||||||
|
self.bigvgan_model = self.bigvgan_model.half()
|
||||||
else:
|
else:
|
||||||
if self.t2s_model is not None:
|
if self.t2s_model is not None:
|
||||||
self.t2s_model = self.t2s_model.float()
|
self.t2s_model = self.t2s_model.float()
|
||||||
@ -392,6 +438,8 @@ class TTS:
|
|||||||
self.bert_model = self.bert_model.float()
|
self.bert_model = self.bert_model.float()
|
||||||
if self.cnhuhbert_model is not None:
|
if self.cnhuhbert_model is not None:
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||||
|
if self.bigvgan_model is not None:
|
||||||
|
self.bigvgan_model = self.bigvgan_model.float()
|
||||||
|
|
||||||
def set_device(self, device: torch.device, save: bool = True):
|
def set_device(self, device: torch.device, save: bool = True):
|
||||||
'''
|
'''
|
||||||
@ -728,6 +776,9 @@ class TTS:
|
|||||||
if not no_prompt_text:
|
if not no_prompt_text:
|
||||||
assert prompt_lang in self.configs.languages
|
assert prompt_lang in self.configs.languages
|
||||||
|
|
||||||
|
if no_prompt_text and self.configs.is_v3_synthesizer:
|
||||||
|
raise RuntimeError("prompt_text cannot be empty when using SoVITS_V3")
|
||||||
|
|
||||||
if ref_audio_path in [None, ""] and \
|
if ref_audio_path in [None, ""] and \
|
||||||
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
|
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
|
||||||
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
|
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
|
||||||
@ -1014,6 +1065,61 @@ class TTS:
|
|||||||
return sr, audio
|
return sr, audio
|
||||||
|
|
||||||
|
|
||||||
|
def v3_synthesis(self,
|
||||||
|
semantic_tokens:torch.Tensor,
|
||||||
|
phones:torch.Tensor,
|
||||||
|
speed:float=1.0,
|
||||||
|
sample_steps:int=16
|
||||||
|
):
|
||||||
|
|
||||||
|
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device)
|
||||||
|
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||||
|
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
|
||||||
|
|
||||||
|
fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||||
|
ref_audio:torch.Tensor = self.prompt_cache["raw_audio"]
|
||||||
|
ref_sr = self.prompt_cache["raw_sr"]
|
||||||
|
ref_audio=ref_audio.to(self.configs.device).float()
|
||||||
|
if (ref_audio.shape[0] == 2):
|
||||||
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||||
|
if ref_sr!=24000:
|
||||||
|
ref_audio=resample(ref_audio, ref_sr)
|
||||||
|
# print("ref_audio",ref_audio.abs().mean())
|
||||||
|
mel2 = mel_fn(ref_audio)
|
||||||
|
mel2 = norm_spec(mel2)
|
||||||
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||||
|
mel2 = mel2[:, :, :T_min]
|
||||||
|
fea_ref = fea_ref[:, :, :T_min]
|
||||||
|
if (T_min > 468):
|
||||||
|
mel2 = mel2[:, :, -468:]
|
||||||
|
fea_ref = fea_ref[:, :, -468:]
|
||||||
|
T_min = 468
|
||||||
|
chunk_len = 934 - T_min
|
||||||
|
|
||||||
|
mel2=mel2.to(self.precision)
|
||||||
|
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||||
|
|
||||||
|
cfm_resss = []
|
||||||
|
idx = 0
|
||||||
|
while (1):
|
||||||
|
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
|
||||||
|
if (fea_todo_chunk.shape[-1] == 0): break
|
||||||
|
idx += chunk_len
|
||||||
|
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||||
|
|
||||||
|
cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
|
||||||
|
cfm_res = cfm_res[:, :, mel2.shape[2]:]
|
||||||
|
mel2 = cfm_res[:, :, -T_min:]
|
||||||
|
|
||||||
|
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||||
|
cfm_resss.append(cfm_res)
|
||||||
|
cmf_res = torch.cat(cfm_resss, 2)
|
||||||
|
cmf_res = denorm_spec(cmf_res)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
wav_gen = self.bigvgan_model(cmf_res)
|
||||||
|
audio=wav_gen[0][0]#.cpu().detach().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||||
@ -1036,3 +1142,14 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
|||||||
processed_audio = np.frombuffer(out, np.int16)
|
processed_audio = np.frombuffer(out, np.int16)
|
||||||
|
|
||||||
return processed_audio
|
return processed_audio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
resample_transform_dict={}
|
||||||
|
def resample(audio_tensor, sr0, device):
|
||||||
|
global resample_transform_dict
|
||||||
|
if sr0 not in resample_transform_dict:
|
||||||
|
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
|
||||||
|
sr0, 24000
|
||||||
|
).to(device)
|
||||||
|
return resample_transform_dict[sr0](audio_tensor)
|
@ -118,11 +118,11 @@ class TextPreprocessor:
|
|||||||
|
|
||||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||||
language = language.replace("all_","")
|
# language = language.replace("all_","")
|
||||||
formattext = text
|
formattext = text
|
||||||
while " " in formattext:
|
while " " in formattext:
|
||||||
formattext = formattext.replace(" ", " ")
|
formattext = formattext.replace(" ", " ")
|
||||||
if language == "zh":
|
if language == "all_zh":
|
||||||
if re.search(r'[A-Za-z]', formattext):
|
if re.search(r'[A-Za-z]', formattext):
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||||
formattext = chinese.mix_text_normalize(formattext)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
@ -130,7 +130,7 @@ class TextPreprocessor:
|
|||||||
else:
|
else:
|
||||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||||
formattext = chinese.mix_text_normalize(formattext)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
return self.get_phones_and_bert(formattext,"yue",version)
|
return self.get_phones_and_bert(formattext,"yue",version)
|
||||||
|
@ -194,10 +194,20 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
|||||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||||
|
|
||||||
|
|
||||||
|
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||||
tts_pipeline.init_vits_weights(sovits_path)
|
|
||||||
global version, dict_language
|
global version, dict_language
|
||||||
|
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
|
||||||
|
# print(sovits_path,version, model_version, if_lora_v3)
|
||||||
|
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||||
|
|
||||||
|
if if_lora_v3 and not os.path.exists(path_sovits_v3):
|
||||||
|
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||||
|
gr.Warning(info)
|
||||||
|
raise FileExistsError(info)
|
||||||
|
|
||||||
|
tts_pipeline.init_vits_weights(sovits_path)
|
||||||
|
|
||||||
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
||||||
if prompt_language is not None and text_language is not None:
|
if prompt_language is not None and text_language is not None:
|
||||||
if prompt_language in list(dict_language.keys()):
|
if prompt_language in list(dict_language.keys()):
|
||||||
@ -210,7 +220,13 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
|||||||
else:
|
else:
|
||||||
text_update = {'__type__':'update', 'value':''}
|
text_update = {'__type__':'update', 'value':''}
|
||||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||||
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
|
if model_version=="v3":
|
||||||
|
visible_sample_steps=True
|
||||||
|
visible_inp_refs=False
|
||||||
|
else:
|
||||||
|
visible_sample_steps=False
|
||||||
|
visible_inp_refs=True
|
||||||
|
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user