mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
modified: GPT_SoVITS/TTS_infer_pack/TTS.py
modified: GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py modified: GPT_SoVITS/inference_webui_fast.py
This commit is contained in:
parent
6dd2f72090
commit
e9791aff9c
@ -4,6 +4,7 @@ import os, sys, gc
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -26,10 +27,37 @@ from tools.my_utils import load_audio
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
from BigVGAN.bigvgan import BigVGAN
|
||||
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
|
||||
|
||||
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
def norm_spec(x):
|
||||
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
||||
def denorm_spec(x):
|
||||
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||||
mel_fn=lambda x: mel_spectrogram_torch(x, **{
|
||||
"n_fft": 1024,
|
||||
"win_size": 1024,
|
||||
"hop_size": 256,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 24000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# configs/tts_infer.yaml
|
||||
"""
|
||||
custom:
|
||||
@ -157,6 +185,8 @@ class TTS_Config:
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
|
||||
|
||||
self.is_v3_synthesizer:bool = False
|
||||
|
||||
|
||||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||
self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path']
|
||||
@ -252,6 +282,7 @@ class TTS:
|
||||
self.bert_tokenizer:AutoTokenizer = None
|
||||
self.bert_model:AutoModelForMaskedLM = None
|
||||
self.cnhuhbert_model:CNHubert = None
|
||||
self.bigvgan_model:BigVGAN = None
|
||||
|
||||
self._init_models()
|
||||
|
||||
@ -359,6 +390,19 @@ class TTS:
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
|
||||
def init_bigvgan(self):
|
||||
self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||
# remove weight norm in the model and set to eval mode
|
||||
self.bigvgan_model.remove_weight_norm()
|
||||
self.bigvgan_model = self.bigvgan_model.eval()
|
||||
if self.configs.is_half == True:
|
||||
self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device)
|
||||
else:
|
||||
self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
|
||||
|
||||
|
||||
|
||||
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||
'''
|
||||
To enable half precision for the TTS model.
|
||||
@ -383,6 +427,8 @@ class TTS:
|
||||
self.bert_model =self.bert_model.half()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
if self.bigvgan_model is not None:
|
||||
self.bigvgan_model = self.bigvgan_model.half()
|
||||
else:
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model = self.t2s_model.float()
|
||||
@ -392,6 +438,8 @@ class TTS:
|
||||
self.bert_model = self.bert_model.float()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||
if self.bigvgan_model is not None:
|
||||
self.bigvgan_model = self.bigvgan_model.float()
|
||||
|
||||
def set_device(self, device: torch.device, save: bool = True):
|
||||
'''
|
||||
@ -728,6 +776,9 @@ class TTS:
|
||||
if not no_prompt_text:
|
||||
assert prompt_lang in self.configs.languages
|
||||
|
||||
if no_prompt_text and self.configs.is_v3_synthesizer:
|
||||
raise RuntimeError("prompt_text cannot be empty when using SoVITS_V3")
|
||||
|
||||
if ref_audio_path in [None, ""] and \
|
||||
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
|
||||
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
|
||||
@ -1014,6 +1065,61 @@ class TTS:
|
||||
return sr, audio
|
||||
|
||||
|
||||
def v3_synthesis(self,
|
||||
semantic_tokens:torch.Tensor,
|
||||
phones:torch.Tensor,
|
||||
speed:float=1.0,
|
||||
sample_steps:int=16
|
||||
):
|
||||
|
||||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device)
|
||||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio:torch.Tensor = self.prompt_cache["raw_audio"]
|
||||
ref_sr = self.prompt_cache["raw_sr"]
|
||||
ref_audio=ref_audio.to(self.configs.device).float()
|
||||
if (ref_audio.shape[0] == 2):
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if ref_sr!=24000:
|
||||
ref_audio=resample(ref_audio, ref_sr)
|
||||
# print("ref_audio",ref_audio.abs().mean())
|
||||
mel2 = mel_fn(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if (T_min > 468):
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
|
||||
mel2=mel2.to(self.precision)
|
||||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while (1):
|
||||
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
|
||||
if (fea_todo_chunk.shape[-1] == 0): break
|
||||
idx += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
|
||||
cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2]:]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
cfm_resss.append(cfm_res)
|
||||
cmf_res = torch.cat(cfm_resss, 2)
|
||||
cmf_res = denorm_spec(cmf_res)
|
||||
|
||||
with torch.inference_mode():
|
||||
wav_gen = self.bigvgan_model(cmf_res)
|
||||
audio=wav_gen[0][0]#.cpu().detach().numpy()
|
||||
|
||||
|
||||
|
||||
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||
@ -1036,3 +1142,14 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||
processed_audio = np.frombuffer(out, np.int16)
|
||||
|
||||
return processed_audio
|
||||
|
||||
|
||||
|
||||
resample_transform_dict={}
|
||||
def resample(audio_tensor, sr0, device):
|
||||
global resample_transform_dict
|
||||
if sr0 not in resample_transform_dict:
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
|
||||
sr0, 24000
|
||||
).to(device)
|
||||
return resample_transform_dict[sr0](audio_tensor)
|
@ -118,11 +118,11 @@ class TextPreprocessor:
|
||||
|
||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
language = language.replace("all_","")
|
||||
# language = language.replace("all_","")
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "zh":
|
||||
if language == "all_zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
@ -130,7 +130,7 @@ class TextPreprocessor:
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"yue",version)
|
||||
|
@ -194,10 +194,20 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
|
||||
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
global version, dict_language
|
||||
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
|
||||
# print(sovits_path,version, model_version, if_lora_v3)
|
||||
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
|
||||
if if_lora_v3 and not os.path.exists(path_sovits_v3):
|
||||
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
gr.Warning(info)
|
||||
raise FileExistsError(info)
|
||||
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
|
||||
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
@ -210,7 +220,13 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
else:
|
||||
text_update = {'__type__':'update', 'value':''}
|
||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
|
||||
if model_version=="v3":
|
||||
visible_sample_steps=True
|
||||
visible_inp_refs=False
|
||||
else:
|
||||
visible_sample_steps=False
|
||||
visible_inp_refs=True
|
||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user