GPT-SoVITS/GPT_SoVITS/AR/models/t2s_model_onnx.py

367 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
from tqdm import tqdm
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
"num_head": 8,
"num_layers": 12,
"num_codebook": 8,
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024,
}
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
def logits_to_probs(
logits,
previous_tokens=None,
temperature: float = 1.0,
top_k=15,
top_p=1.0,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
# if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
# if top_p is not None and top_p < 1.0: #To be captured by onnx
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(
sorted_logits,
dim=-1,
),
dim=-1,
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=logits.device, dtype=torch.float))
# if top_k is not None: # To be captured by onnx
v, _ = torch.topk(logits, top_k)
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, inf_tensor_value, logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def sample(
logits,
previous_tokens,
**sampling_kwargs,
):
probs = logits_to_probs(
logits=logits,
previous_tokens=previous_tokens,
**sampling_kwargs,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
class OnnxEncoder(nn.Module):
def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
super().__init__()
self.ar_text_embedding = ar_text_embedding
self.bert_proj = bert_proj
self.ar_text_position = ar_text_position
def forward(self, x, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
return self.ar_text_position(x)
class T2SStageDecoder(nn.Module):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
self.h = h
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, x, y, k, v, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None):
if top_k is None:
top_k = torch.LongTensor([15]).to(device=y.device)
if top_p is None:
top_p = torch.FloatTensor([1.0]).to(device=y.device)
if repetition_penalty is None:
repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device)
if temperature is None:
temperature = torch.FloatTensor([1.0]).to(device=y.device)
minus_one = torch.tensor([-1]).to(y.device).to(torch.int64)
cache = {
"all_stage": self.num_layers,
"k": k,
"v": v,
"y_emb": y_emb,
"first_infer": first_infer,
"stage": 0,
"x_seq_len": x_seq_len,
"y_seq_len": y_seq_len,
}
# 运行时判断对最后一个y还是整个y做embedding以正确应对首次和后续
multipled = minus_one * first_infer * y_seq_len
index_offset = torch.min(minus_one, multipled)
y_to_emb = y[:, index_offset:]
# 对y输入进行embedding
y_emb = torch.cat(
[
cache["y_emb"],
self.ar_audio_embedding(y_to_emb),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
# 与x输入拼接做attention准备
xy_pos = torch.concat([x, y_pos], dim=1)
# 运行时判断对最后一个xy_pos还是整个xy_pos做self attention
multipled = minus_one * first_infer * (x_seq_len + y_seq_len) # xy_pos = 1 or x_seq_len + y_seq_len
index_offset = torch.min(minus_one, multipled)
xy_pos = xy_pos[:, index_offset:]
# 构造xy的attention mask
x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool()
y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones(
(y_seq_len, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool)
y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool)
x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
# 运行时判断attension mask使用最后一个还是整个
multipled = minus_one * first_infer * (x_seq_len + y_seq_len)
index_offset = torch.min(minus_one, multipled)
xy_attn_mask = xy_attn_mask[index_offset:, :]
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config["model"]["head"]
self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = float(config["model"]["dropout"])
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.h = TransformerEncoder(
TransformerEncoderLayer(
d_model=self.model_dim,
nhead=self.num_head,
dim_feedforward=self.model_dim * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size,
top_k=top_k,
average="micro",
multidim_average="global",
ignore_index=self.EOS,
)
self.top_k = torch.LongTensor([1])
self.early_stop_num = torch.LongTensor([-1])
def init_onnx(self):
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.stage_decoder = T2SStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature, top_k = None):
# torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)
if top_k is None:
top_k = self.top_k
early_stop_num = self.early_stop_num
prefix_len = prompts.shape[1]
x = self.onnx_encoder(x, bert_feature)
x_seq_len = x.shape[1]
y_seq_len = prompts.shape[1]
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
empty_tensor = torch.empty((1,0,512)).to(torch.float)
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v,
empty_tensor, top_k=top_k,
first_infer=torch.LongTensor([1]),
x_seq_len=x_seq_len, y_seq_len=y_seq_len)
stop = False
for idx in tqdm(range(1, 1500)):
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
y_seq_len = y.shape[1]
enco = self.stage_decoder(empty_tensor, y, k, v, y_emb, top_k=top_k,
first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
y = y[:,:-1]
break
# torch.use_deterministic_algorithms(False)
return y, idx
def infer(self, x, prompts, bert_feature, top_k=None):
# torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)
if top_k is None:
top_k = self.top_k
early_stop_num = self.early_stop_num
x = self.onnx_encoder(x, bert_feature)
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_example = x[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
stop = False
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers,
"v": [None] * self.num_layers,
"y_emb": None,
"first_infer": 1,
"stage": 0,
}
for idx in tqdm(range(1500)):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
for i in range(len(cache["k"])):
cache["k"][i] = torch.nn.functional.pad(cache["k"][i], (0, 0, 0, 0, 0, 1))
cache["v"][i] = torch.nn.functional.pad(cache["v"][i], (0, 0, 0, 0, 0, 1))
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1:
xy_pos = torch.concat([x, y_pos], dim=1)
else:
xy_pos = y_pos[:, -1:]
y_len = y_pos.shape[1]
if cache["first_infer"] == 1:
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
else:
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=torch.Tensor([1.0]))[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
break
y = torch.concat([y, samples], dim=1)
cache["first_infer"] = 0
# torch.use_deterministic_algorithms(False)
return y, idx