mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat: allow fsdec and sdec to have sampling parames
This commit is contained in:
parent
3ccd1c0ea3
commit
9ed42daa88
@ -28,45 +28,45 @@ def logits_to_probs(
|
|||||||
logits,
|
logits,
|
||||||
previous_tokens=None,
|
previous_tokens=None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k=None,
|
top_k=15,
|
||||||
top_p=None,
|
top_p=1.0,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
):
|
):
|
||||||
previous_tokens = previous_tokens.squeeze()
|
previous_tokens = previous_tokens.squeeze()
|
||||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
# if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx
|
||||||
previous_tokens = previous_tokens.long()
|
previous_tokens = previous_tokens.long()
|
||||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||||
score = torch.where(
|
score = torch.where(
|
||||||
score < 0,
|
score < 0,
|
||||||
score * repetition_penalty,
|
score * repetition_penalty,
|
||||||
score / repetition_penalty,
|
score / repetition_penalty,
|
||||||
)
|
)
|
||||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
# if top_p is not None and top_p < 1.0: #To be captured by onnx
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
cum_probs = torch.cumsum(
|
cum_probs = torch.cumsum(
|
||||||
torch.nn.functional.softmax(
|
torch.nn.functional.softmax(
|
||||||
sorted_logits,
|
sorted_logits,
|
||||||
dim=-1,
|
|
||||||
),
|
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
),
|
||||||
sorted_indices_to_remove = cum_probs > top_p
|
dim=-1,
|
||||||
sorted_indices_to_remove[0] = False # keep at least one option
|
)
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
dim=0,
|
sorted_indices_to_remove[0] = False # keep at least one option
|
||||||
index=sorted_indices,
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
src=sorted_indices_to_remove,
|
dim=0,
|
||||||
)
|
index=sorted_indices,
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
src=sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
logits = logits / max(temperature, 1e-5)
|
logits = logits / max(temperature, 1e-5)
|
||||||
|
|
||||||
if top_k is not None:
|
# if top_k is not None: # To be captured by onnx
|
||||||
v, _ = torch.topk(logits, top_k)
|
v, _ = torch.topk(logits, top_k)
|
||||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||||
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
||||||
|
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
return probs
|
return probs
|
||||||
@ -130,7 +130,16 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
def forward(self, x, prompt):
|
def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
|
||||||
|
if top_k is None:
|
||||||
|
top_k = torch.LongTensor([15]).to(device=x.device)
|
||||||
|
if top_p is None:
|
||||||
|
top_p = torch.FloatTensor([1.0]).to(device=x.device)
|
||||||
|
if repetition_penalty is None:
|
||||||
|
repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = torch.FloatTensor([1.0]).to(device=x.device)
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
x_example = x[:, :, 0] * 0.0
|
x_example = x[:, :, 0] * 0.0
|
||||||
# N, 1, 512
|
# N, 1, 512
|
||||||
@ -180,7 +189,7 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|
||||||
@ -211,7 +220,16 @@ class T2SStageDecoder(nn.Module):
|
|||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
def forward(self, y, k, v, y_emb, x_example):
|
def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
|
||||||
|
if top_k is None:
|
||||||
|
top_k = torch.LongTensor([15]).to(device=y.device)
|
||||||
|
if top_p is None:
|
||||||
|
top_p = torch.FloatTensor([1.0]).to(device=y.device)
|
||||||
|
if repetition_penalty is None:
|
||||||
|
repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = torch.FloatTensor([1.0]).to(device=y.device)
|
||||||
|
|
||||||
cache = {
|
cache = {
|
||||||
"all_stage": self.num_layers,
|
"all_stage": self.num_layers,
|
||||||
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
||||||
@ -240,7 +258,7 @@ class T2SStageDecoder(nn.Module):
|
|||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|
||||||
@ -389,7 +407,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0)
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
stop = True
|
stop = True
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user