diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index 1b602629..2dd3f392 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -13,11 +13,11 @@ from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.optim import ScaledAdam class Text2SemanticLightningModule(LightningModule): - def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False): + def __init__(self, config, output_dir, is_train=True): super().__init__() self.config = config self.top_k = 3 - self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled) + self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) pretrained_s1 = config.get("pretrained_s1") if pretrained_s1 and is_train: # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index dad24405..e12bb11b 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -85,15 +85,22 @@ class T2SBlock: self.norm_b2 = norm_b2 self.norm_eps2 = norm_eps2 - def process_prompt(self, x, attn_mask : torch.Tensor): - q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + @torch.jit.ignore + def to_mask(self, x, padding_mask): + return x*padding_mask if padding_mask is not None else x + + def process_prompt(self, x, attn_mask : torch.Tensor, padding_mask:torch.Tensor=None): + + + q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) batch_size = q.shape[0] q_len = q.shape[1] kv_len = k.shape[1] - - k_cache = k - v_cache = v + + q = self.to_mask(q, padding_mask) + k_cache = self.to_mask(k, padding_mask) + v_cache = self.to_mask(v, padding_mask) q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) @@ -103,13 +110,15 @@ class T2SBlock: attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) - attn = F.linear(attn, self.out_w, self.out_b) + attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) + x = self.to_mask(x + attn, padding_mask) x = F.layer_norm( - x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 ) + x = self.to_mask(x + self.mlp.forward(self.to_mask(x, padding_mask)), padding_mask) x = F.layer_norm( - x + self.mlp.forward(x), + x, [self.hidden_dim], self.norm_w2, self.norm_b2, @@ -138,11 +147,13 @@ class T2SBlock: attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) attn = F.linear(attn, self.out_w, self.out_b) + x = x + attn x = F.layer_norm( - x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 ) + x = x + self.mlp.forward(x) x = F.layer_norm( - x + self.mlp.forward(x), + x, [self.hidden_dim], self.norm_w2, self.norm_b2, @@ -158,11 +169,13 @@ class T2STransformer: self.blocks = blocks def process_prompt( - self, x, attn_mask : torch.Tensor): + self, x, attn_mask : torch.Tensor, + padding_mask : torch.Tensor=None, + ): k_cache : List[torch.Tensor] = [] v_cache : List[torch.Tensor] = [] for i in range(self.num_blocks): - x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask) + x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask) k_cache.append(k_cache_) v_cache.append(v_cache_) return x, k_cache, v_cache @@ -176,7 +189,7 @@ class T2STransformer: class Text2SemanticDecoder(nn.Module): - def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False): + def __init__(self, config, norm_first=False, top_k=3): super(Text2SemanticDecoder, self).__init__() self.model_dim = config["model"]["hidden_dim"] self.embedding_dim = config["model"]["embedding_dim"] @@ -228,47 +241,37 @@ class Text2SemanticDecoder(nn.Module): multidim_average="global", ignore_index=self.EOS, ) + + blocks = [] + + for i in range(self.num_layers): + layer = self.h.layers[i] + t2smlp = T2SMLP( + layer.linear1.weight, + layer.linear1.bias, + layer.linear2.weight, + layer.linear2.bias + ) + + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps + ) + + blocks.append(block) - self.enable_flash_attn(flash_attn_enabled) - - def enable_flash_attn(self, enable:bool=True): - - if not enable: - print("Not Using Flash Attention") - self.infer_panel = self.infer_panel_batch_only - else: - self.infer_panel = self.infer_panel_batch_infer_with_flash_attn - print("Using Flash Attention") - blocks = [] - - for i in range(self.num_layers): - layer = self.h.layers[i] - t2smlp = T2SMLP( - layer.linear1.weight, - layer.linear1.bias, - layer.linear2.weight, - layer.linear2.bias - ) - - block = T2SBlock( - self.num_head, - self.model_dim, - t2smlp, - layer.self_attn.in_proj_weight, - layer.self_attn.in_proj_bias, - layer.self_attn.out_proj.weight, - layer.self_attn.out_proj.bias, - layer.norm1.weight, - layer.norm1.bias, - layer.norm1.eps, - layer.norm2.weight, - layer.norm2.bias, - layer.norm2.eps - ) - - blocks.append(block) - - self.t2s_transformer = T2STransformer(self.num_layers, blocks) + self.t2s_transformer = T2STransformer(self.num_layers, blocks) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -297,8 +300,7 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) - # 取消对y[0]的mask,以防止复读,详见https://github.com/RVC-Boss/GPT-SoVITS/issues/965 - x_attn_mask[:, x_len]=False + # x_attn_mask[:, x_len]=False y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -394,8 +396,7 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) - # 取消对y[0]的mask,以防止复读,详见https://github.com/RVC-Boss/GPT-SoVITS/issues/965 - x_attn_mask[:, x_len]=False + # x_attn_mask[:, x_len]=False y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -461,7 +462,7 @@ class Text2SemanticDecoder(nn.Module): value=True, ) y_attn_mask = F.pad( - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),# diagonal必须为0,否则会导致batch_size>1时的复读情况 + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) @@ -507,29 +508,39 @@ class Text2SemanticDecoder(nn.Module): def infer_panel_batch_infer_with_flash_attn( self, - x:torch.LongTensor, #####全部文本token + x:List[torch.LongTensor], #####全部文本token x_lens:torch.LongTensor, prompts:torch.LongTensor, ####参考音频token - bert_feature:torch.LongTensor, + bert_feature:List[torch.LongTensor], top_k: int = -100, top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs, ): - ## 先对phones进行embedding、对bert_features进行project,再pad到相同长度(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) - # max_len = 0 + # # fp16 会对结果产生影响(和没pad相比) + # bert_feature_dtype = bert_feature[0].dtype + # if not hasattr(self.bert_proj, "dtype"): + # self.bert_proj.dtype = torch.float32 + # self.bert_proj=self.bert_proj.float() + + ## 先对phones进行embedding、对bert_features进行project,再pad到相同长度(padding策略会影响T2S模型生成的结果。) + ## pad之后再进行Linear会有误差(和没pad相比),就离谱。。。 + max_len = kwargs.get("max_len",x_lens.max()) # for x_item, bert_item in zip(x, bert_feature): # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) - # x_list = [self.ar_text_embedding(item) for item in x] - # x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]1时的复读情况 + y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device) - # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) - xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len) - xy_attn_mask = xy_mask.logical_or(xy_padding_mask) + # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device) + _xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len) + xy_attn_mask = xy_mask.logical_or(_xy_padding_mask) xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) - + + xy_padding_mask = ~xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim) + xy_padding_mask = xy_padding_mask.to(dtype=x.dtype) + ###### decode ##### y_list = [None]*y.shape[0] batch_idx_map = list(range(y.shape[0])) idx_list = [None]*y.shape[0] for idx in tqdm(range(1500)): if idx == 0: - xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask) else: xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) @@ -609,7 +623,7 @@ class Text2SemanticDecoder(nn.Module): logits = logits[:, :-1] samples = sample( - logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature + logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature )[0] y = torch.concat([y, samples], dim=1) @@ -659,7 +673,7 @@ class Text2SemanticDecoder(nn.Module): ####################### update next step ################################### y_emb = self.ar_audio_embedding(y[:, -1:]) - xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) if (None in idx_list): for i in range(x.shape[0]): @@ -670,7 +684,37 @@ class Text2SemanticDecoder(nn.Module): return y_list, [0]*x.shape[0] return y_list, idx_list - def infer_panel_batch_only( + def infer_panel_0307(self, + x:List[torch.LongTensor], #####全部文本token + x_lens:torch.LongTensor, + prompts:torch.LongTensor, ####参考音频token + bert_feature:List[torch.LongTensor], + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs + ): + y_list = [] + idx_list = [] + for i in range(len(x)): + y, idx = self.infer_panel_with_flash_attn_only(x[i].unsqueeze(0), + x_lens[i], + prompts[i].unsqueeze(0), + bert_feature[i].unsqueeze(0), + top_k, + top_p, + early_stop_num, + temperature, + repetition_penalty, + **kwargs) + y_list.append(y[0]) + idx_list.append(idx) + + return y_list, idx_list + + def infer_panel_with_flash_attn_only( self, x:torch.LongTensor, #####全部文本token x_lens:torch.LongTensor, @@ -680,22 +724,11 @@ class Text2SemanticDecoder(nn.Module): top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs ): - ## 先对phones进行embedding、对bert_features进行project,再pad到相同长度(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) - # max_len = 0 - # for x_item, bert_item in zip(x, bert_feature): - # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) - # x_list = [self.ar_text_embedding(item) for item in x] - # x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]1时的复读情况 + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) - - xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz*self.num_head, -1, -1).to(x.device) - # xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1) - xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(bsz, src_len, src_len).repeat(self.num_head, 1, 1) - xy_attn_mask = xy_mask.logical_or(xy_padding_mask) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).unsqueeze(0).expand(bsz*self.num_head, -1, -1).view(bsz, self.num_head, src_len, src_len).to(x.device) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf")) - - y_list = [None]*y.shape[0] - batch_idx_map = list(range(y.shape[0])) - idx_list = [None]*y.shape[0] for idx in tqdm(range(1500)): - - xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) + if xy_attn_mask is not None: + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) + else: + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) + logits = self.ar_predict_layer( xy_dec[:, -1] - ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 - # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) - if(idx==0):###第一次跑不能EOS否则没有了 - logits = logits[:, :-1] ###刨除1024终止符号的概率 - samples = sample( - logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature - )[0] - # 本次生成的 semantic_ids 和之前的 y 构成新的 y - # print(samples.shape)#[1,1]#第一个1是bs - y = torch.concat([y, samples], dim=1) + ) + + if idx == 0: + xy_attn_mask = None + logits = logits[:, :-1] + + samples = sample( + logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature + )[0] + + y = torch.concat([y, samples], dim=1) - # 移除已经生成完毕的序列 - reserved_idx_of_batch_for_y = None - if (self.EOS in torch.argmax(logits, dim=-1)) or \ - (self.EOS in samples[:, 0]): ###如果生成到EOS,则停止 - l = samples[:, 0]==self.EOS - removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist() - reserved_idx_of_batch_for_y = torch.where(l==False)[0] - # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y] - for i in removed_idx_of_batch_for_y: - batch_index = batch_idx_map[i] - idx_list[batch_index] = idx - 1 - y_list[batch_index] = y[i, :-1] - - batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()] - - # 只保留未生成完毕的序列 - if reserved_idx_of_batch_for_y is not None: - # index = torch.LongTensor(batch_idx_map).to(y.device) - y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - if cache["y_emb"] is not None: - cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y) - if cache["k"] is not None: - for i in range(self.num_layers): - # 因为kv转置了,所以batch dim是1 - cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y) - cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y) - - if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) stop = True - - if not (None in idx_list): - # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) + + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True if stop: - # if prompts.shape[1] == y.shape[1]: - # y = torch.concat([y, torch.zeros_like(samples)], dim=1) - # print("bad zero prediction") if y.shape[1]==0: y = torch.concat([y, torch.zeros_like(samples)], dim=1) print("bad zero prediction") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") break - - ####################### update next step ################################### - cache["first_infer"] = 0 - if cache["y_emb"] is not None: - y_emb = torch.cat( - [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1 - ) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos[:, -1:] - else: - y_emb = self.ar_audio_embedding(y[:, -1:]) - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - xy_pos = y_pos - y_len = y_pos.shape[1] - ###最右边一列(是错的) - # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) - # xy_attn_mask[:,-1]=False - ###最下面一行(是对的) - xy_attn_mask = torch.zeros( - (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device - ) - - if (None in idx_list): - for i in range(x.shape[0]): - if idx_list[i] is None: - idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 - + ####################### update next step ################################### + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device) + if ref_free: - return y_list, [0]*x.shape[0] - return y_list, idx_list \ No newline at end of file + return y[:, :-1], 0 + return y[:, :-1], idx - 1 diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index a6f25415..4befc0c4 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -37,7 +37,6 @@ default: cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - flash_attn_enabled: true custom: device: cuda @@ -46,7 +45,6 @@ custom: cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - flash_attn_enabled: true """ @@ -66,6 +64,9 @@ def set_seed(seed:int): # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = True + # 开启后会影响精度 + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False except: pass return seed @@ -78,7 +79,6 @@ class TTS_Config: "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", - "flash_attn_enabled": True } configs:dict = None def __init__(self, configs: Union[dict, str]=None): @@ -108,7 +108,6 @@ class TTS_Config: self.device = self.configs.get("device", torch.device("cpu")) self.is_half = self.configs.get("is_half", False) - self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True) self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None) self.bert_base_path = self.configs.get("bert_base_path", None) @@ -141,7 +140,7 @@ class TTS_Config: self.n_speakers:int = 300 self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] - # print(self) + def _load_configs(self, configs_path: str)->dict: with open(configs_path, 'r') as f: @@ -169,7 +168,6 @@ class TTS_Config: "vits_weights_path" : self.vits_weights_path, "bert_base_path" : self.bert_base_path, "cnhuhbert_base_path": self.cnhuhbert_base_path, - "flash_attn_enabled" : self.flash_attn_enabled } return self.config @@ -289,8 +287,7 @@ class TTS: dict_s1 = torch.load(weights_path, map_location=self.configs.device) config = dict_s1["config"] self.configs.max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False, - flash_attn_enabled=self.configs.flash_attn_enabled) + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model.load_state_dict(dict_s1["weight"]) t2s_model = t2s_model.to(self.configs.device) t2s_model = t2s_model.eval() @@ -435,8 +432,6 @@ class TTS: device:torch.device=torch.device("cpu"), precision:torch.dtype=torch.float32, ): - # 但是这里不能套,反而会负优化 - # with torch.no_grad(): _data:list = [] index_and_len_list = [] for idx, item in enumerate(data): @@ -484,8 +479,6 @@ class TTS: norm_text_batch = [] bert_max_len = 0 phones_max_len = 0 - # 但是这里也不能套,反而会负优化 - # with torch.no_grad(): for item in item_list: if prompt_data is not None: all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ @@ -518,11 +511,11 @@ class TTS: max_len = max(bert_max_len, phones_max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) #### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) - all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) - all_bert_features_batch = all_bert_features_list - all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device) - for idx, item in enumerate(all_bert_features_list): - all_bert_features_batch[idx, :, : item.shape[-1]] = item + # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) + # all_bert_features_batch = all_bert_features_list + # all_bert_features_batch = torch.zeros((len(all_bert_features_list), 1024, max_len), dtype=precision, device=device) + # for idx, item in enumerate(all_bert_features_list): + # all_bert_features_batch[idx, :, : item.shape[-1]] = item # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] @@ -539,7 +532,8 @@ class TTS: "all_phones": all_phones_batch, "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), "all_bert_features": all_bert_features_batch, - "norm_text": norm_text_batch + "norm_text": norm_text_batch, + "max_len": max_len, } _data.append(batch) @@ -569,7 +563,6 @@ class TTS: ''' self.stop_flag = True - # 使用装饰器 @torch.no_grad() def run(self, inputs:dict): """ @@ -594,6 +587,8 @@ class TTS: "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35 # float. repetition penalty for T2S model. } returns: tuple[int, np.ndarray]: sampling rate and audio data. @@ -618,9 +613,17 @@ class TTS: seed = inputs.get("seed", -1) seed = -1 if seed in ["", None] else seed actual_seed = set_seed(seed) + parallel_infer = inputs.get("parallel_infer", True) + repetition_penalty = inputs.get("repetition_penalty", 1.35) + + if parallel_infer: + print(i18n("并行推理模式已开启")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn + else: + print(i18n("并行推理模式已关闭")) + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_0307 if return_fragment: - # split_bucket = False print(i18n("分段返回模式已开启")) if split_bucket: split_bucket = False @@ -740,12 +743,13 @@ class TTS: all_phoneme_lens:torch.LongTensor = item["all_phones_len"] all_bert_features:torch.LongTensor = item["all_bert_features"] norm_text:str = item["norm_text"] + max_len = item["max_len"] print(i18n("前端处理后的文本(每句):"), norm_text) if no_prompt_text : prompt = None else: - prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) + prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( @@ -758,6 +762,8 @@ class TTS: top_p=top_p, temperature=temperature, early_stop_num=self.configs.hz * self.configs.max_sec, + max_len=max_len, + repetition_penalty=repetition_penalty, ) t4 = ttime() t_34 += t4 - t3 diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index c772f295..5f56a4ec 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -2,7 +2,6 @@ custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cuda - flash_attn_enabled: true is_half: true t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth @@ -10,7 +9,6 @@ default: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cpu - flash_attn_enabled: true is_half: false t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 2394d300..16ac1469 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -93,7 +93,8 @@ def inference(text, text_lang, text_split_method, batch_size, speed_factor, ref_text_free, split_bucket,fragment_interval, - seed, keep_random + seed, keep_random, parallel_infer, + repetition_penalty ): seed = -1 if keep_random else seed @@ -114,6 +115,8 @@ def inference(text, text_lang, "return_fragment":False, "fragment_interval":fragment_interval, "seed":actual_seed, + "parallel_infer": parallel_infer, + "repetition_penalty": repetition_penalty, } for item in tts_pipeline.run(inputs): yield item, actual_seed @@ -199,6 +202,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) + repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True) with gr.Column(): how_to_cut = gr.Radio( label=i18n("怎么切"), @@ -207,9 +211,11 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: interactive=True, ) with gr.Row(): - split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) + parallel_infer = gr.Checkbox(label=i18n("并行推理(速度更快,但可能增大复读概率)"), value=True, interactive=True, show_label=True) + split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True) seed = gr.Number(label=i18n("随机种子"),value=-1) keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True) + # with gr.Column(): output = gr.Audio(label=i18n("输出的语音")) with gr.Row(): @@ -226,7 +232,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: how_to_cut, batch_size, speed_factor, ref_text_free, split_bucket,fragment_interval, - seed, keep_random + seed, keep_random, parallel_infer, + repetition_penalty ], [output, seed], ) diff --git a/api_v2.py b/api_v2.py index 50180595..9f45ac53 100644 --- a/api_v2.py +++ b/api_v2.py @@ -22,7 +22,7 @@ POST: ```json { "text": "", # str.(required) text to be synthesized - "text_lang": "", # str.(required) language of the text to be synthesized + "text_lang": "", # str.(required) language of the text to be synthesized "ref_audio_path": "", # str.(required) reference audio path. "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio @@ -32,12 +32,14 @@ POST: "text_split_method": "cut5", # str.(optional) text split method, see text_segmentation_method.py for details. "batch_size": 1, # int.(optional) batch size for inference "batch_threshold": 0.75, # float.(optional) threshold for batch splitting. - "split_bucket": true, # bool.(optional) whether to split the batch into multiple buckets. + "split_bucket": true, # bool.(optional) whether to split the batch into multiple buckets. "speed_factor":1.0, # float.(optional) control the speed of the synthesized audio. "fragment_interval":0.3, # float.(optional) to control the interval of the audio fragment. "seed": -1, # int.(optional) random seed for reproducibility. "media_type": "wav", # str.(optional) media type of the output audio, support "wav", "raw", "ogg", "aac". "streaming_mode": false, # bool.(optional) whether to return a streaming response. + "parallel_infer": True, # bool.(optional) whether to use parallel inference. + "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. } ``` @@ -159,6 +161,8 @@ class TTS_Request(BaseModel): seed:int = -1 media_type:str = "wav" streaming_mode:bool = False + parallel_infer:bool = True + repetition_penalty:float = 1.35 ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): @@ -287,6 +291,8 @@ async def tts_handle(req:dict): "seed": -1, # int. random seed for reproducibility. "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". "streaming_mode": False, # bool. whether to return a streaming response. + "parallel_infer": True, # bool.(optional) whether to use parallel inference. + "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. } returns: StreamingResponse: audio stream response. @@ -354,6 +360,8 @@ async def tts_get_endpoint( seed:int = -1, media_type:str = "wav", streaming_mode:bool = False, + parallel_infer:bool = True, + repetition_penalty:float = 1.35 ): req = { "text": text, @@ -373,6 +381,8 @@ async def tts_get_endpoint( "seed":seed, "media_type":media_type, "streaming_mode":streaming_mode, + "parallel_infer":parallel_infer, + "repetition_penalty":float(repetition_penalty) } return await tts_handle(req)