mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-26 11:36:35 +08:00
新增VITS批量推理 GPT_SoVITS/TTS_infer_pack/TTS.py
fix some bugs GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py fix some bugs GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py fix some bugs GPT_SoVITS/inference_webui.py fix some bugs GPT_SoVITS/module/models.py
This commit is contained in:
parent
174c4bbab3
commit
3535cfe3b0
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import os, sys
|
import os, sys
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -366,6 +367,7 @@ class TTS:
|
|||||||
for batch_idx, index_list in enumerate(batch_index_list):
|
for batch_idx, index_list in enumerate(batch_index_list):
|
||||||
item_list = [data[idx] for idx in index_list]
|
item_list = [data[idx] for idx in index_list]
|
||||||
phones_list = []
|
phones_list = []
|
||||||
|
phones_len_list = []
|
||||||
# bert_features_list = []
|
# bert_features_list = []
|
||||||
all_phones_list = []
|
all_phones_list = []
|
||||||
all_phones_len_list = []
|
all_phones_len_list = []
|
||||||
@ -375,24 +377,26 @@ class TTS:
|
|||||||
phones_max_len = 0
|
phones_max_len = 0
|
||||||
for item in item_list:
|
for item in item_list:
|
||||||
if prompt_data is not None:
|
if prompt_data is not None:
|
||||||
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
|
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)
|
||||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||||
else:
|
else:
|
||||||
all_bert_features = item["bert_features"]
|
all_bert_features = item["bert_features"]
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
all_phones = phones.clone()
|
all_phones = phones
|
||||||
# norm_text = item["norm_text"]
|
# norm_text = item["norm_text"]
|
||||||
|
|
||||||
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
||||||
phones_max_len = max(phones_max_len, phones.shape[-1])
|
phones_max_len = max(phones_max_len, phones.shape[-1])
|
||||||
|
|
||||||
phones_list.append(phones)
|
phones_list.append(phones)
|
||||||
|
phones_len_list.append(phones.shape[-1])
|
||||||
all_phones_list.append(all_phones)
|
all_phones_list.append(all_phones)
|
||||||
all_phones_len_list.append(all_phones.shape[-1])
|
all_phones_len_list.append(all_phones.shape[-1])
|
||||||
all_bert_features_list.append(all_bert_features)
|
all_bert_features_list.append(all_bert_features)
|
||||||
norm_text_batch.append(item["norm_text"])
|
norm_text_batch.append(item["norm_text"])
|
||||||
|
|
||||||
phones_batch = phones_list
|
phones_batch = phones_list
|
||||||
max_len = max(bert_max_len, phones_max_len)
|
max_len = max(bert_max_len, phones_max_len)
|
||||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
@ -406,6 +410,7 @@ class TTS:
|
|||||||
|
|
||||||
batch = {
|
batch = {
|
||||||
"phones": phones_batch,
|
"phones": phones_batch,
|
||||||
|
"phones_len": torch.LongTensor(phones_len_list),
|
||||||
"all_phones": all_phones_batch,
|
"all_phones": all_phones_batch,
|
||||||
"all_phones_len": torch.LongTensor(all_phones_len_list),
|
"all_phones_len": torch.LongTensor(all_phones_len_list),
|
||||||
"all_bert_features": all_bert_features_batch,
|
"all_bert_features": all_bert_features_batch,
|
||||||
@ -492,6 +497,10 @@ class TTS:
|
|||||||
if split_bucket:
|
if split_bucket:
|
||||||
print(i18n("分桶处理模式已开启"))
|
print(i18n("分桶处理模式已开启"))
|
||||||
|
|
||||||
|
# if vits_batched_inference:
|
||||||
|
# print(i18n("VITS批量推理模式已开启"))
|
||||||
|
# else:
|
||||||
|
# print(i18n("VITS单句推理模式已开启"))
|
||||||
|
|
||||||
no_prompt_text = False
|
no_prompt_text = False
|
||||||
if prompt_text in [None, ""]:
|
if prompt_text in [None, ""]:
|
||||||
@ -529,7 +538,6 @@ class TTS:
|
|||||||
|
|
||||||
###### text preprocessing ########
|
###### text preprocessing ########
|
||||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
||||||
audio = []
|
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
data, batch_index_list = self.to_batch(data,
|
data, batch_index_list = self.to_batch(data,
|
||||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||||
@ -538,24 +546,23 @@ class TTS:
|
|||||||
split_bucket=split_bucket
|
split_bucket=split_bucket
|
||||||
)
|
)
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
zero_wav = torch.zeros(
|
|
||||||
int(self.configs.sampling_rate * 0.3),
|
|
||||||
dtype=torch.float16 if self.configs.is_half else torch.float32,
|
|
||||||
device=self.configs.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
###### inference ######
|
###### inference ######
|
||||||
t_34 = 0.0
|
t_34 = 0.0
|
||||||
t_45 = 0.0
|
t_45 = 0.0
|
||||||
|
audio = []
|
||||||
for item in data:
|
for item in data:
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
batch_phones = item["phones"]
|
batch_phones = item["phones"]
|
||||||
|
batch_phones_len = item["phones_len"]
|
||||||
all_phoneme_ids = item["all_phones"]
|
all_phoneme_ids = item["all_phones"]
|
||||||
all_phoneme_lens = item["all_phones_len"]
|
all_phoneme_lens = item["all_phones_len"]
|
||||||
all_bert_features = item["all_bert_features"]
|
all_bert_features = item["all_bert_features"]
|
||||||
norm_text = item["norm_text"]
|
norm_text = item["norm_text"]
|
||||||
|
|
||||||
|
# batch_phones = batch_phones.to(self.configs.device)
|
||||||
|
batch_phones_len = batch_phones_len.to(self.configs.device)
|
||||||
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
|
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
|
||||||
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
|
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
|
||||||
all_bert_features = all_bert_features.to(self.configs.device)
|
all_bert_features = all_bert_features.to(self.configs.device)
|
||||||
@ -566,7 +573,7 @@ class TTS:
|
|||||||
if no_prompt_text :
|
if no_prompt_text :
|
||||||
prompt = None
|
prompt = None
|
||||||
else:
|
else:
|
||||||
prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device)
|
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||||
@ -583,41 +590,52 @@ class TTS:
|
|||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
t_34 += t4 - t3
|
t_34 += t4 - t3
|
||||||
|
|
||||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].clone().to(self.configs.device)
|
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
|
||||||
if self.configs.is_half:
|
if self.configs.is_half:
|
||||||
refer_audio_spepc = refer_audio_spepc.half()
|
refer_audio_spepc = refer_audio_spepc.half()
|
||||||
|
|
||||||
## 直接对batch进行decode 生成的音频会有问题
|
|
||||||
|
batch_audio_fragment = []
|
||||||
|
|
||||||
|
# ## vits并行推理 method 1
|
||||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||||
|
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||||
# batch_phones = batch_phones.to(self.configs.device)
|
# batch_phones = batch_phones.to(self.configs.device)
|
||||||
# batch_audio_fragment =(self.vits_model.decode(
|
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||||
# pred_semantic, batch_phones, refer_audio_spepc
|
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
|
||||||
# ).detach()[:, 0, :])
|
# ))
|
||||||
# max_audio=torch.abs(batch_audio_fragment).max()#简单防止16bit爆音
|
|
||||||
# if max_audio>1: batch_audio_fragment/=max_audio
|
|
||||||
# batch_audio_fragment = batch_audio_fragment.cpu().numpy()
|
|
||||||
|
|
||||||
## 改成串行处理
|
# ## vits并行推理 method 2
|
||||||
batch_audio_fragment = []
|
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||||
for i, idx in enumerate(idx_list):
|
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
|
||||||
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
|
||||||
audio_fragment =(self.vits_model.decode(
|
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||||
_pred_semantic, phones, refer_audio_spepc
|
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||||
|
_batch_audio_fragment = (self.vits_model.decode(
|
||||||
|
all_pred_semantic, _batch_phones,refer_audio_spepc
|
||||||
).detach()[0, 0, :])
|
).detach()[0, 0, :])
|
||||||
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
|
audio_frag_end_idx.insert(0, 0)
|
||||||
if max_audio>1: audio_fragment/=max_audio
|
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
|
||||||
audio_fragment = torch.cat([audio_fragment, zero_wav], dim=0)
|
|
||||||
batch_audio_fragment.append(
|
|
||||||
audio_fragment.cpu().numpy()
|
# ## vits串行推理
|
||||||
) ###试试重建不带上prompt部分
|
# for i, idx in enumerate(idx_list):
|
||||||
|
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||||
|
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
|
# audio_fragment =(self.vits_model.decode(
|
||||||
|
# _pred_semantic, phones, refer_audio_spepc
|
||||||
|
# ).detach()[0, 0, :])
|
||||||
|
# batch_audio_fragment.append(
|
||||||
|
# audio_fragment
|
||||||
|
# ) ###试试重建不带上prompt部分
|
||||||
|
|
||||||
t5 = ttime()
|
t5 = ttime()
|
||||||
t_45 += t5 - t4
|
t_45 += t5 - t4
|
||||||
if return_fragment:
|
if return_fragment:
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||||
yield self.audio_postprocess(batch_audio_fragment,
|
yield self.audio_postprocess([batch_audio_fragment],
|
||||||
self.configs.sampling_rate,
|
self.configs.sampling_rate,
|
||||||
batch_index_list,
|
batch_index_list,
|
||||||
speed_factor,
|
speed_factor,
|
||||||
@ -626,7 +644,8 @@ class TTS:
|
|||||||
audio.append(batch_audio_fragment)
|
audio.append(batch_audio_fragment)
|
||||||
|
|
||||||
if self.stop_flag:
|
if self.stop_flag:
|
||||||
yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16)
|
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
|
||||||
|
dtype=np.int16)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not return_fragment:
|
if not return_fragment:
|
||||||
@ -640,15 +659,30 @@ class TTS:
|
|||||||
|
|
||||||
|
|
||||||
def audio_postprocess(self,
|
def audio_postprocess(self,
|
||||||
audio:np.ndarray,
|
audio:List[torch.Tensor],
|
||||||
sr:int,
|
sr:int,
|
||||||
batch_index_list:list=None,
|
batch_index_list:list=None,
|
||||||
speed_factor:float=1.0,
|
speed_factor:float=1.0,
|
||||||
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
||||||
|
zero_wav = torch.zeros(
|
||||||
|
int(self.configs.sampling_rate * 0.3),
|
||||||
|
dtype=torch.float16 if self.configs.is_half else torch.float32,
|
||||||
|
device=self.configs.device
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, batch in enumerate(audio):
|
||||||
|
for j, audio_fragment in enumerate(batch):
|
||||||
|
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
|
||||||
|
if max_audio>1: audio_fragment/=max_audio
|
||||||
|
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||||
|
audio[i][j] = audio_fragment.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
if split_bucket:
|
if split_bucket:
|
||||||
audio = self.recovery_order(audio, batch_index_list)
|
audio = self.recovery_order(audio, batch_index_list)
|
||||||
else:
|
else:
|
||||||
audio = [item for batch in audio for item in batch]
|
# audio = [item for batch in audio for item in batch]
|
||||||
|
audio = sum(audio, [])
|
||||||
|
|
||||||
|
|
||||||
audio = np.concatenate(audio, 0)
|
audio = np.concatenate(audio, 0)
|
||||||
|
@ -10,7 +10,7 @@ from typing import Dict, List, Tuple
|
|||||||
from text.cleaner import clean_text
|
from text.cleaner import clean_text
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
from TTS_infer_pack.text_segmentation_method import splits, get_method as get_seg_method
|
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||||
|
|
||||||
# from tools.i18n.i18n import I18nAuto
|
# from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
@ -39,6 +39,10 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextPreprocessor:
|
class TextPreprocessor:
|
||||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||||
tokenizer:AutoTokenizer, device:torch.device):
|
tokenizer:AutoTokenizer, device:torch.device):
|
||||||
@ -74,12 +78,18 @@ class TextPreprocessor:
|
|||||||
_texts = text.split("\n")
|
_texts = text.split("\n")
|
||||||
_texts = merge_short_text_in_array(_texts, 5)
|
_texts = merge_short_text_in_array(_texts, 5)
|
||||||
texts = []
|
texts = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for text in _texts:
|
for text in _texts:
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
if (len(text.strip()) == 0):
|
if (len(text.strip()) == 0):
|
||||||
continue
|
continue
|
||||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||||
texts.append(text)
|
|
||||||
|
# 解决句子过长导致Bert报错的问题
|
||||||
|
texts.extend(split_big_text(text))
|
||||||
|
|
||||||
|
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
@ -177,3 +187,7 @@ class TextPreprocessor:
|
|||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
return feature
|
return feature
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,6 +24,32 @@ def register_method(name):
|
|||||||
|
|
||||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||||
|
|
||||||
|
def split_big_text(text, max_len=510):
|
||||||
|
# 定义全角和半角标点符号
|
||||||
|
punctuation = "".join(splits)
|
||||||
|
|
||||||
|
# 切割文本
|
||||||
|
segments = re.split('([' + punctuation + '])', text)
|
||||||
|
|
||||||
|
# 初始化结果列表和当前片段
|
||||||
|
result = []
|
||||||
|
current_segment = ''
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||||
|
if len(current_segment + segment) > max_len:
|
||||||
|
result.append(current_segment)
|
||||||
|
current_segment = segment
|
||||||
|
else:
|
||||||
|
current_segment += segment
|
||||||
|
|
||||||
|
# 将最后一个片段加入结果列表
|
||||||
|
if current_segment:
|
||||||
|
result.append(current_segment)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def split(todo_text):
|
def split(todo_text):
|
||||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||||
@ -121,6 +147,6 @@ def cut5(inp):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
method = get_method("cut1")
|
method = get_method("cut5")
|
||||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||||
|
|
@ -29,9 +29,13 @@ is_share = eval(is_share)
|
|||||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||||
is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
|
is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
|
||||||
|
gpt_path = os.environ.get("gpt_path", None)
|
||||||
|
sovits_path = os.environ.get("sovits_path", None)
|
||||||
|
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||||
|
bert_path = os.environ.get("bert_path", None)
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from TTS_infer_pack.TTS import TTS, TTS_Config
|
from TTS_infer_pack.TTS import TTS, TTS_Config
|
||||||
from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5
|
|
||||||
from TTS_infer_pack.text_segmentation_method import get_method
|
from TTS_infer_pack.text_segmentation_method import get_method
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
@ -65,6 +69,15 @@ cut_method = {
|
|||||||
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
|
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
|
||||||
tts_config.device = device
|
tts_config.device = device
|
||||||
tts_config.is_half = is_half
|
tts_config.is_half = is_half
|
||||||
|
if gpt_path is not None:
|
||||||
|
tts_config.t2s_weights_path = gpt_path
|
||||||
|
if sovits_path is not None:
|
||||||
|
tts_config.vits_weights_path = sovits_path
|
||||||
|
if cnhubert_base_path is not None:
|
||||||
|
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
||||||
|
if bert_path is not None:
|
||||||
|
tts_config.bert_base_path = bert_path
|
||||||
|
|
||||||
tts_pipline = TTS(tts_config)
|
tts_pipline = TTS(tts_config)
|
||||||
gpt_path = tts_config.t2s_weights_path
|
gpt_path = tts_config.t2s_weights_path
|
||||||
sovits_path = tts_config.vits_weights_path
|
sovits_path = tts_config.vits_weights_path
|
||||||
@ -169,7 +182,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True)
|
batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||||
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
||||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||||
@ -181,7 +194,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
value=i18n("凑四句一切"),
|
value=i18n("凑四句一切"),
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
|
with gr.Row():
|
||||||
|
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
|
||||||
# with gr.Column():
|
# with gr.Column():
|
||||||
output = gr.Audio(label=i18n("输出的语音"))
|
output = gr.Audio(label=i18n("输出的语音"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -986,6 +987,55 @@ class SynthesizerTrn(nn.Module):
|
|||||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
|
||||||
|
ge = None
|
||||||
|
if refer is not None:
|
||||||
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||||
|
refer_mask = torch.unsqueeze(
|
||||||
|
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||||
|
).to(refer.dtype)
|
||||||
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
|
|
||||||
|
# y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
|
||||||
|
# codes.dtype
|
||||||
|
# )
|
||||||
|
y_lengths = (y_lengths * 2).long().to(codes.device)
|
||||||
|
text_lengths = text_lengths.long().to(text.device)
|
||||||
|
# y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||||
|
# text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||||
|
|
||||||
|
# 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
|
||||||
|
quantized = self.quantizer.decode(codes)
|
||||||
|
if self.semantic_frame_rate == "25hz":
|
||||||
|
quantized = F.interpolate(
|
||||||
|
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||||
|
)
|
||||||
|
|
||||||
|
x, m_p, logs_p, y_mask = self.enc_p(
|
||||||
|
quantized, y_lengths, text, text_lengths, ge
|
||||||
|
)
|
||||||
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
z_masked = (z * y_mask)[:, :, :]
|
||||||
|
|
||||||
|
# 串行。把padding部分去掉再decode
|
||||||
|
o_list:List[torch.Tensor] = []
|
||||||
|
for i in range(z_masked.shape[0]):
|
||||||
|
z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
|
||||||
|
o = self.dec(z_slice, g=ge)[0, 0, :].detach()
|
||||||
|
o_list.append(o)
|
||||||
|
|
||||||
|
# 并行(会有问题)。先decode,再把padding的部分去掉
|
||||||
|
# o = self.dec(z_masked, g=ge)
|
||||||
|
# upsample_rate = int(math.prod(self.upsample_rates))
|
||||||
|
# o_lengths = y_lengths*upsample_rate
|
||||||
|
# o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
|
||||||
|
|
||||||
|
return o_list
|
||||||
|
|
||||||
def extract_latent(self, x):
|
def extract_latent(self, x):
|
||||||
ssl = self.ssl_proj(x)
|
ssl = self.ssl_proj(x)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user