[fast_inference] 兼容torch.jit.script (#1489)

* 放弃了在t2s模型中使用@torch.jit.script,确保pytorch环境之间的兼容性

* 优化@torch.jit.script策略
This commit is contained in:
ChasonJiang 2024-08-16 16:55:36 +08:00 committed by GitHub
parent 089636424b
commit 5dfce9a3f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,7 +39,7 @@ default_config = {
"EOS": 1024,
}
# @torch.jit.script
@torch.jit.script
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
@ -82,7 +82,7 @@ class T2SMLP:
return x
# @torch.jit.script
@torch.jit.script
class T2SBlock:
def __init__(
self,
@ -114,6 +114,8 @@ class T2SBlock:
self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
if padding_mask is None:
@ -150,7 +152,9 @@ class T2SBlock:
if padding_mask is not None:
for i in range(batch_size):
# mask = padding_mask[i,:,0]
idx = torch.where(padding_mask[i,:,0]==False)[0]
if self.false.device!= padding_mask.device:
self.false = self.false.to(padding_mask.device)
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
x_item = x[i,idx,:].unsqueeze(0)
attn_item = attn[i,idx,:].unsqueeze(0)
x_item = x_item + attn_item
@ -218,7 +222,7 @@ class T2SBlock:
return x, k_cache, v_cache
# @torch.jit.script
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks