Merge pull request #741 from ChasonJiang/fast_inference_

修复了一些bug,优化了一些代码
This commit is contained in:
RVC-Boss 2024-03-13 17:36:39 +08:00 committed by GitHub
commit 37a895a67d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 365 additions and 220 deletions

View File

@ -229,10 +229,15 @@ class Text2SemanticDecoder(nn.Module):
ignore_index=self.EOS,
)
if not flash_attn_enabled:
self.enable_flash_attn(flash_attn_enabled)
def enable_flash_attn(self, enable:bool=True):
if not enable:
print("Not Using Flash Attention")
self.infer_panel = self.infer_panel_batch_only
else:
self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
print("Using Flash Attention")
blocks = []
@ -497,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel(
def infer_panel_batch_infer_with_flash_attn(
self,
x, #####全部文本token
x_lens,
@ -508,8 +513,10 @@ class Text2SemanticDecoder(nn.Module):
early_stop_num: int = -1,
temperature: float = 1.0,
):
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = x + bert_feature
x = self.ar_text_position(x)
# AR Decoder
@ -546,30 +553,28 @@ class Text2SemanticDecoder(nn.Module):
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
y_mask = make_pad_mask(y_lens)
x_mask = make_pad_mask(x_lens)
# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
_xy_padding_mask = (
xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1)
)
x_attn_mask_pad = F.pad(
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
###### decode #####
y_list = [None]*y.shape[0]
@ -641,7 +646,7 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:

View File

@ -143,7 +143,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
pivot = v[: , -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)

View File

@ -1,5 +1,8 @@
from copy import deepcopy
import math
import os, sys
import random
import traceback
now_dir = os.getcwd()
sys.path.append(now_dir)
import ffmpeg
@ -7,6 +10,7 @@ import os
from typing import Generator, List, Union
import numpy as np
import torch
import torch.nn.functional as F
import yaml
from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -45,21 +49,20 @@ custom:
"""
# def set_seed(seed):
# random.seed(seed)
# os.environ['PYTHONHASHSEED'] = str(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = True
# set_seed(1234)
class TTS_Config:
def __init__(self, configs: Union[dict, str]):
configs_base_path:str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(configs)
# assert isinstance(configs, dict)
self.default_configs:dict = configs.get("default", None)
if self.default_configs is None:
self.default_configs={
default_configs={
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
@ -68,15 +71,54 @@ class TTS_Config:
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
}
self.configs:dict = configs.get("custom", self.default_configs)
configs:dict = None
def __init__(self, configs: Union[dict, str]=None):
self.device = self.configs.get("device")
self.is_half = self.configs.get("is_half")
self.t2s_weights_path = self.configs.get("t2s_weights_path")
self.vits_weights_path = self.configs.get("vits_weights_path")
self.bert_base_path = self.configs.get("bert_base_path")
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
# 设置默认配置文件路径
configs_base_path:str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
if configs in ["", None]:
if not os.path.exists(self.configs_path):
self.save_configs()
print(f"Create default config file at {self.configs_path}")
configs:dict = {"default": deepcopy(self.default_configs)}
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict)
default_configs:dict = configs.get("default", None)
if default_configs is not None:
self.default_configs = default_configs
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
self.device = self.configs.get("device", torch.device("cpu"))
self.is_half = self.configs.get("is_half", False)
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
self.vits_weights_path = self.default_configs['vits_weights_path']
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
self.bert_base_path = self.default_configs['bert_base_path']
print(f"fall back to default bert_base_path: {self.bert_base_path}")
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs()
self.max_sec = None
@ -90,53 +132,48 @@ class TTS_Config:
self.n_speakers:int = 300
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
print(self)
# print(self)
def _load_configs(self, configs_path: str)->dict:
with open(configs_path, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
return configs
def save_configs(self, configs_path:str=None)->None:
configs={
"default": {
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
},
"custom": {
"device": str(self.device),
"is_half": self.is_half,
"t2s_weights_path": self.t2s_weights_path,
"vits_weights_path": self.vits_weights_path,
"bert_base_path": self.bert_base_path,
"cnhuhbert_base_path": self.cnhuhbert_base_path,
"flash_attn_enabled": self.flash_attn_enabled
}
"default":self.default_configs,
}
if self.configs is not None:
configs["custom"] = self.update_configs()
if configs_path is None:
configs_path = self.configs_path
with open(configs_path, 'w') as f:
yaml.dump(configs, f)
def update_configs(self):
self.config = {
"device" : str(self.device),
"is_half" : self.is_half,
"t2s_weights_path" : self.t2s_weights_path,
"vits_weights_path" : self.vits_weights_path,
"bert_base_path" : self.bert_base_path,
"cnhuhbert_base_path": self.cnhuhbert_base_path,
"flash_attn_enabled" : self.flash_attn_enabled
}
return self.config
def __str__(self):
string = "----------------TTS Config--------------\n"
string += "device: {}\n".format(self.device)
string += "is_half: {}\n".format(self.is_half)
string += "bert_base_path: {}\n".format(self.bert_base_path)
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
string += "----------------------------------------\n"
self.configs = self.update_configs()
string = "TTS Config".center(100, '-') + '\n'
for k, v in self.configs.items():
string += f"{str(k).ljust(20)}: {str(v)}\n"
string += "-" * 100 + '\n'
return string
def __repr__(self):
return self.__str__()
class TTS:
@ -173,34 +210,40 @@ class TTS:
self.stop_flag:bool = False
self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
def _init_models(self,):
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
self.init_bert_weights(self.configs.bert_base_path)
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
# self.enable_half_precision(self.configs.is_half)
def init_cnhuhbert_weights(self, base_path: str):
print(f"Loading CNHuBERT weights from {base_path}")
self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model.eval()
if self.configs.is_half == True:
self.cnhuhbert_model = self.cnhuhbert_model.half()
self.cnhuhbert_model=self.cnhuhbert_model.eval()
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
if self.configs.is_half:
self.cnhuhbert_model = self.cnhuhbert_model.half()
def init_bert_weights(self, base_path: str):
print(f"Loading BERT weights from {base_path}")
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
self.bert_model=self.bert_model.eval()
self.bert_model = self.bert_model.to(self.configs.device)
if self.configs.is_half:
self.bert_model = self.bert_model.half()
self.bert_model = self.bert_model.to(self.configs.device)
def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path
self.configs.save_configs()
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
@ -223,15 +266,16 @@ class TTS:
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
if self.configs.is_half:
vits_model = vits_model.half()
vits_model = vits_model.to(self.configs.device)
vits_model.eval()
vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model
if self.configs.is_half:
self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path
self.configs.save_configs()
self.configs.hz = 50
@ -241,11 +285,61 @@ class TTS:
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
flash_attn_enabled=self.configs.flash_attn_enabled)
t2s_model.load_state_dict(dict_s1["weight"])
if self.configs.is_half:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(self.configs.device)
t2s_model.eval()
t2s_model = t2s_model.eval()
self.t2s_model = t2s_model
if self.configs.is_half:
self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True):
'''
To enable half precision for the TTS model.
Args:
enable: bool, whether to enable half precision.
'''
if self.configs.device == "cpu" and enable:
print("Half precision is not supported on CPU.")
return
self.configs.is_half = enable
self.precison = torch.float16 if enable else torch.float32
self.configs.save_configs()
if enable:
if self.t2s_model is not None:
self.t2s_model =self.t2s_model.half()
if self.vits_model is not None:
self.vits_model = self.vits_model.half()
if self.bert_model is not None:
self.bert_model =self.bert_model.half()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half()
else:
if self.t2s_model is not None:
self.t2s_model = self.t2s_model.float()
if self.vits_model is not None:
self.vits_model = self.vits_model.float()
if self.bert_model is not None:
self.bert_model = self.bert_model.float()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float()
def set_device(self, device: torch.device):
'''
To set the device for all models.
Args:
device: torch.device, the device to use for all models.
'''
self.configs.device = device
self.configs.save_configs()
if self.t2s_model is not None:
self.t2s_model = self.t2s_model.to(device)
if self.vits_model is not None:
self.vits_model = self.vits_model.to(device)
if self.bert_model is not None:
self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
def set_ref_audio(self, ref_audio_path:str):
'''
@ -346,7 +440,7 @@ class TTS:
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/batch.mean()
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
@ -377,12 +471,14 @@ class TTS:
phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=self.precison)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
phones = torch.LongTensor(item["phones"])
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]
all_bert_features = item["bert_features"]\
.to(dtype=self.precison)
phones = torch.LongTensor(item["phones"])
all_phones = phones
# norm_text = item["norm_text"]
@ -401,12 +497,10 @@ class TTS:
max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
all_bert_features_batch.zero_()
# all_bert_features_batch = all_bert_features_list
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
for idx, item in enumerate(all_bert_features_list):
if item != None:
all_bert_features_batch[idx, :, : item.shape[-1]] = item
all_bert_features_batch[idx, :, : item.shape[-1]] = item
batch = {
"phones": phones_batch,
@ -458,8 +552,8 @@ class TTS:
"prompt_text": "", # str. prompt text for the reference audio
"prompt_lang": "", # str. language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 0.9, # float. top p sampling
"temperature": 0.6, # float. temperature for sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
@ -477,9 +571,9 @@ class TTS:
ref_audio_path:str = inputs.get("ref_audio_path", "")
prompt_text:str = inputs.get("prompt_text", "")
prompt_lang:str = inputs.get("prompt_lang", "")
top_k:int = inputs.get("top_k", 20)
top_p:float = inputs.get("top_p", 0.9)
temperature:float = inputs.get("temperature", 0.6)
top_k:int = inputs.get("top_k", 5)
top_p:float = inputs.get("top_p", 1)
temperature:float = inputs.get("temperature", 1)
text_split_method:str = inputs.get("text_split_method", "")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
@ -497,10 +591,6 @@ class TTS:
if split_bucket:
print(i18n("分桶处理模式已开启"))
# if vits_batched_inference:
# print(i18n("VITS批量推理模式已开启"))
# else:
# print(i18n("VITS单句推理模式已开启"))
no_prompt_text = False
if prompt_text in [None, ""]:
@ -538,6 +628,11 @@ class TTS:
###### text preprocessing ########
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
dtype=np.int16)
return
t1 = ttime()
data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
@ -546,118 +641,141 @@ class TTS:
split_bucket=split_bucket
)
t2 = ttime()
try:
print("############ 推理 ############")
###### inference ######
t_34 = 0.0
t_45 = 0.0
audio = []
for item in data:
t3 = ttime()
batch_phones = item["phones"]
batch_phones_len = item["phones_len"]
all_phoneme_ids = item["all_phones"]
all_phoneme_lens = item["all_phones_len"]
all_bert_features = item["all_bert_features"]
norm_text = item["norm_text"]
# batch_phones = batch_phones.to(self.configs.device)
batch_phones_len = batch_phones_len.to(self.configs.device)
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
all_bert_features = all_bert_features.to(self.configs.device)
if self.configs.is_half:
all_bert_features = all_bert_features.half()
###### inference ######
t_34 = 0.0
t_45 = 0.0
audio = []
for item in data:
t3 = ttime()
batch_phones = item["phones"]
batch_phones_len = item["phones_len"]
all_phoneme_ids = item["all_phones"]
all_phoneme_lens = item["all_phones_len"]
all_bert_features = item["all_bert_features"]
norm_text = item["norm_text"]
# batch_phones = batch_phones.to(self.configs.device)
batch_phones_len = batch_phones_len.to(self.configs.device)
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
all_bert_features = all_bert_features.to(self.configs.device)
if self.configs.is_half:
all_bert_features = all_bert_features.half()
print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text :
prompt = None
else:
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
with torch.no_grad():
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
)
t4 = ttime()
t_34 += t4 - t3
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
if self.configs.is_half:
refer_audio_spepc = refer_audio_spepc.half()
print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text :
prompt = None
else:
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
with torch.no_grad():
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
)
t4 = ttime()
t_34 += t4 - t3
batch_audio_fragment = []
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
.to(dtype=self.precison, device=self.configs.device)
batch_audio_fragment = []
# ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
# ))
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones,refer_audio_spepc
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
# ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# max_len = 0
# for i in range(0, len(batch_phones)):
# max_len = max(max_len, batch_phones[i].shape[-1])
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
# ))
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones,refer_audio_spepc
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
# ## vits串行推理
# for i, idx in enumerate(idx_list):
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
# audio_fragment =(self.vits_model.decode(
# _pred_semantic, phones, refer_audio_spepc
# ).detach()[0, 0, :])
# batch_audio_fragment.append(
# audio_fragment
# ) ###试试重建不带上prompt部分
t5 = ttime()
t_45 += t5 - t4
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment],
# ## vits串行推理
# for i, idx in enumerate(idx_list):
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
# audio_fragment =(self.vits_model.decode(
# _pred_semantic, phones, refer_audio_spepc
# ).detach()[0, 0, :])
# batch_audio_fragment.append(
# audio_fragment
# ) ###试试重建不带上prompt部分
t5 = ttime()
t_45 += t5 - t4
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment],
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket)
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
dtype=np.int16)
return
if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield self.audio_postprocess(audio,
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket)
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
dtype=np.int16)
return
if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield self.audio_postprocess(audio,
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket)
split_bucket)
except Exception as e:
traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。
del self.t2s_model
del self.vits_model
self.t2s_model = None
self.vits_model = None
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
finally:
self.empty_cache()
def empty_cache(self):
try:
if str(self.configs.device) == "cuda":
torch.cuda.empty_cache()
elif str(self.configs.device) == "mps":
torch.mps.empty_cache()
except:
pass
def audio_postprocess(self,
audio:List[torch.Tensor],
sr:int,
@ -666,7 +784,7 @@ class TTS:
split_bucket:bool=True)->tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=torch.float16 if self.configs.is_half else torch.float32,
dtype=self.precison,
device=self.configs.device
)

View File

@ -1,5 +1,7 @@
import os, sys
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -12,9 +14,9 @@ from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
# from tools.i18n.i18n import I18nAuto
from tools.i18n.i18n import I18nAuto
# i18n = I18nAuto()
i18n = I18nAuto()
def get_first(text:str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
@ -51,10 +53,14 @@ class TextPreprocessor:
self.device = device
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
print(i18n("############ 切分文本 ############"))
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
for text in texts:
print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
if phones is None:
continue
res={
"phones": phones,
"bert_features": bert_features,
@ -67,18 +73,18 @@ class TextPreprocessor:
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4):
text = "" + text if lang != "en" else "." + text
# print(i18n("实际输入的目标文本:"), text)
print(i18n("实际输入的目标文本:"))
print(text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
# print(i18n("实际输入的目标文本(切句后):"), text)
_texts = text.split("\n")
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
@ -88,15 +94,21 @@ class TextPreprocessor:
if (text[-1] not in splits): text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
texts.extend(split_big_text(text))
if (len(text) > 510):
texts.extend(split_big_text(text))
else:
texts.append(text)
print(i18n("实际输入的目标文本(切句后):"))
print(texts)
return texts
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
textlist, langlist = self.seg_text(texts, language)
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
if len(textlist) == 0:
return None, None, None
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
return phones, bert_features, norm_text
@ -105,8 +117,10 @@ class TextPreprocessor:
textlist=[]
langlist=[]
if language in ["auto", "zh", "ja"]:
# LangSegment.setfilters(["zh","ja","en","ko"])
LangSegment.setfilters(["zh","ja","en","ko"])
for tmp in LangSegment.getTexts(text):
if tmp["text"] == "":
continue
if tmp["lang"] == "ko":
langlist.append("zh")
elif tmp["lang"] == "en":
@ -116,18 +130,22 @@ class TextPreprocessor:
langlist.append(language if language!="auto" else tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
# LangSegment.setfilters(["en"])
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
while " " in formattext:
formattext = formattext.replace(" ", " ")
textlist.append(formattext)
langlist.append("en")
if formattext != "":
textlist.append(formattext)
langlist.append("en")
elif language in ["all_zh","all_ja"]:
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
language = language.replace("all_","")
if text == "":
return [],[]
textlist.append(formattext)
langlist.append(language)
@ -152,8 +170,7 @@ class TextPreprocessor:
bert_feature = torch.cat(bert_feature_list, dim=1)
# phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones, bert_feature, norm_text
return phones_list, bert_feature, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:

View File

@ -45,9 +45,11 @@ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时
if torch.cuda.is_available():
device = "cuda"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
dict_language = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
@ -78,6 +80,7 @@ if cnhubert_base_path is not None:
if bert_path is not None:
tts_config.bert_base_path = bert_path
print(tts_config)
tts_pipline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
@ -103,10 +106,12 @@ def inference(text, text_lang,
"batch_size":int(batch_size),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
"return_fragment":False
}
yield next(tts_pipline.run(inputs))
for item in tts_pipline.run(inputs):
yield item
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
@ -182,7 +187,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True)
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)