From 2c481ce917a061df967340c64bf5045a65cd47c2 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Thu, 15 Aug 2024 21:17:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E7=94=A8=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 45 ------------------------------- 1 file changed, 45 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index b089d8e0..cb952e6f 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -182,51 +182,6 @@ class T2SBlock: ) return x, k_cache, v_cache - - - # def process_prompt(self, x, attn_mask : torch.Tensor, padding_mask:torch.Tensor=None): - - - # q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) - - # batch_size = q.shape[0] - # q_len = q.shape[1] - # kv_len = k.shape[1] - - # k_cache = k - # v_cache = v - - # q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) - # k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - # v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) - - # attn = scaled_dot_product_attention(q, k, v, attn_mask) - - # attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) - # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) - # attn = F.linear(attn, self.out_w, self.out_b) - - # for i in range(batch_size): - # # mask = padding_mask[i,:,0] - # idx = torch.where(padding_mask[i,:,0]==False)[0] - # x_item = x[i,idx,:].unsqueeze(0) - # attn_item = attn[i,idx,:].unsqueeze(0) - # x_item = x_item + attn_item - # x_item = F.layer_norm( - # x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 - # ) - # x_item = x_item + self.mlp.forward(x_item) - # x_item = F.layer_norm( - # x_item, - # [self.hidden_dim], - # self.norm_w2, - # self.norm_b2, - # self.norm_eps2, - # ) - # x[i,idx,:] = x_item.squeeze(0) - # # x = self.to_mask(x, padding_mask) - # return x, k_cache, v_cache - def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None): q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)