mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
from typing import Protocol
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
Tensor = torch.Tensor
|
|
|
|
|
|
class SampleProtocol(Protocol):
|
|
@staticmethod
|
|
def __call__(
|
|
logits: Tensor,
|
|
previous_tokens: Tensor,
|
|
temperature: float,
|
|
top_k: int,
|
|
top_p: float,
|
|
repetition_penalty: float,
|
|
) -> Tensor: ...
|
|
|
|
|
|
class sample_naive(SampleProtocol):
|
|
@staticmethod
|
|
def __call__(
|
|
logits: Tensor,
|
|
previous_tokens: Tensor,
|
|
temperature: float,
|
|
top_k: int,
|
|
top_p: float,
|
|
repetition_penalty: float,
|
|
):
|
|
if temperature <= 1e-5:
|
|
probs = F.softmax(logits, dim=-1)
|
|
return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32)
|
|
|
|
if repetition_penalty != 1.0:
|
|
previous_tokens = previous_tokens.long()
|
|
score = torch.gather(logits, dim=1, index=previous_tokens)
|
|
score = torch.where(
|
|
score < 0,
|
|
score * repetition_penalty,
|
|
score / repetition_penalty,
|
|
)
|
|
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
|
|
|
if top_p < 1.0:
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
cum_probs[cum_probs > 1] = 1
|
|
sorted_indices_to_remove = cum_probs > top_p
|
|
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
|
)
|
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
|
if temperature < 1.0:
|
|
logits /= temperature
|
|
|
|
v, _ = torch.topk(logits, top_k)
|
|
pivot = v[:, -1].unsqueeze(-1)
|
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
q = -torch.log(torch.rand_like(probs))
|
|
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
|
|
|
return idx_next
|