From 5fe98420694193933306d922596038d5c36c9f58 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Fri, 12 Apr 2024 00:49:44 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86=E4=B8=80=E4=B8=AAbatch=E4=B8=AD?= =?UTF-8?q?=E7=9A=84padding=E7=AD=96=E7=95=A5=EF=BC=8C=E4=BB=8Epadding=20o?= =?UTF-8?q?n=20right=E6=94=B9=E4=B8=BA=E4=BA=86padding=20on=20left?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/data/data_module.py | 1 + GPT_SoVITS/AR/data/dataset.py | 107 ++-- GPT_SoVITS/AR/models/t2s_lightning_module.py | 16 +- GPT_SoVITS/AR/models/t2s_model.py | 511 ++++++++++++++++--- GPT_SoVITS/AR/models/utils.py | 8 +- GPT_SoVITS/TTS_infer_pack/TTS.py | 33 +- GPT_SoVITS/inference_webui.py | 17 +- GPT_SoVITS/s1_train.py | 1 + 8 files changed, 569 insertions(+), 125 deletions(-) diff --git a/GPT_SoVITS/AR/data/data_module.py b/GPT_SoVITS/AR/data/data_module.py index cb947959..431918fa 100644 --- a/GPT_SoVITS/AR/data/data_module.py +++ b/GPT_SoVITS/AR/data/data_module.py @@ -32,6 +32,7 @@ class Text2SemanticDataModule(LightningDataModule): semantic_path=self.train_semantic_path, max_sec=self.config["data"]["max_sec"], pad_val=self.config["data"]["pad_val"], + padding_on_left=self.config["train"]["padding_on_left"], ) self._dev_dataset = self._train_dataset # self._dev_dataset = Text2SemanticDataset( diff --git a/GPT_SoVITS/AR/data/dataset.py b/GPT_SoVITS/AR/data/dataset.py index 1a2ffef1..85011adf 100644 --- a/GPT_SoVITS/AR/data/dataset.py +++ b/GPT_SoVITS/AR/data/dataset.py @@ -55,9 +55,10 @@ class Text2SemanticDataset(Dataset): min_ps_ratio: int = 3, # max value of phoneme/sec max_ps_ratio: int = 25, + padding_on_left:bool=False, ) -> None: super().__init__() - + self.padding_on_left=padding_on_left self.semantic_data = pd.read_csv( semantic_path, delimiter="\t", encoding="utf-8" ) @@ -164,7 +165,9 @@ class Text2SemanticDataset(Dataset): # if len(semantic_ids) > 1000:###########3 # num_deleted_bigger += 1 # continue - + if (len(semantic_ids)+len(phoneme_ids)) > 1000:###########3 + num_deleted_bigger += 1 + continue ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) if ( @@ -173,7 +176,8 @@ class Text2SemanticDataset(Dataset): num_deleted_ps += 1 # print(item_name) continue - + idx_len=[] + self.semantic_phoneme.append((semantic_ids, phoneme_ids)) idx += 1 self.item_names.append(item_name) @@ -253,46 +257,73 @@ class Text2SemanticDataset(Dataset): phoneme_ids_lens: List[int] = [] semantic_ids: List[torch.Tensor] = [] semantic_ids_lens: List[int] = [] - # return - for item in examples: - sample_index.append(item["idx"]) - phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) - semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64)) - phoneme_ids_lens.append(item["phoneme_ids_len"]) - semantic_ids_lens.append(item["semantic_ids_len"]) + if not self.padding_on_left: + for item in examples: + sample_index.append(item["idx"]) + phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) + semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64)) + phoneme_ids_lens.append(item["phoneme_ids_len"]) + semantic_ids_lens.append(item["semantic_ids_len"]) - # pad 0 - phoneme_ids = batch_sequences(phoneme_ids) - semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD) + # pad 0 + phoneme_ids = batch_sequences(phoneme_ids) + semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD) - # # convert each batch to torch.tensor - phoneme_ids = torch.tensor(phoneme_ids) - semantic_ids = torch.tensor(semantic_ids) - phoneme_ids_lens = torch.tensor(phoneme_ids_lens) - semantic_ids_lens = torch.tensor(semantic_ids_lens) - bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens)) - bert_padded.zero_() + # # convert each batch to torch.tensor + phoneme_ids = torch.tensor(phoneme_ids) + semantic_ids = torch.tensor(semantic_ids) + phoneme_ids_lens = torch.tensor(phoneme_ids_lens) + semantic_ids_lens = torch.tensor(semantic_ids_lens) + bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens)) + bert_padded.zero_() - for idx, item in enumerate(examples): - bert = item["bert_feature"] - if bert != None: - bert_padded[idx, :, : bert.shape[-1]] = bert + for idx, item in enumerate(examples): + bert = item["bert_feature"] + if bert != None: + bert_padded[idx, :, : bert.shape[-1]] = bert - return { - # List[int] - "ids": sample_index, - # torch.Tensor (B, max_phoneme_length) - "phoneme_ids": phoneme_ids, - # torch.Tensor (B) - "phoneme_ids_len": phoneme_ids_lens, - # torch.Tensor (B, max_semantic_ids_length) - "semantic_ids": semantic_ids, - # torch.Tensor (B) - "semantic_ids_len": semantic_ids_lens, - # torch.Tensor (B, 1024, max_phoneme_length) - "bert_feature": bert_padded, - } + return { + # List[int] + "ids": sample_index, + # torch.Tensor (B, max_phoneme_length) + "phoneme_ids": phoneme_ids, + # torch.Tensor (B) + "phoneme_ids_len": phoneme_ids_lens, + # torch.Tensor (B, max_semantic_ids_length) + "semantic_ids": semantic_ids, + # torch.Tensor (B) + "semantic_ids_len": semantic_ids_lens, + # torch.Tensor (B, 1024, max_phoneme_length) + "bert_feature": bert_padded, + } + + else: + for item in examples: + sample_index.append(item["idx"]) + phoneme_ids.append(torch.LongTensor(np.array(item["phoneme_ids"], dtype=np.int64))) + semantic_ids.append(torch.LongTensor(np.array(item["semantic_ids"], dtype=np.int64))) + phoneme_ids_lens.append(item["phoneme_ids_len"]) + semantic_ids_lens.append(item["semantic_ids_len"]) + + phoneme_ids_lens = torch.tensor(phoneme_ids_lens) + semantic_ids_lens = torch.tensor(semantic_ids_lens) + bert_features: List[torch.Tensor] = [item["bert_feature"] for item in examples] + + return { + # List[int] + "ids": sample_index, + # List[torch.Tensor] (B, max_phoneme_length) + "phoneme_ids": phoneme_ids, + # torch.Tensor (B) + "phoneme_ids_len": phoneme_ids_lens, + # List[torch.Tensor] (B, max_semantic_ids_length) + "semantic_ids": semantic_ids, + # torch.Tensor (B) + "semantic_ids_len": semantic_ids_lens, + # List[torch.Tensor] (B, 1024, max_phoneme_length) + "bert_feature": bert_features, + } if __name__ == "__main__": diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index 1b602629..19d7872d 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -14,7 +14,7 @@ from AR.modules.optim import ScaledAdam class Text2SemanticLightningModule(LightningModule): def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False): - super().__init__() + super(Text2SemanticLightningModule,self).__init__() self.config = config self.top_k = 3 self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled) @@ -35,7 +35,14 @@ class Text2SemanticLightningModule(LightningModule): def training_step(self, batch: Dict, batch_idx: int): opt = self.optimizers() scheduler = self.lr_schedulers() - forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old + forward=None + if self.config["train"].get("if_dpo",False): + forward=self.model.forward + elif self.config["train"].get("padding_on_left",False): + forward=self.model.forward_old_padding_on_left + else: + forward=self.model.forward_old + loss, acc = forward( batch["phoneme_ids"], batch["phoneme_ids_len"], @@ -56,6 +63,7 @@ class Text2SemanticLightningModule(LightningModule): on_epoch=True, prog_bar=True, sync_dist=True, + batch_size=batch["phoneme_ids_len"].shape[0], ) self.log( "lr", @@ -63,6 +71,7 @@ class Text2SemanticLightningModule(LightningModule): on_epoch=True, prog_bar=True, sync_dist=True, + batch_size=batch["phoneme_ids_len"].shape[0], ) self.log( f"top_{self.top_k}_acc", @@ -71,7 +80,10 @@ class Text2SemanticLightningModule(LightningModule): on_epoch=True, prog_bar=True, sync_dist=True, + batch_size=batch["phoneme_ids_len"].shape[0], ) + if torch.cuda.is_available(): + torch.cuda.empty_cache() def validation_step(self, batch: Dict, batch_idx: int): return diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index b49bcfb5..878af64e 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -1,5 +1,6 @@ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # reference: https://github.com/lifeiteng/vall-e +import math import os, sys now_dir = os.getcwd() sys.path.append(now_dir) @@ -38,7 +39,6 @@ default_config = { "EOS": 1024, } - @torch.jit.script class T2SMLP: def __init__(self, w1, b1, w2, b2): @@ -362,7 +362,8 @@ class Text2SemanticDecoder(nn.Module): loss = loss_1 + loss_2 return loss, acc - + + #padding on right def forward_old(self, x, x_lens, y, y_lens, bert_feature): """ x: phoneme_ids @@ -424,6 +425,91 @@ class Text2SemanticDecoder(nn.Module): loss = F.cross_entropy(logits, targets, reduction="sum") acc = self.ar_accuracy_metric(logits.detach(), targets).item() return loss, acc + + def forward_old_padding_on_left(self, + x:List[torch.Tensor], + x_lens:torch.LongTensor, + y:List[torch.Tensor], + y_lens:torch.LongTensor, + bert_feature:List[torch.Tensor]): + """ + x: phoneme_ids + y: semantic_ids + """ + device = x[0].device + x_len = x_lens.max() + y_len = y_lens.max() + batch_size = len(x) + + xy_pos = torch.zeros((batch_size, x_len+y_len, self.embedding_dim)).to(device) + targets:List[torch.LongTensor] = [] + xy_attn_mask_list = [] + for i in range(batch_size): + padding_len = (x_len-x_lens[i])+(y_len-y_lens[i]) + + x_item=self.ar_text_embedding(x[i].unsqueeze(0)) + if bert_feature[i] is not None: + x_item = x_item + self.bert_proj(bert_feature[i].transpose(0, 1).unsqueeze(0)) + + # x_item = F.pad(x_item, (0, 0, padding_len, 0), value=0) + x_item = self.ar_text_position(x_item).squeeze(0) + y_item = self.ar_audio_position(self.ar_audio_embedding(y[i].unsqueeze(0))).squeeze(0) + + xy_pos[i, padding_len:padding_len+x_lens[i],:] = x_item + xy_pos[i, padding_len+x_lens[i]:,:] = y_item + target = torch.zeros(y_lens[i], dtype=torch.long).to(device) + target[:-1] = y[i][1:] + target[-1] = self.EOS + targets.append(target.unsqueeze(0)) + + x_attn_mask = torch.zeros((x_len+(y_len-y_lens[i]), x_len+y_len), dtype=torch.bool).to(device) + x_attn_mask[:, -y_lens[i]:] = True + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_lens[i], y_lens[i], dtype=torch.bool).to(device), + diagonal=1, + ), + (x_len+(y_len-y_lens[i]), 0), + value=False, + ) + attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + if padding_len>0: + attn_mask[:, :padding_len] = True + xy_attn_mask_list.append(attn_mask) + + xy_attn_mask = torch.stack(xy_attn_mask_list, dim=0) + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=xy_pos.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, torch.finfo(xy_pos.dtype).min) + xy_attn_mask = new_attn_mask + xy_attn_mask = (xy_attn_mask.view(batch_size, 1, x_len+y_len, x_len+y_len) + .expand(-1, self.num_head, -1, -1) + .reshape(batch_size * self.num_head, x_len+y_len, x_len+y_len)) + + + # x 和完整的 y 一次性输入模型 + # xy_pos = torch.concat([x, y_pos], dim=1) + xy_dec, _ = self.h( + (xy_pos, None), + mask=xy_attn_mask, + ) + logits = [self.ar_predict_layer(xy_dec[i, -y_lens[i]:, :].unsqueeze(0)).permute(0, 2, 1) for i in range(batch_size)] + + # loss + # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum + loss = None + acc = None + for i in range(batch_size): + if loss is None: + loss = F.cross_entropy(logits[i], targets[i], reduction="sum") + acc = self.ar_accuracy_metric(logits[i].detach(), targets[i].detach()).item() + else: + loss += F.cross_entropy(logits[i], targets[i], reduction="sum") + acc += self.ar_accuracy_metric(logits[i].detach(), targets[i].detach()).item() + acc /= batch_size + + # loss = F.cross_entropy(logits, targets, reduction="sum") + # acc = self.ar_accuracy_metric(logits.detach(), targets).item() + return loss, acc # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 def infer( @@ -512,85 +598,80 @@ class Text2SemanticDecoder(nn.Module): top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, + repetition_penalty: float = 1.35, + dtype:torch.dtype = torch.float32, ): - # 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) - max_len = 0 - for x_item, bert_item in zip(x, bert_feature): - max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) - x_list = [self.ar_text_embedding(item) for item in x] - x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]0: + attn_mask[:, :padding_len] = True + xy_attn_mask_list.append(attn_mask) + + xy_attn_mask = torch.stack(xy_attn_mask_list, dim=0) + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=xy_pos.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, torch.finfo(xy_pos.dtype).min) + xy_attn_mask = new_attn_mask + xy_attn_mask = (xy_attn_mask.view(batch_size, 1, x_len+y_len, x_len+y_len) + .expand(-1, self.num_head, -1, -1)) + ###### decode ##### - y_list = [None]*y.shape[0] - batch_idx_map = list(range(y.shape[0])) - idx_list = [None]*y.shape[0] + y_list = [None]*batch_size + batch_idx_map = list(range(batch_size)) + idx_list = [None]*batch_size for idx in tqdm(range(1500)): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) @@ -606,7 +687,7 @@ class Text2SemanticDecoder(nn.Module): logits = logits[:, :-1] samples = sample( - logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature + logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature )[0] y = torch.concat([y, samples], dim=1) @@ -659,12 +740,12 @@ class Text2SemanticDecoder(nn.Module): xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) if (None in idx_list): - for i in range(x.shape[0]): + for i in range(batch_size): if idx_list[i] is None: idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 if ref_free: - return y_list, [0]*x.shape[0] + return y_list, [0]*batch_size return y_list, idx_list def infer_panel_batch_only( @@ -677,6 +758,8 @@ class Text2SemanticDecoder(nn.Module): top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs ): # 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) max_len = 0 @@ -772,7 +855,7 @@ class Text2SemanticDecoder(nn.Module): if(idx==0):###第一次跑不能EOS否则没有了 logits = logits[:, :-1] ###刨除1024终止符号的概率 samples = sample( - logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature + logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature )[0] # 本次生成的 semantic_ids 和之前的 y 构成新的 y # print(samples.shape)#[1,1]#第一个1是bs @@ -854,4 +937,298 @@ class Text2SemanticDecoder(nn.Module): if ref_free: return y_list, [0]*x.shape[0] - return y_list, idx_list \ No newline at end of file + return y_list, idx_list + + # padding on right + def infer_panel_batch_infer_with_flash_attn_old( + self, + x:List[torch.LongTensor], #####全部文本token + x_lens:torch.LongTensor, + prompts:torch.LongTensor, ####参考音频token + bert_feature:List[torch.LongTensor], + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + repetition_penalty: float = 1.35, + **kwargs + ): + # 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) + max_len = 0 + for x_item, bert_item in zip(x, bert_feature): + max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) + x_list = [self.ar_text_embedding(item) for item in x] + x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0] early_stop_num) or idx==1499: + print("use early stop num:", early_stop_num) + stop = True + for i, batch_index in enumerate(batch_idx_map): + batch_index = batch_idx_map[i] + idx_list[batch_index] = idx + y_list[batch_index] = y[i, :-1] + + if not (None in idx_list): + stop = True + + if stop: + if y.shape[1]==0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + print("bad zero prediction") + print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + break + + ####################### update next step ################################### + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) + + if (None in idx_list): + for i in range(x.shape[0]): + if idx_list[i] is None: + idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替 + + if ref_free: + return y_list, [0]*x.shape[0] + return y_list, idx_list + + def infer_panel_old( + self, + x, #####全部文本token + x_lens, + prompts, ####参考音频token + bert_feature, + top_k: int = -100, + top_p: int = 100, + early_stop_num: int = -1, + temperature: float = 1.0, + ): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = self.ar_text_position(x) + + # AR Decoder + y = prompts + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + stop = False + # print(1111111,self.num_layers) + cache = { + "all_stage": self.num_layers, + "k": [None] * self.num_layers, ###根据配置自己手写 + "v": [None] * self.num_layers, + # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 + "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行 + # "logits":None,###原版就已经只对结尾求再拼接了,不用管 + # "xy_dec":None,###不需要,本来只需要最后一个做logits + "first_infer": 1, + "stage": 0, + } + ################### first step ########################## + if y is not None: + y_emb = self.ar_audio_embedding(y) + y_len = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + cache["y_emb"] = y_emb + ref_free = False + else: + y_emb = None + y_len = 0 + prefix_len = 0 + y_pos = None + xy_pos = x + y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) + ref_free = True + + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( + x.device + ) + + + for idx in tqdm(range(1500)): + + xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer( + xy_dec[:, -1] + ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 + # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) + if(idx==0):###第一次跑不能EOS否则没有了 + logits = logits[:, :-1] ###刨除1024终止符号的概率 + samples = sample( + logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature + )[0].unsqueeze(0) + # 本次生成的 semantic_ids 和之前的 y 构成新的 y + # print(samples.shape)#[1,1]#第一个1是bs + y = torch.concat([y, samples], dim=1) + + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + print("use early stop num:", early_stop_num) + stop = True + + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) + stop = True + if stop: + # if prompts.shape[1] == y.shape[1]: + # y = torch.concat([y, torch.zeros_like(samples)], dim=1) + # print("bad zero prediction") + if y.shape[1]==0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + print("bad zero prediction") + print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") + break + + ####################### update next step ################################### + cache["first_infer"] = 0 + if cache["y_emb"] is not None: + y_emb = torch.cat( + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1 + ) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + xy_pos = y_pos[:, -1:] + else: + y_emb = self.ar_audio_embedding(y[:, -1:]) + cache["y_emb"] = y_emb + y_pos = self.ar_audio_position(y_emb) + xy_pos = y_pos + y_len = y_pos.shape[1] + + ###最右边一列(是错的) + # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) + # xy_attn_mask[:,-1]=False + ###最下面一行(是对的) + xy_attn_mask = torch.zeros( + (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device + ) + if ref_free: + return y[:, :-1], 0 + return y[:, :-1], idx-1 \ No newline at end of file diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index ce0a98b7..d2a78566 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -11,7 +11,7 @@ def sequence_mask(length, max_length=None): return x.unsqueeze(0) < length.unsqueeze(1) -def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0, padding_left:bool=False) -> torch.Tensor: """ Args: lengths: @@ -35,8 +35,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: n = lengths.size(0) seq_range = torch.arange(0, max_len, device=lengths.device) expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) - - return expaned_lengths >= lengths.unsqueeze(-1) + if padding_left: + return expaned_lengths < (max_len-lengths.unsqueeze(-1)) + else: + return expaned_lengths >= lengths.unsqueeze(-1) # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index b8751519..b55f9d2e 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -63,7 +63,7 @@ def set_seed(seed:int): if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - # torch.backends.cudnn.deterministic = True + torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False # torch.backends.cudnn.enabled = True except: @@ -435,8 +435,7 @@ class TTS: device:torch.device=torch.device("cpu"), precision:torch.dtype=torch.float32, ): - # 但是这里不能套,反而会负优化 - # with torch.no_grad(): + _data:list = [] index_and_len_list = [] for idx, item in enumerate(data): @@ -484,8 +483,7 @@ class TTS: norm_text_batch = [] bert_max_len = 0 phones_max_len = 0 - # 但是这里也不能套,反而会负优化 - # with torch.no_grad(): + for item in item_list: if prompt_data is not None: all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ @@ -533,6 +531,12 @@ class TTS: # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list] # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) + #### padding on left + # all_phones_list = [F.pad(item,(max_len-item.shape[0],0),value=0) for item in all_phones_list] + # all_phones_batch = torch.stack(all_phones_list, dim=0) + # all_bert_features_list = [F.pad(item,(max_len-item.shape[1],0,0,0), value=0) for item in all_bert_features_list] + # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) + batch = { "phones": phones_batch, "phones_len": torch.LongTensor(phones_len_list).to(device), @@ -569,7 +573,6 @@ class TTS: ''' self.stop_flag = True - # 使用装饰器 @torch.no_grad() def run(self, inputs:dict): """ @@ -586,9 +589,10 @@ class TTS: "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling + "repetition_penalty": 1.35, # float. repetition penalty for sampling of T2S model. "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference - "batch_threshold": 0.75, # float. threshold for batch splitting. + "batch_threshold": 1, # float. threshold for batch splitting. "split_bucket: True, # bool. whether to split the batch into multiple buckets. "return_fragment": False, # bool. step by step return the audio fragment. "speed_factor":1.0, # float. control the speed of the synthesized audio. @@ -608,6 +612,7 @@ class TTS: top_k:int = inputs.get("top_k", 5) top_p:float = inputs.get("top_p", 1) temperature:float = inputs.get("temperature", 1) + repetition_penalty: float = inputs.get("repetition_penalty", 1.35) text_split_method:str = inputs.get("text_split_method", "cut0") batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) @@ -618,9 +623,16 @@ class TTS: seed = inputs.get("seed", -1) seed = -1 if seed in ["", None] else seed actual_seed = set_seed(seed) + padding_on_left = inputs.get("padding_on_left", False) + if padding_on_left: + print("padding on left") + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn + else: + print("padding on right") + self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn_old + if return_fragment: - # split_bucket = False print(i18n("分段返回模式已开启")) if split_bucket: split_bucket = False @@ -745,8 +757,7 @@ class TTS: prompt = None else: prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) - - + pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_lens, @@ -756,7 +767,9 @@ class TTS: top_k=top_k, top_p=top_p, temperature=temperature, + repetition_penalty = repetition_penalty, early_stop_num=self.configs.hz * self.configs.max_sec, + dtype = self.precision, ) t4 = ttime() t_34 += t4 - t3 diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index ff72c269..01908039 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -89,12 +89,14 @@ sovits_path = tts_config.vits_weights_path def inference(text, text_lang, ref_audio_path, prompt_text, prompt_lang, top_k, - top_p, temperature, + top_p, temperature, repetition_penalty, text_split_method, batch_size, speed_factor, ref_text_free, - split_bucket,fragment_interval, - seed, + split_bucket, fragment_interval, + seed, keep_random, padding_on_left ): + if keep_random: + seed = random.randrange(1 << 32) actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32) inputs={ "text": text, @@ -105,6 +107,7 @@ def inference(text, text_lang, "top_k": top_k, "top_p": top_p, "temperature": temperature, + "repetition_penalty": repetition_penalty, "text_split_method": cut_method[text_split_method], "batch_size":int(batch_size), "speed_factor":float(speed_factor), @@ -112,6 +115,7 @@ def inference(text, text_lang, "return_fragment":False, "fragment_interval":fragment_interval, "seed":actual_seed, + "padding_on_left":padding_on_left } for item in tts_pipeline.run(inputs): yield item, actual_seed @@ -197,6 +201,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) + repetition_penalty = gr.Slider(minimum=0.5,maximum=4.0,step=0.01,label=i18n("repetition_penalty"),value=1.35,interactive=True) with gr.Column(): how_to_cut = gr.Radio( label=i18n("怎么切"), @@ -207,6 +212,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Row(): split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True) seed = gr.Number(label=i18n("随机种子"),value=-1) + keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True) + padding_on_left = gr.Checkbox(label=i18n("左侧补齐"), value=True, interactive=True, show_label=True) # with gr.Column(): output = gr.Audio(label=i18n("输出的语音")) with gr.Row(): @@ -219,11 +226,11 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: [ text,text_language, inp_ref, prompt_text, prompt_language, - top_k, top_p, temperature, + top_k, top_p, temperature, repetition_penalty, how_to_cut, batch_size, speed_factor, ref_text_free, split_bucket,fragment_interval, - seed + seed, keep_random, padding_on_left ], [output, seed], ) diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 43cfa19a..1ab16b68 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -126,6 +126,7 @@ def main(args): benchmark=False, fast_dev_run=False, strategy = DDPStrategy( + find_unused_parameters=True, process_group_backend="nccl" if platform.system() != "Windows" else "gloo" ) if torch.cuda.is_available() else "auto", precision=config["train"]["precision"],