mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat: allow fsdec and sdec to have sampling parames
This commit is contained in:
parent
3ccd1c0ea3
commit
9ed42daa88
@ -28,12 +28,12 @@ def logits_to_probs(
|
|||||||
logits,
|
logits,
|
||||||
previous_tokens=None,
|
previous_tokens=None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k=None,
|
top_k=15,
|
||||||
top_p=None,
|
top_p=1.0,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
):
|
):
|
||||||
previous_tokens = previous_tokens.squeeze()
|
previous_tokens = previous_tokens.squeeze()
|
||||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
# if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx
|
||||||
previous_tokens = previous_tokens.long()
|
previous_tokens = previous_tokens.long()
|
||||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||||
score = torch.where(
|
score = torch.where(
|
||||||
@ -43,7 +43,7 @@ def logits_to_probs(
|
|||||||
)
|
)
|
||||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
# if top_p is not None and top_p < 1.0: #To be captured by onnx
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
cum_probs = torch.cumsum(
|
cum_probs = torch.cumsum(
|
||||||
torch.nn.functional.softmax(
|
torch.nn.functional.softmax(
|
||||||
@ -63,7 +63,7 @@ def logits_to_probs(
|
|||||||
|
|
||||||
logits = logits / max(temperature, 1e-5)
|
logits = logits / max(temperature, 1e-5)
|
||||||
|
|
||||||
if top_k is not None:
|
# if top_k is not None: # To be captured by onnx
|
||||||
v, _ = torch.topk(logits, top_k)
|
v, _ = torch.topk(logits, top_k)
|
||||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||||
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
||||||
@ -130,7 +130,16 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
def forward(self, x, prompt):
|
def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
|
||||||
|
if top_k is None:
|
||||||
|
top_k = torch.LongTensor([15]).to(device=x.device)
|
||||||
|
if top_p is None:
|
||||||
|
top_p = torch.FloatTensor([1.0]).to(device=x.device)
|
||||||
|
if repetition_penalty is None:
|
||||||
|
repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = torch.FloatTensor([1.0]).to(device=x.device)
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
x_example = x[:, :, 0] * 0.0
|
x_example = x[:, :, 0] * 0.0
|
||||||
# N, 1, 512
|
# N, 1, 512
|
||||||
@ -180,7 +189,7 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|
||||||
@ -211,7 +220,16 @@ class T2SStageDecoder(nn.Module):
|
|||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
def forward(self, y, k, v, y_emb, x_example):
|
def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
|
||||||
|
if top_k is None:
|
||||||
|
top_k = torch.LongTensor([15]).to(device=y.device)
|
||||||
|
if top_p is None:
|
||||||
|
top_p = torch.FloatTensor([1.0]).to(device=y.device)
|
||||||
|
if repetition_penalty is None:
|
||||||
|
repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = torch.FloatTensor([1.0]).to(device=y.device)
|
||||||
|
|
||||||
cache = {
|
cache = {
|
||||||
"all_stage": self.num_layers,
|
"all_stage": self.num_layers,
|
||||||
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
||||||
@ -240,7 +258,7 @@ class T2SStageDecoder(nn.Module):
|
|||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|
||||||
@ -389,7 +407,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0)
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
stop = True
|
stop = True
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user