mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat:updated fsdecode and decoder interface
This commit is contained in:
parent
b45cbc3561
commit
5982080939
@ -61,7 +61,7 @@ def logits_to_probs(
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=temperature.device, dtype=temperature.dtype))
|
||||
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=logits.device, dtype=torch.float))
|
||||
|
||||
# if top_k is not None: # To be captured by onnx
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
@ -115,7 +115,6 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
@ -126,7 +125,6 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
self.ar_predict_layer = ar_predict_layer
|
||||
self.loss_fct = loss_fct
|
||||
self.ar_accuracy_metric = ar_accuracy_metric
|
||||
self.top_k = top_k
|
||||
self.early_stop_num = early_stop_num
|
||||
self.num_layers = num_layers
|
||||
|
||||
@ -205,7 +203,6 @@ class T2SStageDecoder(nn.Module):
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
@ -216,7 +213,6 @@ class T2SStageDecoder(nn.Module):
|
||||
self.ar_predict_layer = ar_predict_layer
|
||||
self.loss_fct = loss_fct
|
||||
self.ar_accuracy_metric = ar_accuracy_metric
|
||||
self.top_k = top_k
|
||||
self.early_stop_num = early_stop_num
|
||||
self.num_layers = num_layers
|
||||
|
||||
@ -317,7 +313,6 @@ class Text2SemanticDecoder(nn.Module):
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
@ -328,23 +323,24 @@ class Text2SemanticDecoder(nn.Module):
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
|
||||
def forward(self, x, prompts, bert_feature):
|
||||
def forward(self, x, prompts, bert_feature, top_k = None):
|
||||
# torch.manual_seed(42)
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
early_stop_num = self.early_stop_num
|
||||
prefix_len = prompts.shape[1]
|
||||
|
||||
x = self.onnx_encoder(x, bert_feature)
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k)
|
||||
|
||||
stop = False
|
||||
for idx in tqdm(range(1, 1500)):
|
||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k)
|
||||
y, k, v, y_emb, logits, samples = enco
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
@ -356,10 +352,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# torch.use_deterministic_algorithms(False)
|
||||
return y, idx
|
||||
|
||||
def infer(self, x, prompts, bert_feature):
|
||||
def infer(self, x, prompts, bert_feature, top_k=None):
|
||||
# torch.manual_seed(42)
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
top_k = self.top_k
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
early_stop_num = self.early_stop_num
|
||||
|
||||
x = self.onnx_encoder(x, bert_feature)
|
||||
@ -407,7 +404,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0)
|
||||
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=torch.Tensor([1.0]))[0].unsqueeze(0)
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
|
Loading…
x
Reference in New Issue
Block a user