diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index 2dd3f392..1b602629 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -13,11 +13,11 @@ from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.optim import ScaledAdam class Text2SemanticLightningModule(LightningModule): - def __init__(self, config, output_dir, is_train=True): + def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False): super().__init__() self.config = config self.top_k = 3 - self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) + self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled) pretrained_s1 = config.get("pretrained_s1") if pretrained_s1 and is_train: # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ed46b2b1..dfd6eb0d 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -1,7 +1,9 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e +import os, sys +now_dir = os.getcwd() +sys.path.append(now_dir) from typing import List - import torch from tqdm import tqdm @@ -174,7 +176,7 @@ class T2STransformer: class Text2SemanticDecoder(nn.Module): - def __init__(self, config, norm_first=False, top_k=3): + def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False): super(Text2SemanticDecoder, self).__init__() self.model_dim = config["model"]["hidden_dim"] self.embedding_dim = config["model"]["embedding_dim"] @@ -226,37 +228,47 @@ class Text2SemanticDecoder(nn.Module): multidim_average="global", ignore_index=self.EOS, ) - - blocks = [] - - for i in range(self.num_layers): - layer = self.h.layers[i] - t2smlp = T2SMLP( - layer.linear1.weight, - layer.linear1.bias, - layer.linear2.weight, - layer.linear2.bias - ) - - block = T2SBlock( - self.num_head, - self.model_dim, - t2smlp, - layer.self_attn.in_proj_weight, - layer.self_attn.in_proj_bias, - layer.self_attn.out_proj.weight, - layer.self_attn.out_proj.bias, - layer.norm1.weight, - layer.norm1.bias, - layer.norm1.eps, - layer.norm2.weight, - layer.norm2.bias, - layer.norm2.eps - ) - - blocks.append(block) - self.t2s_transformer = T2STransformer(self.num_layers, blocks) + self.enable_flash_attn(flash_attn_enabled) + + def enable_flash_attn(self, enable:bool=True): + + if not enable: + print("Not Using Flash Attention") + self.infer_panel = self.infer_panel_batch_only + else: + self.infer_panel = self.infer_panel_batch_infer_with_flash_attn + print("Using Flash Attention") + blocks = [] + + for i in range(self.num_layers): + layer = self.h.layers[i] + t2smlp = T2SMLP( + layer.linear1.weight, + layer.linear1.bias, + layer.linear2.weight, + layer.linear2.bias + ) + + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps + ) + + blocks.append(block) + + self.t2s_transformer = T2STransformer(self.num_layers, blocks) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -490,7 +502,7 @@ class Text2SemanticDecoder(nn.Module): # 错位 return targets[:, :-1], targets[:, 1:] - def infer_panel( + def infer_panel_batch_infer_with_flash_attn( self, x, #####全部文本token x_lens, @@ -501,8 +513,10 @@ class Text2SemanticDecoder(nn.Module): early_stop_num: int = -1, temperature: float = 1.0, ): + + bert_feature = self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = x + bert_feature x = self.ar_text_position(x) # AR Decoder @@ -539,30 +553,28 @@ class Text2SemanticDecoder(nn.Module): y_lens = torch.LongTensor([y_len]*bsz).to(x.device) y_mask = make_pad_mask(y_lens) x_mask = make_pad_mask(x_lens) - + # (bsz, x_len + y_len) xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) - _xy_padding_mask = ( - xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1) - ) - x_attn_mask_pad = F.pad( + x_mask = F.pad( x_attn_mask, (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) value=True, ) - y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( - x.device - ) - xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device) + # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) + xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len) + xy_attn_mask = xy_mask.logical_or(xy_padding_mask) + xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) - new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) - xy_attn_mask = new_attn_mask + xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) ###### decode ##### y_list = [None]*y.shape[0] @@ -634,6 +646,168 @@ class Text2SemanticDecoder(nn.Module): ####################### update next step ################################### y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) + + if (None in idx_list): + for i in range(x.shape[0]): + if idx_list[i] is None: + idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 + + if ref_free: + return y_list, [0]*x.shape[0] + return y_list, idx_list + + def infer_panel_batch_only( + self, + x, #####全部文本token + x_lens, + prompts, ####参考音频token + bert_feature, + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + ): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = self.ar_text_position(x) + + # AR Decoder + y = prompts + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + stop = False + # print(1111111,self.num_layers) + cache = { + "all_stage": self.num_layers, + "k": [None] * self.num_layers, ###根据配置自己手写 + "v": [None] * self.num_layers, + # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 + "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行 + # "logits":None,###原版就已经只对结尾求再拼接了,不用管 + # "xy_dec":None,###不需要,本来只需要最后一个做logits + "first_infer": 1, + "stage": 0, + } + ################### first step ########################## + if y is not None: + y_emb = self.ar_audio_embedding(y) + y_len = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + cache["y_emb"] = y_emb + ref_free = False + else: + y_emb = None + y_len = 0 + prefix_len = 0 + y_pos = None + xy_pos = x + y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) + ref_free = True + + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( + x.device + ) + + y_list = [None]*y.shape[0] + batch_idx_map = list(range(y.shape[0])) + idx_list = [None]*y.shape[0] + for idx in tqdm(range(1500)): + + xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer( + xy_dec[:, -1] + ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 + # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) + if(idx==0):###第一次跑不能EOS否则没有了 + logits = logits[:, :-1] ###刨除1024终止符号的概率 + samples = sample( + logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature + )[0] + # 本次生成的 semantic_ids 和之前的 y 构成新的 y + # print(samples.shape)#[1,1]#第一个1是bs + y = torch.concat([y, samples], dim=1) + + # 移除已经生成完毕的序列 + reserved_idx_of_batch_for_y = None + if (self.EOS in torch.argmax(logits, dim=-1)) or \ + (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止 + l = samples[:, 0]==self.EOS + removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist() + reserved_idx_of_batch_for_y = torch.where(l==False)[0] + # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y] + for i in removed_idx_of_batch_for_y: + batch_index = batch_idx_map[i] + idx_list[batch_index] = idx - 1 + y_list[batch_index] = y[i, :-1] + + batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()] + + # 只保留未生成完毕的序列 + if reserved_idx_of_batch_for_y is not None: + # index = torch.LongTensor(batch_idx_map).to(y.device) + y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) + if cache["y_emb"] is not None: + cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y) + if cache["k"] is not None: + for i in range(self.num_layers): + # 因为kv转置了,所以batch dim是1 + cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y) + cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y) + + + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + print("use early stop num:", early_stop_num) + stop = True + + if not (None in idx_list): + # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) + stop = True + if stop: + # if prompts.shape[1] == y.shape[1]: + # y = torch.concat([y, torch.zeros_like(samples)], dim=1) + # print("bad zero prediction") + if y.shape[1]==0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + print("bad zero prediction") + print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + break + + ####################### update next step ################################### + cache["first_infer"] = 0 + if cache["y_emb"] is not None: + y_emb = torch.cat( + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1 + ) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + xy_pos = y_pos[:, -1:] + else: + y_emb = self.ar_audio_embedding(y[:, -1:]) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + xy_pos = y_pos + y_len = y_pos.shape[1] + + ###最右边一列(是错的) + # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) + # xy_attn_mask[:,-1]=False + ###最下面一行(是对的) + xy_attn_mask = torch.zeros( + (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device + ) if (None in idx_list): for i in range(x.shape[0]): diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index 34178fea..ce0a98b7 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -143,7 +143,7 @@ def logits_to_probs( if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) + pivot = v[: , -1].unsqueeze(-1) logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index ba29a03f..62bb2e90 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,4 +1,8 @@ +from copy import deepcopy +import math import os, sys +import random +import traceback now_dir = os.getcwd() sys.path.append(now_dir) import ffmpeg @@ -6,6 +10,7 @@ import os from typing import Generator, List, Union import numpy as np import torch +import torch.nn.functional as F import yaml from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -17,8 +22,8 @@ from time import time as ttime from tools.i18n.i18n import I18nAuto from my_utils import load_audio from module.mel_processing import spectrogram_torch -from .text_segmentation_method import splits -from .TextPreprocessor import TextPreprocessor +from TTS_infer_pack.text_segmentation_method import splits +from TTS_infer_pack.TextPreprocessor import TextPreprocessor i18n = I18nAuto() # configs/tts_infer.yaml @@ -30,6 +35,7 @@ default: cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth + flash_attn_enabled: true custom: device: cuda @@ -38,41 +44,81 @@ custom: cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - + flash_attn_enabled: true """ - - +# def set_seed(seed): +# random.seed(seed) +# os.environ['PYTHONHASHSEED'] = str(seed) +# np.random.seed(seed) +# torch.manual_seed(seed) +# torch.cuda.manual_seed(seed) +# torch.cuda.manual_seed_all(seed) +# torch.backends.cudnn.deterministic = True +# torch.backends.cudnn.benchmark = False +# torch.backends.cudnn.enabled = True +# set_seed(1234) + class TTS_Config: - def __init__(self, configs: Union[dict, str]): - configs_base_path:str = "GPT_SoVITS/configs/" - os.makedirs(configs_base_path, exist_ok=True) - self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") - if isinstance(configs, str): - self.configs_path = configs - configs:dict = self._load_configs(configs) - - # assert isinstance(configs, dict) - self.default_configs:dict = configs.get("default", None) - if self.default_configs is None: - self.default_configs={ + default_configs={ "device": "cpu", "is_half": False, "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", - "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + "flash_attn_enabled": True } - self.configs:dict = configs.get("custom", self.default_configs) + configs:dict = None + def __init__(self, configs: Union[dict, str]=None): - self.device = self.configs.get("device") - self.is_half = self.configs.get("is_half") - self.t2s_weights_path = self.configs.get("t2s_weights_path") - self.vits_weights_path = self.configs.get("vits_weights_path") - self.bert_base_path = self.configs.get("bert_base_path") - self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path") + # 设置默认配置文件路径 + configs_base_path:str = "GPT_SoVITS/configs/" + os.makedirs(configs_base_path, exist_ok=True) + self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") + + if configs in ["", None]: + if not os.path.exists(self.configs_path): + self.save_configs() + print(f"Create default config file at {self.configs_path}") + configs:dict = {"default": deepcopy(self.default_configs)} + + if isinstance(configs, str): + self.configs_path = configs + configs:dict = self._load_configs(self.configs_path) + + assert isinstance(configs, dict) + default_configs:dict = configs.get("default", None) + if default_configs is not None: + self.default_configs = default_configs + + self.configs:dict = configs.get("custom", deepcopy(self.default_configs)) + + + self.device = self.configs.get("device", torch.device("cpu")) + self.is_half = self.configs.get("is_half", False) + self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True) + self.t2s_weights_path = self.configs.get("t2s_weights_path", None) + self.vits_weights_path = self.configs.get("vits_weights_path", None) + self.bert_base_path = self.configs.get("bert_base_path", None) + self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) + + + if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): + self.t2s_weights_path = self.default_configs['t2s_weights_path'] + print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") + if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): + self.vits_weights_path = self.default_configs['vits_weights_path'] + print(f"fall back to default vits_weights_path: {self.vits_weights_path}") + if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): + self.bert_base_path = self.default_configs['bert_base_path'] + print(f"fall back to default bert_base_path: {self.bert_base_path}") + if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): + self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path'] + print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") + self.update_configs() self.max_sec = None @@ -86,50 +132,48 @@ class TTS_Config: self.n_speakers:int = 300 self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] - print(self) + # print(self) def _load_configs(self, configs_path: str)->dict: with open(configs_path, 'r') as f: configs = yaml.load(f, Loader=yaml.FullLoader) return configs - def save_configs(self, configs_path:str=None)->None: configs={ - "default": { - "device": "cpu", - "is_half": False, - "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", - "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", - "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", - "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" - }, - "custom": { - "device": str(self.device), - "is_half": self.is_half, - "t2s_weights_path": self.t2s_weights_path, - "vits_weights_path": self.vits_weights_path, - "bert_base_path": self.bert_base_path, - "cnhuhbert_base_path": self.cnhuhbert_base_path - } + "default":self.default_configs, } + if self.configs is not None: + configs["custom"] = self.update_configs() + if configs_path is None: configs_path = self.configs_path with open(configs_path, 'w') as f: yaml.dump(configs, f) - + + def update_configs(self): + self.config = { + "device" : str(self.device), + "is_half" : self.is_half, + "t2s_weights_path" : self.t2s_weights_path, + "vits_weights_path" : self.vits_weights_path, + "bert_base_path" : self.bert_base_path, + "cnhuhbert_base_path": self.cnhuhbert_base_path, + "flash_attn_enabled" : self.flash_attn_enabled + } + return self.config def __str__(self): - string = "----------------TTS Config--------------\n" - string += "device: {}\n".format(self.device) - string += "is_half: {}\n".format(self.is_half) - string += "bert_base_path: {}\n".format(self.bert_base_path) - string += "t2s_weights_path: {}\n".format(self.t2s_weights_path) - string += "vits_weights_path: {}\n".format(self.vits_weights_path) - string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path) - string += "----------------------------------------\n" + self.configs = self.update_configs() + string = "TTS Config".center(100, '-') + '\n' + for k, v in self.configs.items(): + string += f"{str(k).ljust(20)}: {str(v)}\n" + string += "-" * 100 + '\n' return string + + def __repr__(self): + return self.__str__() class TTS: @@ -166,34 +210,40 @@ class TTS: self.stop_flag:bool = False + self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32 def _init_models(self,): self.init_t2s_weights(self.configs.t2s_weights_path) self.init_vits_weights(self.configs.vits_weights_path) self.init_bert_weights(self.configs.bert_base_path) self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) + # self.enable_half_precision(self.configs.is_half) def init_cnhuhbert_weights(self, base_path: str): + print(f"Loading CNHuBERT weights from {base_path}") self.cnhuhbert_model = CNHubert(base_path) - self.cnhuhbert_model.eval() - if self.configs.is_half == True: - self.cnhuhbert_model = self.cnhuhbert_model.half() + self.cnhuhbert_model=self.cnhuhbert_model.eval() self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) + if self.configs.is_half: + self.cnhuhbert_model = self.cnhuhbert_model.half() def init_bert_weights(self, base_path: str): + print(f"Loading BERT weights from {base_path}") self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) + self.bert_model=self.bert_model.eval() + self.bert_model = self.bert_model.to(self.configs.device) if self.configs.is_half: self.bert_model = self.bert_model.half() - self.bert_model = self.bert_model.to(self.configs.device) def init_vits_weights(self, weights_path: str): + print(f"Loading VITS weights from {weights_path}") self.configs.vits_weights_path = weights_path self.configs.save_configs() dict_s2 = torch.load(weights_path, map_location=self.configs.device) @@ -216,28 +266,80 @@ class TTS: if hasattr(vits_model, "enc_q"): del vits_model.enc_q - if self.configs.is_half: - vits_model = vits_model.half() vits_model = vits_model.to(self.configs.device) - vits_model.eval() + vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) self.vits_model = vits_model + if self.configs.is_half: + self.vits_model = self.vits_model.half() def init_t2s_weights(self, weights_path: str): + print(f"Loading Text2Semantic weights from {weights_path}") self.configs.t2s_weights_path = weights_path self.configs.save_configs() self.configs.hz = 50 dict_s1 = torch.load(weights_path, map_location=self.configs.device) config = dict_s1["config"] self.configs.max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False, + flash_attn_enabled=self.configs.flash_attn_enabled) t2s_model.load_state_dict(dict_s1["weight"]) - if self.configs.is_half: - t2s_model = t2s_model.half() t2s_model = t2s_model.to(self.configs.device) - t2s_model.eval() + t2s_model = t2s_model.eval() self.t2s_model = t2s_model + if self.configs.is_half: + self.t2s_model = self.t2s_model.half() + + def enable_half_precision(self, enable: bool = True): + ''' + To enable half precision for the TTS model. + Args: + enable: bool, whether to enable half precision. + + ''' + if self.configs.device == "cpu" and enable: + print("Half precision is not supported on CPU.") + return + + self.configs.is_half = enable + self.precison = torch.float16 if enable else torch.float32 + self.configs.save_configs() + if enable: + if self.t2s_model is not None: + self.t2s_model =self.t2s_model.half() + if self.vits_model is not None: + self.vits_model = self.vits_model.half() + if self.bert_model is not None: + self.bert_model =self.bert_model.half() + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.half() + else: + if self.t2s_model is not None: + self.t2s_model = self.t2s_model.float() + if self.vits_model is not None: + self.vits_model = self.vits_model.float() + if self.bert_model is not None: + self.bert_model = self.bert_model.float() + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.float() + + def set_device(self, device: torch.device): + ''' + To set the device for all models. + Args: + device: torch.device, the device to use for all models. + ''' + self.configs.device = device + self.configs.save_configs() + if self.t2s_model is not None: + self.t2s_model = self.t2s_model.to(device) + if self.vits_model is not None: + self.vits_model = self.vits_model.to(device) + if self.bert_model is not None: + self.bert_model = self.bert_model.to(device) + if self.cnhuhbert_model is not None: + self.cnhuhbert_model = self.cnhuhbert_model.to(device) def set_ref_audio(self, ref_audio_path:str): ''' @@ -338,7 +440,7 @@ class TTS: pos_end = min(pos+batch_size,index_and_len_list.shape[0]) while pos < pos_end: batch=index_and_len_list[pos:pos_end, 1].astype(np.float32) - score=batch[(pos_end-pos)//2]/batch.mean() + score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8) if (score>=threshold) or (pos_end-pos==1): batch_index=index_and_len_list[pos:pos_end, 0].tolist() batch_index_list_len += len(batch_index) @@ -359,6 +461,7 @@ class TTS: for batch_idx, index_list in enumerate(batch_index_list): item_list = [data[idx] for idx in index_list] phones_list = [] + phones_len_list = [] # bert_features_list = [] all_phones_list = [] all_phones_len_list = [] @@ -368,37 +471,40 @@ class TTS: phones_max_len = 0 for item in item_list: if prompt_data is not None: - all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1) + all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ + .to(dtype=self.precison) all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]) phones = torch.LongTensor(item["phones"]) # norm_text = prompt_data["norm_text"]+item["norm_text"] else: - all_bert_features = item["bert_features"] + all_bert_features = item["bert_features"]\ + .to(dtype=self.precison) phones = torch.LongTensor(item["phones"]) - all_phones = phones.clone() + all_phones = phones # norm_text = item["norm_text"] bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) phones_max_len = max(phones_max_len, phones.shape[-1]) phones_list.append(phones) + phones_len_list.append(phones.shape[-1]) all_phones_list.append(all_phones) all_phones_len_list.append(all_phones.shape[-1]) all_bert_features_list.append(all_bert_features) norm_text_batch.append(item["norm_text"]) + phones_batch = phones_list max_len = max(bert_max_len, phones_max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) - all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len) - all_bert_features_batch.zero_() - + # all_bert_features_batch = all_bert_features_list + all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison) for idx, item in enumerate(all_bert_features_list): - if item != None: - all_bert_features_batch[idx, :, : item.shape[-1]] = item + all_bert_features_batch[idx, :, : item.shape[-1]] = item batch = { "phones": phones_batch, + "phones_len": torch.LongTensor(phones_len_list), "all_phones": all_phones_batch, "all_phones_len": torch.LongTensor(all_phones_len_list), "all_bert_features": all_bert_features_batch, @@ -446,8 +552,8 @@ class TTS: "prompt_text": "", # str. prompt text for the reference audio "prompt_lang": "", # str. language of the prompt text for the reference audio "top_k": 5, # int. top k sampling - "top_p": 0.9, # float. top p sampling - "temperature": 0.6, # float. temperature for sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling "text_split_method": "", # str. text split method, see text_segmentaion_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. @@ -465,9 +571,9 @@ class TTS: ref_audio_path:str = inputs.get("ref_audio_path", "") prompt_text:str = inputs.get("prompt_text", "") prompt_lang:str = inputs.get("prompt_lang", "") - top_k:int = inputs.get("top_k", 20) - top_p:float = inputs.get("top_p", 0.9) - temperature:float = inputs.get("temperature", 0.6) + top_k:int = inputs.get("top_k", 5) + top_p:float = inputs.get("top_p", 1) + temperature:float = inputs.get("temperature", 1) text_split_method:str = inputs.get("text_split_method", "") batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) @@ -484,7 +590,7 @@ class TTS: if split_bucket: print(i18n("分桶处理模式已开启")) - + no_prompt_text = False if prompt_text in [None, ""]: @@ -522,7 +628,11 @@ class TTS: ###### text preprocessing ######## data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) - audio = [] + if len(data) == 0: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), + dtype=np.int16) + return + t1 = ttime() data, batch_index_list = self.to_batch(data, prompt_data=self.prompt_cache if not no_prompt_text else None, @@ -531,117 +641,166 @@ class TTS: split_bucket=split_bucket ) t2 = ttime() - zero_wav = torch.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=torch.float16 if self.configs.is_half else torch.float32, - device=self.configs.device + try: + print("############ 推理 ############") + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + for item in data: + t3 = ttime() + batch_phones = item["phones"] + batch_phones_len = item["phones_len"] + all_phoneme_ids = item["all_phones"] + all_phoneme_lens = item["all_phones_len"] + all_bert_features = item["all_bert_features"] + norm_text = item["norm_text"] + + # batch_phones = batch_phones.to(self.configs.device) + batch_phones_len = batch_phones_len.to(self.configs.device) + all_phoneme_ids = all_phoneme_ids.to(self.configs.device) + all_phoneme_lens = all_phoneme_lens.to(self.configs.device) + all_bert_features = all_bert_features.to(self.configs.device) + if self.configs.is_half: + all_bert_features = all_bert_features.half() + + print(i18n("前端处理后的文本(每句):"), norm_text) + if no_prompt_text : + prompt = None + else: + prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) + + with torch.no_grad(): + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=self.configs.hz * self.configs.max_sec, ) - - - ###### inference ###### - t_34 = 0.0 - t_45 = 0.0 - for item in data: - t3 = ttime() - batch_phones = item["phones"] - all_phoneme_ids = item["all_phones"] - all_phoneme_lens = item["all_phones_len"] - all_bert_features = item["all_bert_features"] - norm_text = item["norm_text"] - - all_phoneme_ids = all_phoneme_ids.to(self.configs.device) - all_phoneme_lens = all_phoneme_lens.to(self.configs.device) - all_bert_features = all_bert_features.to(self.configs.device) - if self.configs.is_half: - all_bert_features = all_bert_features.half() - - print(i18n("前端处理后的文本(每句):"), norm_text) - if no_prompt_text : - prompt = None - else: - prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device) - - with torch.no_grad(): - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - ) - t4 = ttime() - t_34 += t4 - t3 - - refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].clone().to(self.configs.device) - if self.configs.is_half: - refer_audio_spepc = refer_audio_spepc.half() - - ## 直接对batch进行decode 生成的音频会有问题 - # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) - # batch_phones = batch_phones.to(self.configs.device) - # batch_audio_fragment =(self.vits_model.decode( - # pred_semantic, batch_phones, refer_audio_spepc - # ).detach()[:, 0, :]) - # max_audio=torch.abs(batch_audio_fragment).max()#简单防止16bit爆音 - # if max_audio>1: batch_audio_fragment/=max_audio - # batch_audio_fragment = batch_audio_fragment.cpu().numpy() - - ## 改成串行处理 - batch_audio_fragment = [] - for i, idx in enumerate(idx_list): - phones = batch_phones[i].unsqueeze(0).to(self.configs.device) - _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment =(self.vits_model.decode( - _pred_semantic, phones, refer_audio_spepc + t4 = ttime() + t_34 += t4 - t3 + + refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\ + .to(dtype=self.precison, device=self.configs.device) + + batch_audio_fragment = [] + + # ## vits并行推理 method 1 + # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) + # pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0) + # max_len = 0 + # for i in range(0, len(batch_phones)): + # max_len = max(max_len, batch_phones[i].shape[-1]) + # batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len) + # batch_phones = batch_phones.to(self.configs.device) + # batch_audio_fragment = (self.vits_model.batched_decode( + # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc + # )) + + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = (self.vits_model.decode( + all_pred_semantic, _batch_phones,refer_audio_spepc ).detach()[0, 0, :]) - max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 - if max_audio>1: audio_fragment/=max_audio - audio_fragment = torch.cat([audio_fragment, zero_wav], dim=0) - batch_audio_fragment.append( - audio_fragment.cpu().numpy() - ) ###试试重建不带上prompt部分 - - t5 = ttime() - t_45 += t5 - t4 - if return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield self.audio_postprocess(batch_audio_fragment, + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + + + # ## vits串行推理 + # for i, idx in enumerate(idx_list): + # phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + # _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + # audio_fragment =(self.vits_model.decode( + # _pred_semantic, phones, refer_audio_spepc + # ).detach()[0, 0, :]) + # batch_audio_fragment.append( + # audio_fragment + # ) ###试试重建不带上prompt部分 + + t5 = ttime() + t_45 += t5 - t4 + if return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) + yield self.audio_postprocess([batch_audio_fragment], + self.configs.sampling_rate, + batch_index_list, + speed_factor, + split_bucket) + else: + audio.append(batch_audio_fragment) + + if self.stop_flag: + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3), + dtype=np.int16) + return + + if not return_fragment: + print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + yield self.audio_postprocess(audio, self.configs.sampling_rate, batch_index_list, speed_factor, - split_bucket) - else: - audio.append(batch_audio_fragment) - - if self.stop_flag: - yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16) - return - - if not return_fragment: - print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) - yield self.audio_postprocess(audio, - self.configs.sampling_rate, - batch_index_list, - speed_factor, - split_bucket) - - - + split_bucket) + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), + dtype=np.int16) + # 重置模型, 否则会导致显存释放不完全。 + del self.t2s_model + del self.vits_model + self.t2s_model = None + self.vits_model = None + self.init_t2s_weights(self.configs.t2s_weights_path) + self.init_vits_weights(self.configs.vits_weights_path) + finally: + self.empty_cache() + + def empty_cache(self): + try: + if str(self.configs.device) == "cuda": + torch.cuda.empty_cache() + elif str(self.configs.device) == "mps": + torch.mps.empty_cache() + except: + pass + def audio_postprocess(self, - audio:np.ndarray, + audio:List[torch.Tensor], sr:int, batch_index_list:list=None, speed_factor:float=1.0, split_bucket:bool=True)->tuple[int, np.ndarray]: + zero_wav = torch.zeros( + int(self.configs.sampling_rate * 0.3), + dtype=self.precison, + device=self.configs.device + ) + + for i, batch in enumerate(audio): + for j, audio_fragment in enumerate(batch): + max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 + if max_audio>1: audio_fragment/=max_audio + audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) + audio[i][j] = audio_fragment.cpu().numpy() + + if split_bucket: audio = self.recovery_order(audio, batch_index_list) else: - audio = [item for batch in audio for item in batch] + # audio = [item for batch in audio for item in batch] + audio = sum(audio, []) audio = np.concatenate(audio, 0) diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 1504a534..58b2678c 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,4 +1,9 @@ +import os, sys + +from tqdm import tqdm +now_dir = os.getcwd() +sys.path.append(now_dir) import re import torch @@ -7,11 +12,11 @@ from typing import Dict, List, Tuple from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer -from .text_segmentation_method import splits, get_method as get_seg_method +from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method -# from tools.i18n.i18n import I18nAuto +from tools.i18n.i18n import I18nAuto -# i18n = I18nAuto() +i18n = I18nAuto() def get_first(text:str) -> str: pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" @@ -36,6 +41,10 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list: return result + + + + class TextPreprocessor: def __init__(self, bert_model:AutoModelForMaskedLM, tokenizer:AutoTokenizer, device:torch.device): @@ -44,10 +53,14 @@ class TextPreprocessor: self.device = device def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]: + print(i18n("############ 切分文本 ############")) texts = self.pre_seg_text(text, lang, text_split_method) result = [] - for text in texts: + print(i18n("############ 提取文本Bert特征 ############")) + for text in tqdm(texts): phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang) + if phones is None: + continue res={ "phones": phones, "bert_features": bert_features, @@ -60,30 +73,42 @@ class TextPreprocessor: text = text.strip("\n") if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if lang != "en" else "." + text - # print(i18n("实际输入的目标文本:"), text) + print(i18n("实际输入的目标文本:")) + print(text) seg_method = get_seg_method(text_split_method) text = seg_method(text) while "\n\n" in text: text = text.replace("\n\n", "\n") - # print(i18n("实际输入的目标文本(切句后):"), text) + _texts = text.split("\n") _texts = merge_short_text_in_array(_texts, 5) texts = [] + + for text in _texts: # 解决输入目标文本的空行导致报错的问题 if (len(text.strip()) == 0): continue if (text[-1] not in splits): text += "。" if lang != "en" else "." - texts.append(text) + # 解决句子过长导致Bert报错的问题 + if (len(text) > 510): + texts.extend(split_big_text(text)) + else: + texts.append(text) + + print(i18n("实际输入的目标文本(切句后):")) + print(texts) return texts def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]: textlist, langlist = self.seg_text(texts, language) - phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist) + if len(textlist) == 0: + return None, None, None + phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist) return phones, bert_features, norm_text @@ -92,8 +117,10 @@ class TextPreprocessor: textlist=[] langlist=[] if language in ["auto", "zh", "ja"]: - # LangSegment.setfilters(["zh","ja","en","ko"]) + LangSegment.setfilters(["zh","ja","en","ko"]) for tmp in LangSegment.getTexts(text): + if tmp["text"] == "": + continue if tmp["lang"] == "ko": langlist.append("zh") elif tmp["lang"] == "en": @@ -103,18 +130,22 @@ class TextPreprocessor: langlist.append(language if language!="auto" else tmp["lang"]) textlist.append(tmp["text"]) elif language == "en": - # LangSegment.setfilters(["en"]) + LangSegment.setfilters(["en"]) formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) while " " in formattext: formattext = formattext.replace(" ", " ") - textlist.append(formattext) - langlist.append("en") + if formattext != "": + textlist.append(formattext) + langlist.append("en") elif language in ["all_zh","all_ja"]: + formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") language = language.replace("all_","") + if text == "": + return [],[] textlist.append(formattext) langlist.append(language) @@ -139,8 +170,7 @@ class TextPreprocessor: bert_feature = torch.cat(bert_feature_list, dim=1) # phones = sum(phones_list, []) norm_text = ''.join(norm_text_list) - - return phones, bert_feature, norm_text + return phones_list, bert_feature, norm_text def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor: @@ -173,4 +203,8 @@ class TextPreprocessor: dtype=torch.float32, ).to(self.device) - return feature \ No newline at end of file + return feature + + + + diff --git a/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py b/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py index 7bc6b009..2a182b29 100644 --- a/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py +++ b/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py @@ -24,6 +24,32 @@ def register_method(name): splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } +def split_big_text(text, max_len=510): + # 定义全角和半角标点符号 + punctuation = "".join(splits) + + # 切割文本 + segments = re.split('([' + punctuation + '])', text) + + # 初始化结果列表和当前片段 + result = [] + current_segment = '' + + for segment in segments: + # 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段 + if len(current_segment + segment) > max_len: + result.append(current_segment) + current_segment = segment + else: + current_segment += segment + + # 将最后一个片段加入结果列表 + if current_segment: + result.append(current_segment) + + return result + + def split(todo_text): todo_text = todo_text.replace("……", "。").replace("——", ",") @@ -121,6 +147,6 @@ def cut5(inp): if __name__ == '__main__': - method = get_method("cut1") + method = get_method("cut5") print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。")) \ No newline at end of file diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index 5f56a4ec..c772f295 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -2,6 +2,7 @@ custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cuda + flash_attn_enabled: true is_half: true t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth @@ -9,6 +10,7 @@ default: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cpu + flash_attn_enabled: true is_half: false t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 47a226ed..b21b26b5 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -20,7 +20,6 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR) logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) import pdb import torch -# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py infer_ttswebui = os.environ.get("infer_ttswebui", 9872) @@ -28,20 +27,27 @@ infer_ttswebui = int(infer_ttswebui) if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] from config import is_half,is_share +gpt_path = os.environ.get("gpt_path", None) +sovits_path = os.environ.get("sovits_path", None) +cnhubert_base_path = os.environ.get("cnhubert_base_path", None) +bert_path = os.environ.get("bert_path", None) + import gradio as gr from TTS_infer_pack.TTS import TTS, TTS_Config -from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5 -from tools.i18n.i18n import I18nAuto from TTS_infer_pack.text_segmentation_method import get_method +from tools.i18n.i18n import I18nAuto + i18n = I18nAuto() os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。 if torch.cuda.is_available(): device = "cuda" +# elif torch.backends.mps.is_available(): +# device = "mps" else: device = "cpu" - + dict_language = { i18n("中文"): "all_zh",#全部按中文识别 i18n("英文"): "en",#全部按英文识别#######不变 @@ -63,6 +69,16 @@ cut_method = { tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") tts_config.device = device tts_config.is_half = is_half +if gpt_path is not None: + tts_config.t2s_weights_path = gpt_path +if sovits_path is not None: + tts_config.vits_weights_path = sovits_path +if cnhubert_base_path is not None: + tts_config.cnhuhbert_base_path = cnhubert_base_path +if bert_path is not None: + tts_config.bert_base_path = bert_path + +print(tts_config) tts_pipline = TTS(tts_config) gpt_path = tts_config.t2s_weights_path sovits_path = tts_config.vits_weights_path @@ -88,10 +104,12 @@ def inference(text, text_lang, "batch_size":int(batch_size), "speed_factor":float(speed_factor), "split_bucket":split_bucket, - "return_fragment":False, + "return_fragment":False } - yield next(tts_pipline.run(inputs)) - + + for item in tts_pipline.run(inputs): + yield item + def custom_sort_key(s): # 使用正则表达式提取字符串中的数字部分和非数字部分 parts = re.split('(\d+)', s) @@ -167,7 +185,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Row(): with gr.Column(): - batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True) + batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) @@ -179,7 +197,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: value=i18n("凑四句一切"), interactive=True, ) - split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) + with gr.Row(): + split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) # with gr.Column(): output = gr.Audio(label=i18n("输出的语音")) with gr.Row(): diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index a4d22352..75bc6177 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1,5 +1,6 @@ import copy import math +from typing import List import torch from torch import nn from torch.nn import functional as F @@ -985,6 +986,55 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + + + @torch.no_grad() + def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze( + commons.sequence_mask(refer_lengths, refer.size(2)), 1 + ).to(refer.dtype) + ge = self.ref_enc(refer * refer_mask, refer_mask) + + # y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to( + # codes.dtype + # ) + y_lengths = (y_lengths * 2).long().to(codes.device) + text_lengths = text_lengths.long().to(text.device) + # y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + # text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + # 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题? + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate( + quantized, size=int(quantized.shape[-1] * 2), mode="nearest" + ) + + x, m_p, logs_p, y_mask = self.enc_p( + quantized, y_lengths, text, text_lengths, ge + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + + z = self.flow(z_p, y_mask, g=ge, reverse=True) + z_masked = (z * y_mask)[:, :, :] + + # 串行。把padding部分去掉再decode + o_list:List[torch.Tensor] = [] + for i in range(z_masked.shape[0]): + z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0) + o = self.dec(z_slice, g=ge)[0, 0, :].detach() + o_list.append(o) + + # 并行(会有问题)。先decode,再把padding的部分去掉 + # o = self.dec(z_masked, g=ge) + # upsample_rate = int(math.prod(self.upsample_rates)) + # o_lengths = y_lengths*upsample_rate + # o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)] + + return o_list def extract_latent(self, x): ssl = self.ssl_proj(x)