mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat:updated fsdecode and decoder interface
This commit is contained in:
parent
b45cbc3561
commit
5982080939
@ -61,7 +61,7 @@ def logits_to_probs(
|
|||||||
)
|
)
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=temperature.device, dtype=temperature.dtype))
|
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=logits.device, dtype=torch.float))
|
||||||
|
|
||||||
# if top_k is not None: # To be captured by onnx
|
# if top_k is not None: # To be captured by onnx
|
||||||
v, _ = torch.topk(logits, top_k)
|
v, _ = torch.topk(logits, top_k)
|
||||||
@ -115,7 +115,6 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
ar_predict_layer,
|
ar_predict_layer,
|
||||||
loss_fct,
|
loss_fct,
|
||||||
ar_accuracy_metric,
|
ar_accuracy_metric,
|
||||||
top_k,
|
|
||||||
early_stop_num,
|
early_stop_num,
|
||||||
num_layers,
|
num_layers,
|
||||||
):
|
):
|
||||||
@ -126,7 +125,6 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
self.ar_predict_layer = ar_predict_layer
|
self.ar_predict_layer = ar_predict_layer
|
||||||
self.loss_fct = loss_fct
|
self.loss_fct = loss_fct
|
||||||
self.ar_accuracy_metric = ar_accuracy_metric
|
self.ar_accuracy_metric = ar_accuracy_metric
|
||||||
self.top_k = top_k
|
|
||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
@ -205,7 +203,6 @@ class T2SStageDecoder(nn.Module):
|
|||||||
ar_predict_layer,
|
ar_predict_layer,
|
||||||
loss_fct,
|
loss_fct,
|
||||||
ar_accuracy_metric,
|
ar_accuracy_metric,
|
||||||
top_k,
|
|
||||||
early_stop_num,
|
early_stop_num,
|
||||||
num_layers,
|
num_layers,
|
||||||
):
|
):
|
||||||
@ -216,7 +213,6 @@ class T2SStageDecoder(nn.Module):
|
|||||||
self.ar_predict_layer = ar_predict_layer
|
self.ar_predict_layer = ar_predict_layer
|
||||||
self.loss_fct = loss_fct
|
self.loss_fct = loss_fct
|
||||||
self.ar_accuracy_metric = ar_accuracy_metric
|
self.ar_accuracy_metric = ar_accuracy_metric
|
||||||
self.top_k = top_k
|
|
||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
@ -317,7 +313,6 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
self.ar_predict_layer,
|
self.ar_predict_layer,
|
||||||
self.loss_fct,
|
self.loss_fct,
|
||||||
self.ar_accuracy_metric,
|
self.ar_accuracy_metric,
|
||||||
self.top_k,
|
|
||||||
self.early_stop_num,
|
self.early_stop_num,
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
)
|
)
|
||||||
@ -328,23 +323,24 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
self.ar_predict_layer,
|
self.ar_predict_layer,
|
||||||
self.loss_fct,
|
self.loss_fct,
|
||||||
self.ar_accuracy_metric,
|
self.ar_accuracy_metric,
|
||||||
self.top_k,
|
|
||||||
self.early_stop_num,
|
self.early_stop_num,
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, prompts, bert_feature):
|
def forward(self, x, prompts, bert_feature, top_k = None):
|
||||||
# torch.manual_seed(42)
|
# torch.manual_seed(42)
|
||||||
# torch.use_deterministic_algorithms(True)
|
# torch.use_deterministic_algorithms(True)
|
||||||
|
if top_k is None:
|
||||||
|
top_k = self.top_k
|
||||||
early_stop_num = self.early_stop_num
|
early_stop_num = self.early_stop_num
|
||||||
prefix_len = prompts.shape[1]
|
prefix_len = prompts.shape[1]
|
||||||
|
|
||||||
x = self.onnx_encoder(x, bert_feature)
|
x = self.onnx_encoder(x, bert_feature)
|
||||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k)
|
||||||
|
|
||||||
stop = False
|
stop = False
|
||||||
for idx in tqdm(range(1, 1500)):
|
for idx in tqdm(range(1, 1500)):
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k)
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
stop = True
|
stop = True
|
||||||
@ -356,9 +352,10 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# torch.use_deterministic_algorithms(False)
|
# torch.use_deterministic_algorithms(False)
|
||||||
return y, idx
|
return y, idx
|
||||||
|
|
||||||
def infer(self, x, prompts, bert_feature):
|
def infer(self, x, prompts, bert_feature, top_k=None):
|
||||||
# torch.manual_seed(42)
|
# torch.manual_seed(42)
|
||||||
# torch.use_deterministic_algorithms(True)
|
# torch.use_deterministic_algorithms(True)
|
||||||
|
if top_k is None:
|
||||||
top_k = self.top_k
|
top_k = self.top_k
|
||||||
early_stop_num = self.early_stop_num
|
early_stop_num = self.early_stop_num
|
||||||
|
|
||||||
@ -407,7 +404,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0)
|
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=torch.Tensor([1.0]))[0].unsqueeze(0)
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
stop = True
|
stop = True
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user