mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
fix: solved t2s ending problem, and verified infer&forward has same output under deterministic random, fixed topk to 15
This commit is contained in:
parent
63cbb6efa7
commit
3ccd1c0ea3
@ -7,6 +7,7 @@ from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
from tqdm import tqdm
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
@ -179,7 +180,7 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
@ -239,7 +240,7 @@ class T2SStageDecoder(nn.Module):
|
||||
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
@ -315,26 +316,31 @@ class Text2SemanticDecoder(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, prompts, bert_feature):
|
||||
# torch.manual_seed(42)
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
early_stop_num = self.early_stop_num
|
||||
prefix_len = prompts.shape[1]
|
||||
|
||||
x = self.onnx_encoder(x, bert_feature)
|
||||
y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
stop = False
|
||||
for idx in range(1, 1500):
|
||||
enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
|
||||
y, k, v, y_emb, stage, logits, samples = enco
|
||||
for idx in tqdm(range(1, 1500)):
|
||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
y, k, v, y_emb, logits, samples = enco
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
stop = True
|
||||
if stop:
|
||||
y = y[:,:-1]
|
||||
break
|
||||
y[0, -1] = 0
|
||||
# torch.use_deterministic_algorithms(False)
|
||||
return y, idx
|
||||
|
||||
def infer(self, x, prompts, bert_feature):
|
||||
# torch.manual_seed(42)
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
top_k = self.top_k
|
||||
early_stop_num = self.early_stop_num
|
||||
|
||||
@ -356,11 +362,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
"first_infer": 1,
|
||||
"stage": 0,
|
||||
}
|
||||
for idx in range(1500):
|
||||
for idx in tqdm(range(1500)):
|
||||
if cache["first_infer"] == 1:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
else:
|
||||
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
||||
for i in range(len(cache["k"])):
|
||||
cache["k"][i] = torch.nn.functional.pad(cache["k"][i], (0, 0, 0, 0, 0, 1))
|
||||
cache["v"][i] = torch.nn.functional.pad(cache["v"][i], (0, 0, 0, 0, 0, 1))
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
if cache["first_infer"] == 1:
|
||||
@ -380,15 +389,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
stop = True
|
||||
if stop:
|
||||
if prompts.shape[1] == y.shape[1]:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
break
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
cache["first_infer"] = 0
|
||||
# torch.use_deterministic_algorithms(False)
|
||||
return y, idx
|
||||
|
@ -68,7 +68,7 @@ def preprocess_text(text:str):
|
||||
|
||||
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
||||
# ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
|
||||
[ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织")
|
||||
[ref_phones, ref_bert] = preprocess_text("近日江苏苏州荷花市集开张热闹与浪漫交织")
|
||||
|
||||
|
||||
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
||||
@ -113,7 +113,7 @@ for idx in tqdm(range(1, 1500)):
|
||||
})
|
||||
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
|
||||
break
|
||||
y[0, -1] = 0
|
||||
y = y[:,:-1]
|
||||
|
||||
|
||||
pred_semantic = np.expand_dims(y[:, -idx:], axis=0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user