mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-29 01:18:13 +08:00
105 lines
2.9 KiB
Python
105 lines
2.9 KiB
Python
from typing import Callable, Protocol, TypeVar, cast
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from typing_extensions import ParamSpec
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
Tensor = torch.Tensor
|
|
|
|
|
|
def script(fn: Callable[P, R]) -> Callable[P, R]:
|
|
scripted = torch.jit.script(fn)
|
|
return cast(Callable[P, R], scripted)
|
|
|
|
|
|
@script
|
|
def apply_repetition_penalty(logits: Tensor, previous_tokens: Tensor, repetition_penalty: float):
|
|
previous_tokens = previous_tokens.long()
|
|
score = torch.gather(logits, dim=1, index=previous_tokens)
|
|
score = torch.where(
|
|
score < 0,
|
|
score * repetition_penalty,
|
|
score / repetition_penalty,
|
|
)
|
|
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
|
return logits
|
|
|
|
|
|
@script
|
|
def apply_greedy_sampling(logits: Tensor):
|
|
return torch.argmax(logits, dim=-1, keepdim=True).to(dtype=torch.int32)
|
|
|
|
|
|
@script
|
|
def apply_temperature(logits: Tensor, temperature: float):
|
|
return logits / temperature
|
|
|
|
|
|
@script
|
|
def apply_top_k(logits: Tensor, top_k: int):
|
|
v, _ = torch.topk(logits, top_k)
|
|
pivot = v[:, -1].unsqueeze(-1)
|
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
|
return logits
|
|
|
|
|
|
@script
|
|
def apply_top_p(logits: Tensor, top_p: float):
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
cum_probs[cum_probs > 1] = 1
|
|
sorted_indices_to_remove = cum_probs > top_p
|
|
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
return logits
|
|
|
|
|
|
@script
|
|
def apply_sampling(logits: Tensor):
|
|
probs = F.softmax(logits, dim=-1)
|
|
q = -torch.log(torch.rand_like(probs))
|
|
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
|
return idx_next
|
|
|
|
|
|
class SampleProtocol(Protocol):
|
|
@staticmethod
|
|
def __call__(
|
|
logits: Tensor,
|
|
previous_tokens: Tensor,
|
|
repetition_penalty: float,
|
|
temperature: float,
|
|
top_k: int,
|
|
top_p: float,
|
|
) -> Tensor: ...
|
|
|
|
|
|
class sample_naive(SampleProtocol):
|
|
@staticmethod
|
|
def __call__(
|
|
logits: Tensor,
|
|
previous_tokens: Tensor,
|
|
repetition_penalty: float = 1.35,
|
|
temperature: float = 1.0,
|
|
top_k: int = 15,
|
|
top_p: float = 1.0,
|
|
):
|
|
if repetition_penalty != 1.0:
|
|
logits = apply_repetition_penalty(logits, previous_tokens, repetition_penalty)
|
|
|
|
if temperature <= 1e-5:
|
|
return apply_greedy_sampling(logits)
|
|
elif temperature < 1.0:
|
|
logits = apply_temperature(logits, temperature)
|
|
|
|
if top_k < 1025:
|
|
logits = apply_top_k(logits, top_k)
|
|
|
|
if top_p < 1.0:
|
|
logits = apply_top_p(logits, top_p)
|
|
|
|
return apply_sampling(logits)
|