Enhance sampling functions in TTS by adding support for previous token masks in logits_to_probs. Implement batch processing for sampling with padded token sequences and contiguous sampling groups. Refactor sampling logic in T2S scheduler to utilize new functionalities, improving efficiency and flexibility in token generation.

This commit is contained in:
baicai-1145 2026-03-09 21:24:16 +08:00
parent 845b181360
commit a45e171ff5
2 changed files with 165 additions and 42 deletions

View File

@ -147,6 +147,7 @@ def multinomial_sample_one_no_sync(
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
previous_token_mask: Optional[torch.Tensor] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
@ -158,13 +159,27 @@ def logits_to_probs(
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if previous_token_mask is None:
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
else:
previous_token_mask = previous_token_mask.to(dtype=torch.bool, device=logits.device)
if previous_token_mask.any():
batch_index = torch.arange(logits.size(0), device=logits.device).unsqueeze(1).expand_as(previous_tokens)
valid_batch_index = batch_index[previous_token_mask]
valid_token_index = previous_tokens[previous_token_mask]
score = logits[valid_batch_index, valid_token_index]
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits[valid_batch_index, valid_token_index] = score
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@ -192,9 +207,15 @@ def logits_to_probs(
def sample(
logits,
previous_tokens: Optional[torch.Tensor] = None,
previous_token_mask: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
probs = logits_to_probs(
logits=logits,
previous_tokens=previous_tokens,
previous_token_mask=previous_token_mask,
**sampling_kwargs,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from AR.models.utils import make_pad_mask_left, sample
from AR.models.utils import logits_to_probs, make_pad_mask_left, multinomial_sample_one_no_sync, sample
def _sync_device(device: Any) -> None:
@ -277,6 +277,90 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device:
)
def _pad_token_sequences(
token_sequences: Sequence[torch.LongTensor],
) -> Tuple[torch.LongTensor, torch.BoolTensor]:
if not token_sequences:
raise ValueError("token_sequences 不能为空")
device = token_sequences[0].device
max_len = max(int(sequence.shape[0]) for sequence in token_sequences)
padded = torch.zeros((len(token_sequences), max_len), dtype=token_sequences[0].dtype, device=device)
mask = torch.zeros((len(token_sequences), max_len), dtype=torch.bool, device=device)
for row_index, sequence in enumerate(token_sequences):
seq_len = int(sequence.shape[0])
padded[row_index, :seq_len] = sequence
mask[row_index, :seq_len] = True
return padded, mask
def _sampling_group_key(
top_k: int,
top_p: float,
temperature: float,
repetition_penalty: float,
trim_eos: bool,
) -> Tuple[int, float, float, float, bool]:
return (
int(top_k),
float(top_p),
float(temperature),
float(repetition_penalty),
bool(trim_eos),
)
def _iter_contiguous_sampling_groups(
sampling_keys: Sequence[Tuple[int, float, float, float, bool]],
) -> List[Tuple[Tuple[int, float, float, float, bool], List[int]]]:
groups: List[Tuple[Tuple[int, float, float, float, bool], List[int]]] = []
if not sampling_keys:
return groups
current_key = sampling_keys[0]
current_indices: List[int] = [0]
for index in range(1, len(sampling_keys)):
key = sampling_keys[index]
if key == current_key:
current_indices.append(index)
continue
groups.append((current_key, current_indices))
current_key = key
current_indices = [index]
groups.append((current_key, current_indices))
return groups
def _batched_sample_by_group(
logits: torch.Tensor,
histories: Sequence[torch.LongTensor],
sampling_keys: Sequence[Tuple[int, float, float, float, bool]],
) -> Tuple[List[torch.Tensor], List[int]]:
sampled_list: List[Optional[torch.Tensor]] = [None] * len(histories)
argmax_list: List[Optional[int]] = [None] * len(histories)
for group_key, group_indices in _iter_contiguous_sampling_groups(sampling_keys):
top_k, top_p, temperature, repetition_penalty, trim_eos = group_key
index_tensor = torch.tensor(group_indices, dtype=torch.long, device=logits.device)
group_logits = torch.index_select(logits, dim=0, index=index_tensor)
if trim_eos:
group_logits = group_logits[:, :-1]
group_histories = [histories[index] for index in group_indices]
padded_histories, history_mask = _pad_token_sequences(group_histories)
probs = logits_to_probs(
logits=group_logits,
previous_tokens=padded_histories,
previous_token_mask=history_mask,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
)
argmax_tokens = torch.argmax(group_logits, dim=-1)
for local_index, global_index in enumerate(group_indices):
sampled_list[global_index] = multinomial_sample_one_no_sync(probs[local_index : local_index + 1])
argmax_list[global_index] = int(argmax_tokens[local_index].item())
return [item for item in sampled_list if item is not None], [int(item) for item in argmax_list if item is not None]
@torch.inference_mode()
def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch:
x_items: List[torch.Tensor] = []
@ -360,19 +444,26 @@ def _sample_per_request(
updated_sequences: List[torch.LongTensor] = []
step_idx = active_batch.step_idx
for batch_index, state in enumerate(active_batch.states):
logits_i = logits[batch_index : batch_index + 1].clone()
current_history = active_batch.y_sequences[batch_index]
sampled = sample(
logits_i,
current_history.unsqueeze(0),
sampling_keys = [
_sampling_group_key(
top_k=state.top_k,
top_p=state.top_p,
repetition_penalty=state.repetition_penalty,
temperature=state.temperature,
)[0]
repetition_penalty=state.repetition_penalty,
trim_eos=False,
)
for state in active_batch.states
]
sampled_items, argmax_tokens = _batched_sample_by_group(
logits=logits,
histories=active_batch.y_sequences,
sampling_keys=sampling_keys,
)
for batch_index, state in enumerate(active_batch.states):
current_history = active_batch.y_sequences[batch_index]
sampled = sampled_items[batch_index]
sampled_token = int(sampled[0, 0].item())
argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item())
argmax_token = argmax_tokens[batch_index]
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
finish_reason: Optional[str] = None
@ -507,25 +598,30 @@ def run_prefill_step(
if len(states) == 1 and not decode_attn_mask.any().item():
decode_attn_mask = None
logits = model.ar_predict_layer(xy_dec[:, -1])
sampling_keys = [
_sampling_group_key(
top_k=state.top_k,
top_p=state.top_p,
temperature=state.temperature,
repetition_penalty=state.repetition_penalty,
trim_eos=True,
)
for state in states
]
sampled_items, argmax_tokens = _batched_sample_by_group(
logits=logits,
histories=active_batch.y_sequences,
sampling_keys=sampling_keys,
)
running_requests: List[T2SRunningRequest] = []
finished_items: List[T2SFinishedItem] = []
for batch_index, state in enumerate(states):
logits_i = logits[batch_index : batch_index + 1].clone()
if 0 < 11:
logits_i = logits_i[:, :-1]
current_history = active_batch.y_sequences[batch_index]
sampled = sample(
logits_i,
current_history.unsqueeze(0),
top_k=state.top_k,
top_p=state.top_p,
repetition_penalty=state.repetition_penalty,
temperature=state.temperature,
)[0]
sampled = sampled_items[batch_index]
sampled_token = int(sampled[0, 0].item())
argmax_token = int(torch.argmax(logits_i[0], dim=-1).item())
argmax_token = argmax_tokens[batch_index]
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
prefix_len = int(active_batch.prefix_lens[batch_index].item())
@ -624,25 +720,31 @@ def run_decode_step_for_running(
batched_decode_attn_mask,
)
logits = model.ar_predict_layer(xy_dec[:, -1])
sampling_keys = [
_sampling_group_key(
top_k=running_request.state.top_k,
top_p=running_request.state.top_p,
temperature=running_request.state.temperature,
repetition_penalty=running_request.state.repetition_penalty,
trim_eos=running_request.step_idx < 11,
)
for running_request in running_requests
]
histories = [running_request.y_sequence for running_request in running_requests]
sampled_items, argmax_tokens = _batched_sample_by_group(
logits=logits,
histories=histories,
sampling_keys=sampling_keys,
)
next_running: List[T2SRunningRequest] = []
finished_items: List[T2SFinishedItem] = []
for batch_index, running_request in enumerate(running_requests):
current_idx = running_request.step_idx
logits_i = logits[batch_index : batch_index + 1].clone()
if current_idx < 11:
logits_i = logits_i[:, :-1]
sampled = sample(
logits_i,
running_request.y_sequence.unsqueeze(0),
top_k=running_request.state.top_k,
top_p=running_request.state.top_p,
repetition_penalty=running_request.state.repetition_penalty,
temperature=running_request.state.temperature,
)[0]
sampled = sampled_items[batch_index]
sampled_token = int(sampled[0, 0].item())
argmax_token = int(torch.argmax(logits_i[0], dim=-1).item())
argmax_token = argmax_tokens[batch_index]
new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0)
finish_reason: Optional[str] = None