fix: solved t2s ending problem, and verified infer&forward has same output under deterministic random, fixed topk to 15

This commit is contained in:
zpeng11 2025-08-23 03:31:01 -04:00
parent 63cbb6efa7
commit 3ccd1c0ea3
2 changed files with 21 additions and 13 deletions

View File

@ -7,6 +7,7 @@ from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
from tqdm import tqdm
default_config = {
"embedding_dim": 512,
@ -179,7 +180,7 @@ class T2SFirstStageDecoder(nn.Module):
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
@ -239,7 +240,7 @@ class T2SStageDecoder(nn.Module):
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
@ -315,26 +316,31 @@ class Text2SemanticDecoder(nn.Module):
)
def forward(self, x, prompts, bert_feature):
# torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)
early_stop_num = self.early_stop_num
prefix_len = prompts.shape[1]
x = self.onnx_encoder(x, bert_feature)
y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
stop = False
for idx in range(1, 1500):
enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
y, k, v, y_emb, stage, logits, samples = enco
for idx in tqdm(range(1, 1500)):
enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
y = y[:,:-1]
break
y[0, -1] = 0
# torch.use_deterministic_algorithms(False)
return y, idx
def infer(self, x, prompts, bert_feature):
# torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)
top_k = self.top_k
early_stop_num = self.early_stop_num
@ -356,11 +362,14 @@ class Text2SemanticDecoder(nn.Module):
"first_infer": 1,
"stage": 0,
}
for idx in range(1500):
for idx in tqdm(range(1500)):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
for i in range(len(cache["k"])):
cache["k"][i] = torch.nn.functional.pad(cache["k"][i], (0, 0, 0, 0, 0, 1))
cache["v"][i] = torch.nn.functional.pad(cache["v"][i], (0, 0, 0, 0, 0, 1))
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1:
@ -380,15 +389,14 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
break
y = torch.concat([y, samples], dim=1)
cache["first_infer"] = 0
# torch.use_deterministic_algorithms(False)
return y, idx

View File

@ -68,7 +68,7 @@ def preprocess_text(text:str):
# ref_phones = np.load("playground/ref/ref_phones.npy")
# ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
[ref_phones, ref_bert] = preprocess_text("日江苏苏州荷花市集开张热闹与浪漫交织")
[ref_phones, ref_bert] = preprocess_text("日江苏苏州荷花市集开张热闹与浪漫交织")
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
@ -113,7 +113,7 @@ for idx in tqdm(range(1, 1500)):
})
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
break
y[0, -1] = 0
y = y[:,:-1]
pred_semantic = np.expand_dims(y[:, -idx:], axis=0)