2025-09-06 22:58:58 +08:00

68 lines
2.1 KiB
Python

from typing import Protocol
import torch
import torch.nn.functional as F
Tensor = torch.Tensor
class SampleProtocol(Protocol):
@staticmethod
def __call__(
logits: Tensor,
previous_tokens: Tensor,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
) -> Tensor: ...
class sample_naive(SampleProtocol):
@staticmethod
def __call__(
logits: Tensor,
previous_tokens: Tensor,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
):
if temperature <= 1e-5:
probs = F.softmax(logits, dim=-1)
return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32)
if repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
cum_probs[cum_probs > 1] = 1
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
if temperature < 1.0:
logits /= temperature
v, _ = torch.topk(logits, top_k)
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = F.softmax(logits, dim=-1)
q = -torch.log(torch.rand_like(probs))
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
return idx_next