diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 0913c098..4e135a2c 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -234,10 +234,15 @@ class Text2SemanticDecoder(nn.Module): ignore_index=self.EOS, ) - if not flash_attn_enabled: + self.enable_flash_attn(flash_attn_enabled) + + def enable_flash_attn(self, enable:bool=True): + + if not enable: print("Not Using Flash Attention") self.infer_panel = self.infer_panel_batch_only else: + self.infer_panel = self.infer_panel_batch_infer_with_flash_attn print("Using Flash Attention") blocks = [] @@ -502,91 +507,7 @@ class Text2SemanticDecoder(nn.Module): # 错位 return targets[:, :-1], targets[:, 1:] - def infer_one_step(self, x, xy_attn_mask, k_cache, v_cache, cache_seqlens): - hidden_dim = x.shape[-1] - - for layer_id in range(self.num_layers): - layer = self.h.layers[layer_id] - - q, k, v = F.linear( - x, - layer.self_attn.in_proj_weight, - layer.self_attn.in_proj_bias - ).chunk(3, dim=-1) - - batch_size = q.shape[0] - q_len = q.shape[1] - - if flash_attn_with_kvcache is None: - past_k = k_cache[layer_id] - past_v = v_cache[layer_id] - - if past_k is not None: - k = torch.cat([past_k, k], 1) - v = torch.cat([past_v, v], 1) - k_cache[layer_id] = k - v_cache[layer_id] = v - kv_len = k.shape[1] - - q = q.view(batch_size, q_len, layer.self_attn.num_heads, -1).transpose(1, 2) - k = k.view(batch_size, kv_len, layer.self_attn.num_heads, -1).transpose(1, 2) - v = v.view(batch_size, kv_len, layer.self_attn.num_heads, -1).transpose(1, 2) - - if xy_attn_mask is None: - attn = F.scaled_dot_product_attention(q, k, v) - else: - attn = F.scaled_dot_product_attention(q, k, v, ~xy_attn_mask) - - attn = attn.permute(2, 0, 1, 3).reshape(-1, hidden_dim) - else: - q = q.view(batch_size, q_len, layer.self_attn.num_heads, -1) - k = k.view(batch_size, q_len, layer.self_attn.num_heads, -1) - v = v.view(batch_size, q_len, layer.self_attn.num_heads, -1) - - if xy_attn_mask is None: - attn = flash_attn_with_kvcache(q, k_cache[layer_id], v_cache[layer_id], k, v, cache_seqlens=cache_seqlens, causal=True) - else: - # NOTE: there's a slight difference with the result produced by SDPA. - x_len = (~xy_attn_mask).sum(1)[0].item() - - attn_x = flash_attn_with_kvcache( - q[:, :x_len], - k_cache[layer_id], - v_cache[layer_id], - k[:, :x_len], - v[:, :x_len], - cache_seqlens=cache_seqlens, - causal=False - ) - - attn_y = flash_attn_with_kvcache( - q[:, x_len:], - k_cache[layer_id], - v_cache[layer_id], - k[:, x_len:], - v[:, x_len:], - cache_seqlens=cache_seqlens + x_len, - causal=True - ) - - attn = torch.cat([attn_x, attn_y], dim=1) - attn = attn.view(-1, hidden_dim) - - attn_out = F.linear(attn, layer.self_attn.out_proj.weight, layer.self_attn.out_proj.bias) - - x = layer.norm1(x + attn_out, None) - - x = layer.norm2(x + layer.linear2(F.relu(layer.linear1(x))), None) - - xy_dec = x - - logits = self.ar_predict_layer( - xy_dec[:, -1] - ) - - return logits - - def infer_panel( + def infer_panel_batch_infer_with_flash_attn( self, x, #####全部文本token x_lens, @@ -597,8 +518,10 @@ class Text2SemanticDecoder(nn.Module): early_stop_num: int = -1, temperature: float = 1.0, ): + + bert_feature = self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = x + bert_feature x = self.ar_text_position(x) # AR Decoder @@ -635,30 +558,28 @@ class Text2SemanticDecoder(nn.Module): y_lens = torch.LongTensor([y_len]*bsz).to(x.device) y_mask = make_pad_mask(y_lens) x_mask = make_pad_mask(x_lens) - + # (bsz, x_len + y_len) xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) - _xy_padding_mask = ( - xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1) - ) - x_attn_mask_pad = F.pad( + x_mask = F.pad( x_attn_mask, (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) value=True, ) - y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( - x.device - ) - xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device) + # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) + xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len) + xy_attn_mask = xy_mask.logical_or(xy_padding_mask) + xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) - new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) - xy_attn_mask = new_attn_mask + xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) ###### decode ##### y_list = [None]*y.shape[0] @@ -730,7 +651,7 @@ class Text2SemanticDecoder(nn.Module): ####################### update next step ################################### y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) - + if (None in idx_list): for i in range(x.shape[0]): if idx_list[i] is None: diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index 34178fea..ce0a98b7 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -143,7 +143,7 @@ def logits_to_probs( if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) + pivot = v[: , -1].unsqueeze(-1) logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 0bf5f72b..4f256e49 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,5 +1,6 @@ import math import os, sys +import random now_dir = os.getcwd() sys.path.append(now_dir) import ffmpeg @@ -7,6 +8,7 @@ import os from typing import Generator, List, Union import numpy as np import torch +import torch.nn.functional as F import yaml from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -97,7 +99,6 @@ class TTS_Config: configs = yaml.load(f, Loader=yaml.FullLoader) return configs - def save_configs(self, configs_path:str=None)->None: configs={ @@ -110,32 +111,31 @@ class TTS_Config: "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "flash_attn_enabled": True }, - "custom": { - "device": str(self.device), - "is_half": self.is_half, - "t2s_weights_path": self.t2s_weights_path, - "vits_weights_path": self.vits_weights_path, - "bert_base_path": self.bert_base_path, - "cnhuhbert_base_path": self.cnhuhbert_base_path, - "flash_attn_enabled": self.flash_attn_enabled - } + "custom": self.update_configs() } if configs_path is None: configs_path = self.configs_path with open(configs_path, 'w') as f: yaml.dump(configs, f) - + + def update_configs(self): + config = { + "device" : str(self.device), + "is_half" : self.is_half, + "t2s_weights_path" : self.t2s_weights_path, + "vits_weights_path" : self.vits_weights_path, + "bert_base_path" : self.bert_base_path, + "cnhuhbert_base_path": self.cnhuhbert_base_path, + "flash_attn_enabled" : self.flash_attn_enabled + } + return config def __str__(self): - string = "----------------TTS Config--------------\n" - string += "device: {}\n".format(self.device) - string += "is_half: {}\n".format(self.is_half) - string += "bert_base_path: {}\n".format(self.bert_base_path) - string += "t2s_weights_path: {}\n".format(self.t2s_weights_path) - string += "vits_weights_path: {}\n".format(self.vits_weights_path) - string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path) - string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled) - string += "----------------------------------------\n" + self.configs = self.update_configs() + string = "TTS Config".center(100, '-') + '\n' + for k, v in self.configs.items(): + string += f"{str(k).ljust(20)}: {str(v)}\n" + string += "-" * 100 + '\n' return string @@ -184,7 +184,7 @@ class TTS: def init_cnhuhbert_weights(self, base_path: str): self.cnhuhbert_model = CNHubert(base_path) - self.cnhuhbert_model.eval() + self.cnhuhbert_model=self.cnhuhbert_model.eval() if self.configs.is_half == True: self.cnhuhbert_model = self.cnhuhbert_model.half() self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) @@ -194,6 +194,7 @@ class TTS: def init_bert_weights(self, base_path: str): self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) + self.bert_model=self.bert_model.eval() if self.configs.is_half: self.bert_model = self.bert_model.half() self.bert_model = self.bert_model.to(self.configs.device) @@ -226,7 +227,7 @@ class TTS: if self.configs.is_half: vits_model = vits_model.half() vits_model = vits_model.to(self.configs.device) - vits_model.eval() + vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) self.vits_model = vits_model @@ -244,7 +245,7 @@ class TTS: if self.configs.is_half: t2s_model = t2s_model.half() t2s_model = t2s_model.to(self.configs.device) - t2s_model.eval() + t2s_model = t2s_model.eval() self.t2s_model = t2s_model def set_ref_audio(self, ref_audio_path:str): @@ -377,12 +378,14 @@ class TTS: phones_max_len = 0 for item in item_list: if prompt_data is not None: - all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1) + all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ + .to(dtype=torch.float32 if not self.configs.is_half else torch.float16) all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]) phones = torch.LongTensor(item["phones"]) # norm_text = prompt_data["norm_text"]+item["norm_text"] else: - all_bert_features = item["bert_features"] + all_bert_features = item["bert_features"]\ + .to(dtype=torch.float32 if not self.configs.is_half else torch.float16) phones = torch.LongTensor(item["phones"]) all_phones = phones # norm_text = item["norm_text"] @@ -401,12 +404,10 @@ class TTS: max_len = max(bert_max_len, phones_max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) - all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len) - all_bert_features_batch.zero_() - + # all_bert_features_batch = all_bert_features_list + all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32) for idx, item in enumerate(all_bert_features_list): - if item != None: - all_bert_features_batch[idx, :, : item.shape[-1]] = item + all_bert_features_batch[idx, :, : item.shape[-1]] = item batch = { "phones": phones_batch, @@ -458,8 +459,8 @@ class TTS: "prompt_text": "", # str. prompt text for the reference audio "prompt_lang": "", # str. language of the prompt text for the reference audio "top_k": 5, # int. top k sampling - "top_p": 0.9, # float. top p sampling - "temperature": 0.6, # float. temperature for sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling "text_split_method": "", # str. text split method, see text_segmentaion_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. @@ -477,9 +478,9 @@ class TTS: ref_audio_path:str = inputs.get("ref_audio_path", "") prompt_text:str = inputs.get("prompt_text", "") prompt_lang:str = inputs.get("prompt_lang", "") - top_k:int = inputs.get("top_k", 20) - top_p:float = inputs.get("top_p", 0.9) - temperature:float = inputs.get("temperature", 0.6) + top_k:int = inputs.get("top_k", 5) + top_p:float = inputs.get("top_p", 1) + temperature:float = inputs.get("temperature", 1) text_split_method:str = inputs.get("text_split_method", "") batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) @@ -497,10 +498,6 @@ class TTS: if split_bucket: print(i18n("分桶处理模式已开启")) - # if vits_batched_inference: - # print(i18n("VITS批量推理模式已开启")) - # else: - # print(i18n("VITS单句推理模式已开启")) no_prompt_text = False if prompt_text in [None, ""]: @@ -547,7 +544,7 @@ class TTS: ) t2 = ttime() - + print("############ 推理 ############") ###### inference ###### t_34 = 0.0 t_45 = 0.0 @@ -601,6 +598,10 @@ class TTS: # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) # batch_phones = batch_phones.to(self.configs.device) # batch_audio_fragment = (self.vits_model.batched_decode( # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc @@ -654,7 +655,12 @@ class TTS: self.configs.sampling_rate, batch_index_list, speed_factor, - split_bucket) + split_bucket) + + try: + torch.cuda.empty_cache() + except: + pass diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 2669bf41..ee61bc30 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,5 +1,7 @@ import os, sys + +from tqdm import tqdm now_dir = os.getcwd() sys.path.append(now_dir) @@ -12,9 +14,9 @@ from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method -# from tools.i18n.i18n import I18nAuto +from tools.i18n.i18n import I18nAuto -# i18n = I18nAuto() +i18n = I18nAuto() def get_first(text:str) -> str: pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" @@ -51,9 +53,11 @@ class TextPreprocessor: self.device = device def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]: + print(i18n("############ 切分文本 ############")) texts = self.pre_seg_text(text, lang, text_split_method) result = [] - for text in texts: + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(texts): phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang) res={ "phones": phones, @@ -67,14 +71,16 @@ class TextPreprocessor: text = text.strip("\n") if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if lang != "en" else "." + text - # print(i18n("实际输入的目标文本:"), text) + print(i18n("实际输入的目标文本:")) + print(text) seg_method = get_seg_method(text_split_method) text = seg_method(text) while "\n\n" in text: text = text.replace("\n\n", "\n") - # print(i18n("实际输入的目标文本(切句后):"), text) + print(i18n("实际输入的目标文本(切句后):")) + print(text) _texts = text.split("\n") _texts = merge_short_text_in_array(_texts, 5) texts = [] @@ -105,7 +111,7 @@ class TextPreprocessor: textlist=[] langlist=[] if language in ["auto", "zh", "ja"]: - # LangSegment.setfilters(["zh","ja","en","ko"]) + LangSegment.setfilters(["zh","ja","en","ko"]) for tmp in LangSegment.getTexts(text): if tmp["lang"] == "ko": langlist.append("zh") @@ -116,7 +122,7 @@ class TextPreprocessor: langlist.append(language if language!="auto" else tmp["lang"]) textlist.append(tmp["text"]) elif language == "en": - # LangSegment.setfilters(["en"]) + LangSegment.setfilters(["en"]) formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) while " " in formattext: formattext = formattext.replace(" ", " ") @@ -152,8 +158,7 @@ class TextPreprocessor: bert_feature = torch.cat(bert_feature_list, dim=1) # phones = sum(phones_list, []) norm_text = ''.join(norm_text_list) - - return phones, bert_feature, norm_text + return phones_list, bert_feature, norm_text def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor: diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 334e2ab6..b3b25c9b 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -45,9 +45,11 @@ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时 if torch.cuda.is_available(): device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" else: device = "cpu" - + dict_language = { i18n("中文"): "all_zh",#全部按中文识别 i18n("英文"): "en",#全部按英文识别#######不变 @@ -103,10 +105,12 @@ def inference(text, text_lang, "batch_size":int(batch_size), "speed_factor":float(speed_factor), "split_bucket":split_bucket, - "return_fragment":False, + "return_fragment":False } - yield next(tts_pipline.run(inputs)) - + + for item in tts_pipline.run(inputs): + yield item + def custom_sort_key(s): # 使用正则表达式提取字符串中的数字部分和非数字部分 parts = re.split('(\d+)', s) @@ -182,7 +186,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Row(): with gr.Column(): - batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True) + batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=1,interactive=True) speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)