mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat:successfully unified first step and following step
This commit is contained in:
parent
d413a4f5b1
commit
c85ee3d521
@ -105,110 +105,6 @@ class OnnxEncoder(nn.Module):
|
|||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
return self.ar_text_position(x)
|
return self.ar_text_position(x)
|
||||||
|
|
||||||
|
|
||||||
class T2SFirstStageDecoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ar_audio_embedding,
|
|
||||||
ar_audio_position,
|
|
||||||
h,
|
|
||||||
ar_predict_layer,
|
|
||||||
loss_fct,
|
|
||||||
ar_accuracy_metric,
|
|
||||||
early_stop_num,
|
|
||||||
num_layers,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.ar_audio_embedding = ar_audio_embedding
|
|
||||||
self.ar_audio_position = ar_audio_position
|
|
||||||
self.h = h
|
|
||||||
self.ar_predict_layer = ar_predict_layer
|
|
||||||
self.loss_fct = loss_fct
|
|
||||||
self.ar_accuracy_metric = ar_accuracy_metric
|
|
||||||
self.early_stop_num = early_stop_num
|
|
||||||
self.num_layers = num_layers
|
|
||||||
|
|
||||||
def forward(self, x, prompt, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None):
|
|
||||||
if top_k is None:
|
|
||||||
top_k = torch.LongTensor([15]).to(device=x.device)
|
|
||||||
if top_p is None:
|
|
||||||
top_p = torch.FloatTensor([1.0]).to(device=x.device)
|
|
||||||
if repetition_penalty is None:
|
|
||||||
repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device)
|
|
||||||
if temperature is None:
|
|
||||||
temperature = torch.FloatTensor([1.0]).to(device=x.device)
|
|
||||||
minus_one = torch.tensor([-1]).to(x.device).to(torch.int64)
|
|
||||||
|
|
||||||
y = prompt
|
|
||||||
x_example = x[:, :, 0] * 0.0
|
|
||||||
# N, 1, 512
|
|
||||||
cache = {
|
|
||||||
"all_stage": self.num_layers,
|
|
||||||
"k": None,
|
|
||||||
"v": None,
|
|
||||||
"y_emb": y_emb,
|
|
||||||
"first_infer": first_infer,
|
|
||||||
"stage": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 运行时判断对最后一个y还是整个y做embedding,以正确应对首次和后续
|
|
||||||
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(y)[1]
|
|
||||||
index_offset = torch.min(minus_one, multipled)
|
|
||||||
y_to_emb = y[:, index_offset:]
|
|
||||||
|
|
||||||
print("y_emb shape:", y_emb.shape)
|
|
||||||
|
|
||||||
y_emb = torch.cat(
|
|
||||||
[
|
|
||||||
cache["y_emb"],
|
|
||||||
self.ar_audio_embedding(y_to_emb),
|
|
||||||
],
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
cache["y_emb"] = y_emb
|
|
||||||
|
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
|
||||||
|
|
||||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
|
||||||
|
|
||||||
# 运行时判断对最后一个xy_pos还是整个xy_pos做self attention
|
|
||||||
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_pos)[1]
|
|
||||||
index_offset = torch.min(minus_one, multipled)
|
|
||||||
xy_pos = xy_pos[:, index_offset:]
|
|
||||||
|
|
||||||
# 构造xy的attention mask
|
|
||||||
x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool()
|
|
||||||
y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64)
|
|
||||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
|
||||||
torch.ones(
|
|
||||||
(y_seq_len, 1),
|
|
||||||
dtype=torch.int64,
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
y_attn_mask = y_attn_mask > 0
|
|
||||||
|
|
||||||
x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool)
|
|
||||||
y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool)
|
|
||||||
|
|
||||||
x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1)
|
|
||||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
|
||||||
print("first iter xy_attn_mask shape:", xy_attn_mask.shape)
|
|
||||||
cache["k"] = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
|
|
||||||
cache["v"] = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
|
|
||||||
|
|
||||||
print("first iter cache k shape:", cache["k"].shape, 'cache v shape:', cache["v"].shape)
|
|
||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
|
||||||
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
|
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
|
||||||
|
|
||||||
return y, cache["k"], cache["v"], cache["y_emb"], x_example
|
|
||||||
|
|
||||||
|
|
||||||
class T2SStageDecoder(nn.Module):
|
class T2SStageDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -231,7 +127,7 @@ class T2SStageDecoder(nn.Module):
|
|||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
def forward(self, x, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None):
|
def forward(self, x, y, k, v, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None):
|
||||||
if top_k is None:
|
if top_k is None:
|
||||||
top_k = torch.LongTensor([15]).to(device=y.device)
|
top_k = torch.LongTensor([15]).to(device=y.device)
|
||||||
if top_p is None:
|
if top_p is None:
|
||||||
@ -244,8 +140,8 @@ class T2SStageDecoder(nn.Module):
|
|||||||
|
|
||||||
cache = {
|
cache = {
|
||||||
"all_stage": self.num_layers,
|
"all_stage": self.num_layers,
|
||||||
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
"k": k,
|
||||||
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
|
"v": v,
|
||||||
"y_emb": y_emb,
|
"y_emb": y_emb,
|
||||||
"first_infer": first_infer,
|
"first_infer": first_infer,
|
||||||
"stage": 0,
|
"stage": 0,
|
||||||
@ -273,10 +169,29 @@ class T2SStageDecoder(nn.Module):
|
|||||||
index_offset = torch.min(minus_one, multipled)
|
index_offset = torch.min(minus_one, multipled)
|
||||||
xy_pos = xy_pos[:, index_offset:]
|
xy_pos = xy_pos[:, index_offset:]
|
||||||
|
|
||||||
x_example_len = torch.onnx.operators.shape_as_tensor(x_example)[1]
|
# 构造xy的attention mask
|
||||||
y_example_len = torch.onnx.operators.shape_as_tensor(y_pos)[1]
|
x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool()
|
||||||
xy_attn_mask = torch.zeros((1, x_example_len + y_example_len), dtype=torch.bool)
|
y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64)
|
||||||
print('xy_attn_mask shape:', xy_attn_mask.shape)
|
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||||
|
torch.ones(
|
||||||
|
(y_seq_len, 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
y_attn_mask = y_attn_mask > 0
|
||||||
|
|
||||||
|
x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool)
|
||||||
|
y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool)
|
||||||
|
|
||||||
|
x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1)
|
||||||
|
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||||
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
|
|
||||||
|
# 运行时判断attension mask使用最后一个还是整个
|
||||||
|
multipled = minus_one * first_infer * torch.onnx.operators.shape_as_tensor(xy_attn_mask)[0]
|
||||||
|
index_offset = torch.min(minus_one, multipled)
|
||||||
|
xy_attn_mask = xy_attn_mask[index_offset:, :]
|
||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
@ -332,16 +247,6 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
def init_onnx(self):
|
def init_onnx(self):
|
||||||
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||||
self.first_stage_decoder = T2SFirstStageDecoder(
|
|
||||||
self.ar_audio_embedding,
|
|
||||||
self.ar_audio_position,
|
|
||||||
self.h,
|
|
||||||
self.ar_predict_layer,
|
|
||||||
self.loss_fct,
|
|
||||||
self.ar_accuracy_metric,
|
|
||||||
self.early_stop_num,
|
|
||||||
self.num_layers,
|
|
||||||
)
|
|
||||||
self.stage_decoder = T2SStageDecoder(
|
self.stage_decoder = T2SStageDecoder(
|
||||||
self.ar_audio_embedding,
|
self.ar_audio_embedding,
|
||||||
self.ar_audio_position,
|
self.ar_audio_position,
|
||||||
@ -362,14 +267,21 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
prefix_len = prompts.shape[1]
|
prefix_len = prompts.shape[1]
|
||||||
|
|
||||||
x = self.onnx_encoder(x, bert_feature)
|
x = self.onnx_encoder(x, bert_feature)
|
||||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, torch.empty((1,0,512)).to(torch.float), top_k=top_k,
|
|
||||||
first_infer=torch.LongTensor([1]),
|
init_k = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float)
|
||||||
x_seq_len=x.shape[1], y_seq_len=prompts.shape[1])
|
init_v = torch.zeros((self.num_layers, (x.shape[1] + prompts.shape[1]), 1, 512), dtype=torch.float)
|
||||||
|
|
||||||
|
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v,
|
||||||
|
torch.empty((1,0,512)).to(torch.float), top_k=top_k,
|
||||||
|
first_infer=torch.LongTensor([1]),
|
||||||
|
x_seq_len=x.shape[1], y_seq_len=prompts.shape[1])
|
||||||
|
|
||||||
stop = False
|
stop = False
|
||||||
for idx in tqdm(range(1, 1500)):
|
for idx in tqdm(range(1, 1500)):
|
||||||
enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, x_example, top_k=top_k,
|
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
|
||||||
first_infer=torch.LongTensor([0]))
|
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
|
||||||
|
enco = self.stage_decoder( torch.empty((1,0,512)).to(torch.float) ,y, k, v, y_emb, top_k=top_k,
|
||||||
|
first_infer=torch.LongTensor([0]), x_seq_len=x.shape[1], y_seq_len=y.shape[1])
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
stop = True
|
stop = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user