From 4ed0b8bdcc0769997fefa4572928e095a83f0636 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Tue, 24 Sep 2024 15:48:15 +0800 Subject: [PATCH] =?UTF-8?q?=E6=81=A2=E5=A4=8D=20`t2s=5Fmodel.py`=20?= =?UTF-8?q?=E6=8A=8A=E6=94=B9=E5=8A=A8=E7=A7=BB=E5=88=B0=20`export=5Ftorch?= =?UTF-8?q?=5Fscript.py`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 10 +- GPT_SoVITS/export_torch_script.py | 328 ++++++++++++++++++++++++------ 2 files changed, 267 insertions(+), 71 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 12afca21..fb528914 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -83,7 +83,7 @@ class T2SMLP: class T2SBlock: def __init__( self, - num_heads: int, + num_heads, hidden_dim: int, mlp: T2SMLP, qkv_w, @@ -92,12 +92,12 @@ class T2SBlock: out_b, norm_w1, norm_b1, - norm_eps1: float, + norm_eps1, norm_w2, norm_b2, - norm_eps2: float, + norm_eps2, ): - self.num_heads:int = num_heads + self.num_heads = num_heads self.mlp = mlp self.hidden_dim: int = hidden_dim self.qkv_w = qkv_w @@ -266,7 +266,7 @@ class Text2SemanticDecoder(nn.Module): self.norm_first = norm_first self.vocab_size = config["model"]["vocab_size"] self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] - self.p_dropout = float(config["model"]["dropout"]) + self.p_dropout = config["model"]["dropout"] self.EOS = config["model"]["EOS"] self.norm_first = norm_first assert self.EOS == self.vocab_size - 1 diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index eb359f40..c7f1306c 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -34,6 +34,7 @@ default_config = { def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: config = dict_s1["config"] + config["model"]["dropout"] = float(config["model"]["dropout"]) t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model.load_state_dict(dict_s1["weight"]) t2s_model = t2s_model.eval() @@ -105,7 +106,7 @@ def sample( @torch.jit.script -def spectrogram_torch(y, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False): +def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False): hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype) y = torch.nn.functional.pad( y.unsqueeze(1), @@ -155,7 +156,180 @@ class DictToAttrRecursive(dict): del self[item] except KeyError: raise AttributeError(f"Attribute {item} not found") + +@torch.jit.script +class T2SMLP: + def __init__(self, w1, b1, w2, b2): + self.w1 = w1 + self.b1 = b1 + self.w2 = w2 + self.b2 = b2 + + def forward(self, x): + x = F.relu(F.linear(x, self.w1, self.b1)) + x = F.linear(x, self.w2, self.b2) + return x + +@torch.jit.script +class T2SBlock: + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp: T2SMLP, + qkv_w, + qkv_b, + out_w, + out_b, + norm_w1, + norm_b1, + norm_eps1: float, + norm_w2, + norm_b2, + norm_eps2: float, + ): + self.num_heads = num_heads + self.mlp = mlp + self.hidden_dim: int = hidden_dim + self.qkv_w = qkv_w + self.qkv_b = qkv_b + self.out_w = out_w + self.out_b = out_b + self.norm_w1 = norm_w1 + self.norm_b1 = norm_b1 + self.norm_eps1 = norm_eps1 + self.norm_w2 = norm_w2 + self.norm_b2 = norm_b2 + self.norm_eps2 = norm_eps2 + + self.false = torch.tensor(False, dtype=torch.bool) + + @torch.jit.ignore + def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]): + if padding_mask is None: + return x + if padding_mask.dtype == torch.bool: + return x.masked_fill(padding_mask, 0) + else: + return x * padding_mask + + def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None): + q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k.shape[1] + + q = self.to_mask(q, padding_mask) + k_cache = self.to_mask(k, padding_mask) + v_cache = self.to_mask(v, padding_mask) + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) + attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) + + if padding_mask is not None: + for i in range(batch_size): + # mask = padding_mask[i,:,0] + if self.false.device!= padding_mask.device: + self.false = self.false.to(padding_mask.device) + idx = torch.where(padding_mask[i,:,0]==self.false)[0] + x_item = x[i,idx,:].unsqueeze(0) + attn_item = attn[i,idx,:].unsqueeze(0) + x_item = x_item + attn_item + x_item = F.layer_norm( + x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x_item = x_item + self.mlp.forward(x_item) + x_item = F.layer_norm( + x_item, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + x[i,idx,:] = x_item.squeeze(0) + x = self.to_mask(x, padding_mask) + else: + x = x + attn + x = F.layer_norm( + x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor): + q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + k_cache = torch.cat([k_cache, k], dim=1) + v_cache = torch.cat([v_cache, v], dim=1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k_cache.shape[1] + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) + attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = F.linear(attn, self.out_w, self.out_b) + + x = x + attn + x = F.layer_norm( + x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 + ) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + +@torch.jit.script +class T2STransformer: + def __init__(self, num_blocks : int, blocks: list[T2SBlock]): + self.num_blocks : int = num_blocks + self.blocks = blocks + + def process_prompt( + self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None): + k_cache : list[torch.Tensor] = [] + v_cache : list[torch.Tensor] = [] + for i in range(self.num_blocks): + x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask) + k_cache.append(k_cache_) + v_cache.append(v_cache_) + return x, k_cache, v_cache + + def decode_next_token( + self, x:torch.Tensor, + k_cache: list[torch.Tensor], + v_cache: list[torch.Tensor]): + for i in range(self.num_blocks): + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) + return x, k_cache, v_cache + class VitsModel(nn.Module): def __init__(self, vits_path): super().__init__() @@ -189,34 +363,19 @@ class VitsModel(nn.Module): return self.vq_model(pred_semantic, text_seq, refer)[0, 0] class T2SModel(nn.Module): - def __init__(self, config,raw_t2s:Text2SemanticLightningModule, norm_first=False, top_k=3): + def __init__(self,raw_t2s:Text2SemanticLightningModule): super(T2SModel, self).__init__() - self.model_dim = config["model"]["hidden_dim"] - self.embedding_dim = config["model"]["embedding_dim"] - self.num_head = config["model"]["head"] - self.num_layers = config["model"]["n_layer"] - self.vocab_size = config["model"]["vocab_size"] - self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] - self.p_dropout = float(config["model"]["dropout"]) - self.EOS:int = config["model"]["EOS"] - self.norm_first = norm_first + self.model_dim = raw_t2s.model.model_dim + self.embedding_dim = raw_t2s.model.embedding_dim + self.num_head = raw_t2s.model.num_head + self.num_layers = raw_t2s.model.num_layers + self.vocab_size = raw_t2s.model.vocab_size + self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size + # self.p_dropout = float(raw_t2s.model.p_dropout) + self.EOS:int = int(raw_t2s.model.EOS) + self.norm_first = raw_t2s.model.norm_first assert self.EOS == self.vocab_size - 1 self.hz = 50 - self.config = config - - # self.bert_proj = nn.Linear(1024, self.embedding_dim) - # self.ar_text_embedding = TokenEmbedding( - # self.embedding_dim, self.phoneme_vocab_size, self.p_dropout - # ) - # self.ar_text_position = SinePositionalEmbedding( - # self.embedding_dim, dropout=0.1, scale=False, alpha=True - # ) - # self.ar_audio_embedding = TokenEmbedding( - # self.embedding_dim, self.vocab_size, self.p_dropout - # ) - # self.ar_audio_position = SinePositionalEmbedding( - # self.embedding_dim, dropout=0.1, scale=False, alpha=True - # ) self.bert_proj = raw_t2s.model.bert_proj self.ar_text_embedding = raw_t2s.model.ar_text_embedding @@ -225,13 +384,45 @@ class T2SModel(nn.Module): self.ar_audio_position = raw_t2s.model.ar_audio_position # self.t2s_transformer = T2STransformer(self.num_layers, blocks) - self.t2s_transformer = raw_t2s.model.t2s_transformer + # self.t2s_transformer = raw_t2s.model.t2s_transformer + + blocks = [] + h = raw_t2s.model.h + + for i in range(self.num_layers): + layer = h.layers[i] + t2smlp = T2SMLP( + layer.linear1.weight, + layer.linear1.bias, + layer.linear2.weight, + layer.linear2.bias + ) + + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps + ) + + blocks.append(block) + + self.t2s_transformer = T2STransformer(self.num_layers, blocks) # self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) self.ar_predict_layer = raw_t2s.model.ar_predict_layer # self.loss_fct = nn.CrossEntropyLoss(reduction="sum") - self.max_sec = self.config["data"]["max_sec"] - self.top_k = int(self.config["inference"]["top_k"]) + self.max_sec = raw_t2s.config["data"]["max_sec"] + self.top_k = int(raw_t2s.config["inference"]["top_k"]) self.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor): @@ -296,6 +487,10 @@ class T2SModel(nn.Module): # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) logits = self.ar_predict_layer(xy_dec[:, -1]) + + if(idx<11):###至少预测出10个token不然不给停止(0.4s) + logits = logits[:, :-1] + samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] y = torch.concat([y, samples], dim=1) @@ -305,13 +500,13 @@ class T2SModel(nn.Module): if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True if stop: + if y.shape[1] == 0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) break y_emb = self.ar_audio_embedding(y[:, -1:]) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device) - y[0, -1] = 0 - return y[:, -idx:].unsqueeze(0) bert_path = os.environ.get( @@ -362,20 +557,19 @@ class ExportSSLModel(torch.nn.Module): audio = resamplex(ref_audio,src_sr,dst_sr).float() return audio -def export_bert(tokenizer,ref_text,word2ph): - ref_bert_inputs = tokenizer(ref_text, return_tensors="pt") +def export_bert(ref_bert_inputs): ref_bert_inputs = { - 'input_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['input_ids']), - 'attention_mask': torch.jit.annotate(torch.Tensor,ref_bert_inputs['attention_mask']), - 'token_type_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['token_type_ids']), - 'word2ph': torch.jit.annotate(torch.Tensor,word2ph) + 'input_ids': ref_bert_inputs['input_ids'], + 'attention_mask': ref_bert_inputs['attention_mask'], + 'token_type_ids': ref_bert_inputs['token_type_ids'], + 'word2ph': ref_bert_inputs['word2ph'] } bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) my_bert_model = MyBertModel(bert_model) my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs) my_bert_model.save("onnx/bert_model.pt") - print('exported bert') + print('#### exported bert ####') def export(gpt_path, vits_path): tokenizer = AutoTokenizer.from_pretrained(bert_path) @@ -392,18 +586,14 @@ def export(gpt_path, vits_path): bert = MyBertModel(bert_model) - # export_bert(tokenizer,"声音,是有温度的.夜晚的声音,会发光",ref_bert_inputs['word2ph']) + # export_bert(ref_bert_inputs) ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float() ssl = SSLModel() - s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(torch.jit.annotate(torch.Tensor,ref_audio)))) + s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio))) torch.jit.script(s).save("onnx/xw/ssl_model.pt") + print('#### exported ssl ####') - print('exported ssl') - - # ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')]) - # ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')]) - # text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')]) ref_bert = bert(**ref_bert_inputs) text_bert = bert(**text_berf_inputs) ssl_content = ssl(ref_audio) @@ -415,24 +605,30 @@ def export(gpt_path, vits_path): # gpt_path = "GPT_weights_v2/xw-e15.ckpt" dict_s1 = torch.load(gpt_path, map_location="cpu") raw_t2s = get_raw_t2s_model(dict_s1) - t2s_m = T2SModel(dict_s1['config'],raw_t2s,top_k=3) + t2s_m = T2SModel(raw_t2s) t2s_m.eval() t2s = torch.jit.script(t2s_m) - print('exported t2s_m') - - gpt_sovits = GPT_SoVITS(t2s,vits) - ref_audio_sr = ssl.resample(ref_audio,16000,32000) - - # audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert) + print('#### script t2s_m ####') - torch.jit.trace(gpt_sovits,example_inputs=( - torch.jit.annotate(torch.Tensor,ssl_content), - torch.jit.annotate(torch.Tensor,ref_audio_sr), - torch.jit.annotate(torch.Tensor,ref_seq), - torch.jit.annotate(torch.Tensor,text_seq), - torch.jit.annotate(torch.Tensor,ref_bert), - torch.jit.annotate(torch.Tensor,text_bert))).save("onnx/xw/gpt_sovits_model.pt") - print('exported vits') + print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate) + gpt_sovits = GPT_SoVITS(t2s,vits) + gpt_sovits.eval() + ref_audio_sr = s.resample(ref_audio,16000,32000) + print('ref_audio_sr:',ref_audio_sr.shape) + + gpt_sovits_export = torch.jit.trace( + gpt_sovits, + example_inputs=( + ssl_content, + ref_audio_sr, + ref_seq, + text_seq, + ref_bert, + text_bert), + check_trace=False) # 默认是True 但是 check 的时候可能是随机生成的一个奇怪维度的值,导致报错 + + gpt_sovits_export.save("onnx/xw/gpt_sovits_model.pt") + print('#### exported gpt_sovits ####') @torch.jit.script def parse_audio(ref_audio): @@ -459,20 +655,20 @@ class GPT_SoVITS(nn.Module): audio = self.vits(text_seq, pred_semantic, ref_audio_sr) return audio -def test(): +def test(gpt_path, vits_path): tokenizer = AutoTokenizer.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) bert = MyBertModel(bert_model) # bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') - gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" dict_s1 = torch.load(gpt_path, map_location="cpu") raw_t2s = get_raw_t2s_model(dict_s1) - t2s = T2SModel(dict_s1['config'],raw_t2s,top_k=3) + t2s = T2SModel(raw_t2s) t2s.eval() # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') - vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" vits = VitsModel(vits_path) vits.eval() @@ -506,7 +702,7 @@ def test(): text_berf_inputs['word2ph']) #[1,N] - ref_audio = torch.tensor(load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 16000)).float().unsqueeze(0) + ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float() print('ref_audio:',ref_audio.shape) ref_audio_sr = ssl.resample(ref_audio,16000,32000) @@ -537,5 +733,5 @@ def export_symbel(version='v2'): if __name__ == "__main__": export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth") - # test() + # test(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth") # export_symbel() \ No newline at end of file