double inference speed

double inference speed
This commit is contained in:
RVC-Boss 2024-07-07 21:34:28 +08:00 committed by GitHub
parent 6638e66294
commit b0786f2998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,11 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import torch
from tqdm import tqdm
import random
import numpy as np
from tqdm import tqdm
from typing import List
from AR.models.utils import make_pad_mask
from AR.models.utils import (
topk_sampling,
@ -10,7 +13,7 @@ from AR.models.utils import (
logits_to_probs,
multinomial_sample_one_no_sync,
dpo_loss,
make_reject_y,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
@ -35,6 +38,139 @@ default_config = {
}
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
self.w1 = w1
self.b1 = b1
self.w2 = w2
self.b2 = b2
def forward(self, x):
x = F.relu(F.linear(x, self.w1, self.b1))
x = F.linear(x, self.w2, self.b2)
return x
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w
self.qkv_b = qkv_b
self.out_w = out_w
self.out_b = out_b
self.norm_w1 = norm_w1
self.norm_b1 = norm_b1
self.norm_eps1 = norm_eps1
self.norm_w2 = norm_w2
self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2
def process_prompt(self, x, attn_mask: torch.Tensor):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
k_cache = k
v_cache = v
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
attn = F.linear(attn, self.out_w, self.out_b)
x = F.layer_norm(
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(
x + self.mlp.forward(x),
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x, k_cache, v_cache):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
kv_len = k_cache.shape[1]
batch_size = q.shape[0]
q_len = q.shape[1]
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
attn = F.linear(attn, self.out_w, self.out_b)
x = F.layer_norm(
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(
x + self.mlp.forward(x),
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
self.num_blocks: int = num_blocks
self.blocks = blocks
def process_prompt(
self, x, attn_mask: torch.Tensor):
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask)
k_cache.append(k_cache_)
v_cache.append(v_cache_)
return x, k_cache, v_cache
def decode_next_token(
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
@ -89,6 +225,37 @@ class Text2SemanticDecoder(nn.Module):
ignore_index=self.EOS,
)
blocks = []
for i in range(self.num_layers):
layer = self.h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)
# (layer.self_attn.in_proj_weight, layer.self_attn.in_proj_bias)
block = T2SBlock(
self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -116,7 +283,7 @@ class Text2SemanticDecoder(nn.Module):
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@ -177,7 +344,7 @@ class Text2SemanticDecoder(nn.Module):
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2
return loss, acc
@ -246,14 +413,14 @@ class Text2SemanticDecoder(nn.Module):
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -322,15 +489,15 @@ class Text2SemanticDecoder(nn.Module):
return targets[:, :-1], targets[:, 1:]
def infer_panel(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -338,22 +505,14 @@ class Text2SemanticDecoder(nn.Module):
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer": 1,
"stage": 0,
}
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
@ -361,7 +520,6 @@ class Text2SemanticDecoder(nn.Module):
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
cache["y_emb"] = y_emb
ref_free = False
else:
y_emb = None
@ -373,10 +531,10 @@ class Text2SemanticDecoder(nn.Module):
ref_free = True
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
@ -385,64 +543,43 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
for idx in tqdm(range(1500)):
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
)
if idx == 0:
xy_attn_mask = None
logits = logits[:, :-1]
samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0].unsqueeze(0)
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
y = torch.concat([y, samples], dim=1)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
# if prompts.shape[1] == y.shape[1]:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# print("bad zero prediction")
if y.shape[1]==0:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
cache["first_infer"] = 0
if cache["y_emb"] is not None:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
else:
y_emb = self.ar_audio_embedding(y[:, -1:])
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos
y_len = y_pos.shape[1]
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, prompts.shape[1] + idx]
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx-1
return y[:, :-1], idx - 1