修复了一些bug,优化了一些代码

This commit is contained in:
chasonjiang 2024-03-11 17:16:04 +08:00
parent 3535cfe3b0
commit d23f3a62c4
5 changed files with 72 additions and 51 deletions

View File

@ -229,10 +229,15 @@ class Text2SemanticDecoder(nn.Module):
ignore_index=self.EOS,
)
if not flash_attn_enabled:
self.enable_flash_attn(flash_attn_enabled)
def enable_flash_attn(self, enable:bool=True):
if not enable:
print("Not Using Flash Attention")
self.infer_panel = self.infer_panel_batch_only
else:
self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
print("Using Flash Attention")
blocks = []
@ -497,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel(
def infer_panel_batch_infer_with_flash_attn(
self,
x, #####全部文本token
x_lens,
@ -508,8 +513,10 @@ class Text2SemanticDecoder(nn.Module):
early_stop_num: int = -1,
temperature: float = 1.0,
):
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = x + bert_feature
x = self.ar_text_position(x)
# AR Decoder
@ -546,30 +553,28 @@ class Text2SemanticDecoder(nn.Module):
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
y_mask = make_pad_mask(y_lens)
x_mask = make_pad_mask(x_lens)
# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
_xy_padding_mask = (
xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1)
)
x_attn_mask_pad = F.pad(
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
###### decode #####
y_list = [None]*y.shape[0]
@ -641,7 +646,7 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:

View File

@ -143,7 +143,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
pivot = v[: , -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)

View File

@ -1,5 +1,6 @@
import math
import os, sys
import random
now_dir = os.getcwd()
sys.path.append(now_dir)
import ffmpeg
@ -7,6 +8,7 @@ import os
from typing import Generator, List, Union
import numpy as np
import torch
import torch.nn.functional as F
import yaml
from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -130,11 +132,11 @@ class TTS_Config:
string = "----------------TTS Config--------------\n"
string += "device: {}\n".format(self.device)
string += "is_half: {}\n".format(self.is_half)
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
string += "bert_base_path: {}\n".format(self.bert_base_path)
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
string += "----------------------------------------\n"
return string
@ -184,7 +186,7 @@ class TTS:
def init_cnhuhbert_weights(self, base_path: str):
self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model.eval()
self.cnhuhbert_model=self.cnhuhbert_model.eval()
if self.configs.is_half == True:
self.cnhuhbert_model = self.cnhuhbert_model.half()
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
@ -194,6 +196,7 @@ class TTS:
def init_bert_weights(self, base_path: str):
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
self.bert_model=self.bert_model.eval()
if self.configs.is_half:
self.bert_model = self.bert_model.half()
self.bert_model = self.bert_model.to(self.configs.device)
@ -226,7 +229,7 @@ class TTS:
if self.configs.is_half:
vits_model = vits_model.half()
vits_model = vits_model.to(self.configs.device)
vits_model.eval()
vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model
@ -244,7 +247,7 @@ class TTS:
if self.configs.is_half:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(self.configs.device)
t2s_model.eval()
t2s_model = t2s_model.eval()
self.t2s_model = t2s_model
def set_ref_audio(self, ref_audio_path:str):
@ -377,12 +380,14 @@ class TTS:
phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
phones = torch.LongTensor(item["phones"])
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]
all_bert_features = item["bert_features"]\
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
phones = torch.LongTensor(item["phones"])
all_phones = phones
# norm_text = item["norm_text"]
@ -401,12 +406,10 @@ class TTS:
max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
all_bert_features_batch.zero_()
# all_bert_features_batch = all_bert_features_list
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32)
for idx, item in enumerate(all_bert_features_list):
if item != None:
all_bert_features_batch[idx, :, : item.shape[-1]] = item
all_bert_features_batch[idx, :, : item.shape[-1]] = item
batch = {
"phones": phones_batch,
@ -458,8 +461,8 @@ class TTS:
"prompt_text": "", # str. prompt text for the reference audio
"prompt_lang": "", # str. language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 0.9, # float. top p sampling
"temperature": 0.6, # float. temperature for sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
@ -477,9 +480,9 @@ class TTS:
ref_audio_path:str = inputs.get("ref_audio_path", "")
prompt_text:str = inputs.get("prompt_text", "")
prompt_lang:str = inputs.get("prompt_lang", "")
top_k:int = inputs.get("top_k", 20)
top_p:float = inputs.get("top_p", 0.9)
temperature:float = inputs.get("temperature", 0.6)
top_k:int = inputs.get("top_k", 5)
top_p:float = inputs.get("top_p", 1)
temperature:float = inputs.get("temperature", 1)
text_split_method:str = inputs.get("text_split_method", "")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
@ -497,10 +500,6 @@ class TTS:
if split_bucket:
print(i18n("分桶处理模式已开启"))
# if vits_batched_inference:
# print(i18n("VITS批量推理模式已开启"))
# else:
# print(i18n("VITS单句推理模式已开启"))
no_prompt_text = False
if prompt_text in [None, ""]:
@ -547,7 +546,7 @@ class TTS:
)
t2 = ttime()
print("############ 推理 ############")
###### inference ######
t_34 = 0.0
t_45 = 0.0
@ -601,6 +600,10 @@ class TTS:
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# max_len = 0
# for i in range(0, len(batch_phones)):
# max_len = max(max_len, batch_phones[i].shape[-1])
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
@ -654,7 +657,12 @@ class TTS:
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket)
split_bucket)
try:
torch.cuda.empty_cache()
except:
pass

View File

@ -1,5 +1,7 @@
import os, sys
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -12,9 +14,9 @@ from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
# from tools.i18n.i18n import I18nAuto
from tools.i18n.i18n import I18nAuto
# i18n = I18nAuto()
i18n = I18nAuto()
def get_first(text:str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
@ -51,9 +53,11 @@ class TextPreprocessor:
self.device = device
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
print(i18n("############ 切分文本 ############"))
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
for text in texts:
print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
res={
"phones": phones,
@ -67,14 +71,16 @@ class TextPreprocessor:
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4):
text = "" + text if lang != "en" else "." + text
# print(i18n("实际输入的目标文本:"), text)
print(i18n("实际输入的目标文本:"))
print(text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
# print(i18n("实际输入的目标文本(切句后):"), text)
print(i18n("实际输入的目标文本(切句后):"))
print(text)
_texts = text.split("\n")
_texts = merge_short_text_in_array(_texts, 5)
texts = []
@ -105,7 +111,7 @@ class TextPreprocessor:
textlist=[]
langlist=[]
if language in ["auto", "zh", "ja"]:
# LangSegment.setfilters(["zh","ja","en","ko"])
LangSegment.setfilters(["zh","ja","en","ko"])
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
@ -116,7 +122,7 @@ class TextPreprocessor:
langlist.append(language if language!="auto" else tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
# LangSegment.setfilters(["en"])
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
while " " in formattext:
formattext = formattext.replace(" ", " ")
@ -153,7 +159,7 @@ class TextPreprocessor:
# phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones, bert_feature, norm_text
return phones_list, bert_feature, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:

View File

@ -103,10 +103,12 @@ def inference(text, text_lang,
"batch_size":int(batch_size),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
"return_fragment":False
}
yield next(tts_pipline.run(inputs))
for item in tts_pipline.run(inputs):
yield item
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
@ -182,7 +184,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True)
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=1,interactive=True)
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)