mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-15 06:28:11 +08:00
Enhance sampling functions in TTS by adding support for previous token masks in logits_to_probs. Implement batch processing for sampling with padded token sequences and contiguous sampling groups. Refactor sampling logic in T2S scheduler to utilize new functionalities, improving efficiency and flexibility in token generation.
This commit is contained in:
parent
845b181360
commit
a45e171ff5
@ -147,6 +147,7 @@ def multinomial_sample_one_no_sync(
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
previous_tokens: Optional[torch.Tensor] = None,
|
||||
previous_token_mask: Optional[torch.Tensor] = None,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
@ -158,13 +159,27 @@ def logits_to_probs(
|
||||
# pdb.set_trace()
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
if previous_token_mask is None:
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
else:
|
||||
previous_token_mask = previous_token_mask.to(dtype=torch.bool, device=logits.device)
|
||||
if previous_token_mask.any():
|
||||
batch_index = torch.arange(logits.size(0), device=logits.device).unsqueeze(1).expand_as(previous_tokens)
|
||||
valid_batch_index = batch_index[previous_token_mask]
|
||||
valid_token_index = previous_tokens[previous_token_mask]
|
||||
score = logits[valid_batch_index, valid_token_index]
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits[valid_batch_index, valid_token_index] = score
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
@ -192,9 +207,15 @@ def logits_to_probs(
|
||||
def sample(
|
||||
logits,
|
||||
previous_tokens: Optional[torch.Tensor] = None,
|
||||
previous_token_mask: Optional[torch.Tensor] = None,
|
||||
**sampling_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
|
||||
probs = logits_to_probs(
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
previous_token_mask=previous_token_mask,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from AR.models.utils import make_pad_mask_left, sample
|
||||
from AR.models.utils import logits_to_probs, make_pad_mask_left, multinomial_sample_one_no_sync, sample
|
||||
|
||||
|
||||
def _sync_device(device: Any) -> None:
|
||||
@ -277,6 +277,90 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device:
|
||||
)
|
||||
|
||||
|
||||
def _pad_token_sequences(
|
||||
token_sequences: Sequence[torch.LongTensor],
|
||||
) -> Tuple[torch.LongTensor, torch.BoolTensor]:
|
||||
if not token_sequences:
|
||||
raise ValueError("token_sequences 不能为空")
|
||||
device = token_sequences[0].device
|
||||
max_len = max(int(sequence.shape[0]) for sequence in token_sequences)
|
||||
padded = torch.zeros((len(token_sequences), max_len), dtype=token_sequences[0].dtype, device=device)
|
||||
mask = torch.zeros((len(token_sequences), max_len), dtype=torch.bool, device=device)
|
||||
for row_index, sequence in enumerate(token_sequences):
|
||||
seq_len = int(sequence.shape[0])
|
||||
padded[row_index, :seq_len] = sequence
|
||||
mask[row_index, :seq_len] = True
|
||||
return padded, mask
|
||||
|
||||
|
||||
def _sampling_group_key(
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
repetition_penalty: float,
|
||||
trim_eos: bool,
|
||||
) -> Tuple[int, float, float, float, bool]:
|
||||
return (
|
||||
int(top_k),
|
||||
float(top_p),
|
||||
float(temperature),
|
||||
float(repetition_penalty),
|
||||
bool(trim_eos),
|
||||
)
|
||||
|
||||
|
||||
def _iter_contiguous_sampling_groups(
|
||||
sampling_keys: Sequence[Tuple[int, float, float, float, bool]],
|
||||
) -> List[Tuple[Tuple[int, float, float, float, bool], List[int]]]:
|
||||
groups: List[Tuple[Tuple[int, float, float, float, bool], List[int]]] = []
|
||||
if not sampling_keys:
|
||||
return groups
|
||||
current_key = sampling_keys[0]
|
||||
current_indices: List[int] = [0]
|
||||
for index in range(1, len(sampling_keys)):
|
||||
key = sampling_keys[index]
|
||||
if key == current_key:
|
||||
current_indices.append(index)
|
||||
continue
|
||||
groups.append((current_key, current_indices))
|
||||
current_key = key
|
||||
current_indices = [index]
|
||||
groups.append((current_key, current_indices))
|
||||
return groups
|
||||
|
||||
|
||||
def _batched_sample_by_group(
|
||||
logits: torch.Tensor,
|
||||
histories: Sequence[torch.LongTensor],
|
||||
sampling_keys: Sequence[Tuple[int, float, float, float, bool]],
|
||||
) -> Tuple[List[torch.Tensor], List[int]]:
|
||||
sampled_list: List[Optional[torch.Tensor]] = [None] * len(histories)
|
||||
argmax_list: List[Optional[int]] = [None] * len(histories)
|
||||
for group_key, group_indices in _iter_contiguous_sampling_groups(sampling_keys):
|
||||
top_k, top_p, temperature, repetition_penalty, trim_eos = group_key
|
||||
index_tensor = torch.tensor(group_indices, dtype=torch.long, device=logits.device)
|
||||
group_logits = torch.index_select(logits, dim=0, index=index_tensor)
|
||||
if trim_eos:
|
||||
group_logits = group_logits[:, :-1]
|
||||
group_histories = [histories[index] for index in group_indices]
|
||||
padded_histories, history_mask = _pad_token_sequences(group_histories)
|
||||
probs = logits_to_probs(
|
||||
logits=group_logits,
|
||||
previous_tokens=padded_histories,
|
||||
previous_token_mask=history_mask,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
argmax_tokens = torch.argmax(group_logits, dim=-1)
|
||||
for local_index, global_index in enumerate(group_indices):
|
||||
sampled_list[global_index] = multinomial_sample_one_no_sync(probs[local_index : local_index + 1])
|
||||
argmax_list[global_index] = int(argmax_tokens[local_index].item())
|
||||
|
||||
return [item for item in sampled_list if item is not None], [int(item) for item in argmax_list if item is not None]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch:
|
||||
x_items: List[torch.Tensor] = []
|
||||
@ -360,19 +444,26 @@ def _sample_per_request(
|
||||
updated_sequences: List[torch.LongTensor] = []
|
||||
|
||||
step_idx = active_batch.step_idx
|
||||
for batch_index, state in enumerate(active_batch.states):
|
||||
logits_i = logits[batch_index : batch_index + 1].clone()
|
||||
current_history = active_batch.y_sequences[batch_index]
|
||||
sampled = sample(
|
||||
logits_i,
|
||||
current_history.unsqueeze(0),
|
||||
sampling_keys = [
|
||||
_sampling_group_key(
|
||||
top_k=state.top_k,
|
||||
top_p=state.top_p,
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
temperature=state.temperature,
|
||||
)[0]
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
trim_eos=False,
|
||||
)
|
||||
for state in active_batch.states
|
||||
]
|
||||
sampled_items, argmax_tokens = _batched_sample_by_group(
|
||||
logits=logits,
|
||||
histories=active_batch.y_sequences,
|
||||
sampling_keys=sampling_keys,
|
||||
)
|
||||
for batch_index, state in enumerate(active_batch.states):
|
||||
current_history = active_batch.y_sequences[batch_index]
|
||||
sampled = sampled_items[batch_index]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item())
|
||||
argmax_token = argmax_tokens[batch_index]
|
||||
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
@ -507,25 +598,30 @@ def run_prefill_step(
|
||||
if len(states) == 1 and not decode_attn_mask.any().item():
|
||||
decode_attn_mask = None
|
||||
logits = model.ar_predict_layer(xy_dec[:, -1])
|
||||
sampling_keys = [
|
||||
_sampling_group_key(
|
||||
top_k=state.top_k,
|
||||
top_p=state.top_p,
|
||||
temperature=state.temperature,
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
trim_eos=True,
|
||||
)
|
||||
for state in states
|
||||
]
|
||||
sampled_items, argmax_tokens = _batched_sample_by_group(
|
||||
logits=logits,
|
||||
histories=active_batch.y_sequences,
|
||||
sampling_keys=sampling_keys,
|
||||
)
|
||||
|
||||
running_requests: List[T2SRunningRequest] = []
|
||||
finished_items: List[T2SFinishedItem] = []
|
||||
|
||||
for batch_index, state in enumerate(states):
|
||||
logits_i = logits[batch_index : batch_index + 1].clone()
|
||||
if 0 < 11:
|
||||
logits_i = logits_i[:, :-1]
|
||||
current_history = active_batch.y_sequences[batch_index]
|
||||
sampled = sample(
|
||||
logits_i,
|
||||
current_history.unsqueeze(0),
|
||||
top_k=state.top_k,
|
||||
top_p=state.top_p,
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
temperature=state.temperature,
|
||||
)[0]
|
||||
sampled = sampled_items[batch_index]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
argmax_token = int(torch.argmax(logits_i[0], dim=-1).item())
|
||||
argmax_token = argmax_tokens[batch_index]
|
||||
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
|
||||
prefix_len = int(active_batch.prefix_lens[batch_index].item())
|
||||
|
||||
@ -624,25 +720,31 @@ def run_decode_step_for_running(
|
||||
batched_decode_attn_mask,
|
||||
)
|
||||
logits = model.ar_predict_layer(xy_dec[:, -1])
|
||||
sampling_keys = [
|
||||
_sampling_group_key(
|
||||
top_k=running_request.state.top_k,
|
||||
top_p=running_request.state.top_p,
|
||||
temperature=running_request.state.temperature,
|
||||
repetition_penalty=running_request.state.repetition_penalty,
|
||||
trim_eos=running_request.step_idx < 11,
|
||||
)
|
||||
for running_request in running_requests
|
||||
]
|
||||
histories = [running_request.y_sequence for running_request in running_requests]
|
||||
sampled_items, argmax_tokens = _batched_sample_by_group(
|
||||
logits=logits,
|
||||
histories=histories,
|
||||
sampling_keys=sampling_keys,
|
||||
)
|
||||
|
||||
next_running: List[T2SRunningRequest] = []
|
||||
finished_items: List[T2SFinishedItem] = []
|
||||
|
||||
for batch_index, running_request in enumerate(running_requests):
|
||||
current_idx = running_request.step_idx
|
||||
logits_i = logits[batch_index : batch_index + 1].clone()
|
||||
if current_idx < 11:
|
||||
logits_i = logits_i[:, :-1]
|
||||
sampled = sample(
|
||||
logits_i,
|
||||
running_request.y_sequence.unsqueeze(0),
|
||||
top_k=running_request.state.top_k,
|
||||
top_p=running_request.state.top_p,
|
||||
repetition_penalty=running_request.state.repetition_penalty,
|
||||
temperature=running_request.state.temperature,
|
||||
)[0]
|
||||
sampled = sampled_items[batch_index]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
argmax_token = int(torch.argmax(logits_i[0], dim=-1).item())
|
||||
argmax_token = argmax_tokens[batch_index]
|
||||
new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0)
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user