diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index a3170b9..dfd6eb0 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -229,10 +229,15 @@ class Text2SemanticDecoder(nn.Module): ignore_index=self.EOS, ) - if not flash_attn_enabled: + self.enable_flash_attn(flash_attn_enabled) + + def enable_flash_attn(self, enable:bool=True): + + if not enable: print("Not Using Flash Attention") self.infer_panel = self.infer_panel_batch_only else: + self.infer_panel = self.infer_panel_batch_infer_with_flash_attn print("Using Flash Attention") blocks = [] @@ -497,7 +502,7 @@ class Text2SemanticDecoder(nn.Module): # 错位 return targets[:, :-1], targets[:, 1:] - def infer_panel( + def infer_panel_batch_infer_with_flash_attn( self, x, #####全部文本token x_lens, @@ -508,8 +513,10 @@ class Text2SemanticDecoder(nn.Module): early_stop_num: int = -1, temperature: float = 1.0, ): + + bert_feature = self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = x + bert_feature x = self.ar_text_position(x) # AR Decoder @@ -546,30 +553,28 @@ class Text2SemanticDecoder(nn.Module): y_lens = torch.LongTensor([y_len]*bsz).to(x.device) y_mask = make_pad_mask(y_lens) x_mask = make_pad_mask(x_lens) - + # (bsz, x_len + y_len) xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) - _xy_padding_mask = ( - xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1) - ) - x_attn_mask_pad = F.pad( + x_mask = F.pad( x_attn_mask, (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) value=True, ) - y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( - x.device - ) - xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device) + # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) + xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len) + xy_attn_mask = xy_mask.logical_or(xy_padding_mask) + xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) - new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) - xy_attn_mask = new_attn_mask + xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) ###### decode ##### y_list = [None]*y.shape[0] @@ -641,7 +646,7 @@ class Text2SemanticDecoder(nn.Module): ####################### update next step ################################### y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) - + if (None in idx_list): for i in range(x.shape[0]): if idx_list[i] is None: diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index 34178fe..ce0a98b 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -143,7 +143,7 @@ def logits_to_probs( if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) + pivot = v[: , -1].unsqueeze(-1) logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c111034..791e0f4 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,5 +1,6 @@ import math import os, sys +import random now_dir = os.getcwd() sys.path.append(now_dir) import ffmpeg @@ -7,6 +8,7 @@ import os from typing import Generator, List, Union import numpy as np import torch +import torch.nn.functional as F import yaml from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -130,11 +132,11 @@ class TTS_Config: string = "----------------TTS Config--------------\n" string += "device: {}\n".format(self.device) string += "is_half: {}\n".format(self.is_half) + string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled) string += "bert_base_path: {}\n".format(self.bert_base_path) string += "t2s_weights_path: {}\n".format(self.t2s_weights_path) string += "vits_weights_path: {}\n".format(self.vits_weights_path) string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path) - string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled) string += "----------------------------------------\n" return string @@ -184,7 +186,7 @@ class TTS: def init_cnhuhbert_weights(self, base_path: str): self.cnhuhbert_model = CNHubert(base_path) - self.cnhuhbert_model.eval() + self.cnhuhbert_model=self.cnhuhbert_model.eval() if self.configs.is_half == True: self.cnhuhbert_model = self.cnhuhbert_model.half() self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) @@ -194,6 +196,7 @@ class TTS: def init_bert_weights(self, base_path: str): self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) + self.bert_model=self.bert_model.eval() if self.configs.is_half: self.bert_model = self.bert_model.half() self.bert_model = self.bert_model.to(self.configs.device) @@ -226,7 +229,7 @@ class TTS: if self.configs.is_half: vits_model = vits_model.half() vits_model = vits_model.to(self.configs.device) - vits_model.eval() + vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) self.vits_model = vits_model @@ -244,7 +247,7 @@ class TTS: if self.configs.is_half: t2s_model = t2s_model.half() t2s_model = t2s_model.to(self.configs.device) - t2s_model.eval() + t2s_model = t2s_model.eval() self.t2s_model = t2s_model def set_ref_audio(self, ref_audio_path:str): @@ -377,12 +380,14 @@ class TTS: phones_max_len = 0 for item in item_list: if prompt_data is not None: - all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1) + all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ + .to(dtype=torch.float32 if not self.configs.is_half else torch.float16) all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]) phones = torch.LongTensor(item["phones"]) # norm_text = prompt_data["norm_text"]+item["norm_text"] else: - all_bert_features = item["bert_features"] + all_bert_features = item["bert_features"]\ + .to(dtype=torch.float32 if not self.configs.is_half else torch.float16) phones = torch.LongTensor(item["phones"]) all_phones = phones # norm_text = item["norm_text"] @@ -401,12 +406,10 @@ class TTS: max_len = max(bert_max_len, phones_max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) - all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len) - all_bert_features_batch.zero_() - + # all_bert_features_batch = all_bert_features_list + all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32) for idx, item in enumerate(all_bert_features_list): - if item != None: - all_bert_features_batch[idx, :, : item.shape[-1]] = item + all_bert_features_batch[idx, :, : item.shape[-1]] = item batch = { "phones": phones_batch, @@ -458,8 +461,8 @@ class TTS: "prompt_text": "", # str. prompt text for the reference audio "prompt_lang": "", # str. language of the prompt text for the reference audio "top_k": 5, # int. top k sampling - "top_p": 0.9, # float. top p sampling - "temperature": 0.6, # float. temperature for sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling "text_split_method": "", # str. text split method, see text_segmentaion_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. @@ -477,9 +480,9 @@ class TTS: ref_audio_path:str = inputs.get("ref_audio_path", "") prompt_text:str = inputs.get("prompt_text", "") prompt_lang:str = inputs.get("prompt_lang", "") - top_k:int = inputs.get("top_k", 20) - top_p:float = inputs.get("top_p", 0.9) - temperature:float = inputs.get("temperature", 0.6) + top_k:int = inputs.get("top_k", 5) + top_p:float = inputs.get("top_p", 1) + temperature:float = inputs.get("temperature", 1) text_split_method:str = inputs.get("text_split_method", "") batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) @@ -497,10 +500,6 @@ class TTS: if split_bucket: print(i18n("分桶处理模式已开启")) - # if vits_batched_inference: - # print(i18n("VITS批量推理模式已开启")) - # else: - # print(i18n("VITS单句推理模式已开启")) no_prompt_text = False if prompt_text in [None, ""]: @@ -547,7 +546,7 @@ class TTS: ) t2 = ttime() - + print("############ 推理 ############") ###### inference ###### t_34 = 0.0 t_45 = 0.0 @@ -601,6 +600,10 @@ class TTS: # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) # batch_phones = batch_phones.to(self.configs.device) # batch_audio_fragment = (self.vits_model.batched_decode( # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc @@ -654,7 +657,12 @@ class TTS: self.configs.sampling_rate, batch_index_list, speed_factor, - split_bucket) + split_bucket) + + try: + torch.cuda.empty_cache() + except: + pass diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 2669bf4..986819f 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,5 +1,7 @@ import os, sys + +from tqdm import tqdm now_dir = os.getcwd() sys.path.append(now_dir) @@ -12,9 +14,9 @@ from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method -# from tools.i18n.i18n import I18nAuto +from tools.i18n.i18n import I18nAuto -# i18n = I18nAuto() +i18n = I18nAuto() def get_first(text:str) -> str: pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" @@ -51,9 +53,11 @@ class TextPreprocessor: self.device = device def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]: + print(i18n("############ 切分文本 ############")) texts = self.pre_seg_text(text, lang, text_split_method) result = [] - for text in texts: + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(texts): phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang) res={ "phones": phones, @@ -67,14 +71,16 @@ class TextPreprocessor: text = text.strip("\n") if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if lang != "en" else "." + text - # print(i18n("实际输入的目标文本:"), text) + print(i18n("实际输入的目标文本:")) + print(text) seg_method = get_seg_method(text_split_method) text = seg_method(text) while "\n\n" in text: text = text.replace("\n\n", "\n") - # print(i18n("实际输入的目标文本(切句后):"), text) + print(i18n("实际输入的目标文本(切句后):")) + print(text) _texts = text.split("\n") _texts = merge_short_text_in_array(_texts, 5) texts = [] @@ -105,7 +111,7 @@ class TextPreprocessor: textlist=[] langlist=[] if language in ["auto", "zh", "ja"]: - # LangSegment.setfilters(["zh","ja","en","ko"]) + LangSegment.setfilters(["zh","ja","en","ko"]) for tmp in LangSegment.getTexts(text): if tmp["lang"] == "ko": langlist.append("zh") @@ -116,7 +122,7 @@ class TextPreprocessor: langlist.append(language if language!="auto" else tmp["lang"]) textlist.append(tmp["text"]) elif language == "en": - # LangSegment.setfilters(["en"]) + LangSegment.setfilters(["en"]) formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) while " " in formattext: formattext = formattext.replace(" ", " ") @@ -153,7 +159,7 @@ class TextPreprocessor: # phones = sum(phones_list, []) norm_text = ''.join(norm_text_list) - return phones, bert_feature, norm_text + return phones_list, bert_feature, norm_text def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor: diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 2d223f9..53333c4 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -103,10 +103,12 @@ def inference(text, text_lang, "batch_size":int(batch_size), "speed_factor":float(speed_factor), "split_bucket":split_bucket, - "return_fragment":False, + "return_fragment":False } - yield next(tts_pipline.run(inputs)) - + + for item in tts_pipline.run(inputs): + yield item + def custom_sort_key(s): # 使用正则表达式提取字符串中的数字部分和非数字部分 parts = re.split('(\d+)', s) @@ -182,7 +184,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Row(): with gr.Column(): - batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True) + batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=1,interactive=True) speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)