From a45e171ff50a372f1ac1ce567fd95659050b453b Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 21:24:16 +0800 Subject: [PATCH] Enhance sampling functions in TTS by adding support for previous token masks in logits_to_probs. Implement batch processing for sampling with padded token sequences and contiguous sampling groups. Refactor sampling logic in T2S scheduler to utilize new functionalities, improving efficiency and flexibility in token generation. --- GPT_SoVITS/AR/models/utils.py | 37 ++++- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 170 ++++++++++++++++----- 2 files changed, 165 insertions(+), 42 deletions(-) diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index cc4f24d8..4b564ed8 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -147,6 +147,7 @@ def multinomial_sample_one_no_sync( def logits_to_probs( logits, previous_tokens: Optional[torch.Tensor] = None, + previous_token_mask: Optional[torch.Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[int] = None, @@ -158,13 +159,27 @@ def logits_to_probs( # pdb.set_trace() if previous_tokens is not None and repetition_penalty != 1.0: previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=1, index=previous_tokens) - score = torch.where( - score < 0, - score * repetition_penalty, - score / repetition_penalty, - ) - logits.scatter_(dim=1, index=previous_tokens, src=score) + if previous_token_mask is None: + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=1, index=previous_tokens, src=score) + else: + previous_token_mask = previous_token_mask.to(dtype=torch.bool, device=logits.device) + if previous_token_mask.any(): + batch_index = torch.arange(logits.size(0), device=logits.device).unsqueeze(1).expand_as(previous_tokens) + valid_batch_index = batch_index[previous_token_mask] + valid_token_index = previous_tokens[previous_token_mask] + score = logits[valid_batch_index, valid_token_index] + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits[valid_batch_index, valid_token_index] = score if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) @@ -192,9 +207,15 @@ def logits_to_probs( def sample( logits, previous_tokens: Optional[torch.Tensor] = None, + previous_token_mask: Optional[torch.Tensor] = None, **sampling_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs) + probs = logits_to_probs( + logits=logits, + previous_tokens=previous_tokens, + previous_token_mask=previous_token_mask, + **sampling_kwargs, + ) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index de498573..b7118a72 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple import torch import torch.nn.functional as F -from AR.models.utils import make_pad_mask_left, sample +from AR.models.utils import logits_to_probs, make_pad_mask_left, multinomial_sample_one_no_sync, sample def _sync_device(device: Any) -> None: @@ -277,6 +277,90 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: ) +def _pad_token_sequences( + token_sequences: Sequence[torch.LongTensor], +) -> Tuple[torch.LongTensor, torch.BoolTensor]: + if not token_sequences: + raise ValueError("token_sequences 不能为空") + device = token_sequences[0].device + max_len = max(int(sequence.shape[0]) for sequence in token_sequences) + padded = torch.zeros((len(token_sequences), max_len), dtype=token_sequences[0].dtype, device=device) + mask = torch.zeros((len(token_sequences), max_len), dtype=torch.bool, device=device) + for row_index, sequence in enumerate(token_sequences): + seq_len = int(sequence.shape[0]) + padded[row_index, :seq_len] = sequence + mask[row_index, :seq_len] = True + return padded, mask + + +def _sampling_group_key( + top_k: int, + top_p: float, + temperature: float, + repetition_penalty: float, + trim_eos: bool, +) -> Tuple[int, float, float, float, bool]: + return ( + int(top_k), + float(top_p), + float(temperature), + float(repetition_penalty), + bool(trim_eos), + ) + + +def _iter_contiguous_sampling_groups( + sampling_keys: Sequence[Tuple[int, float, float, float, bool]], +) -> List[Tuple[Tuple[int, float, float, float, bool], List[int]]]: + groups: List[Tuple[Tuple[int, float, float, float, bool], List[int]]] = [] + if not sampling_keys: + return groups + current_key = sampling_keys[0] + current_indices: List[int] = [0] + for index in range(1, len(sampling_keys)): + key = sampling_keys[index] + if key == current_key: + current_indices.append(index) + continue + groups.append((current_key, current_indices)) + current_key = key + current_indices = [index] + groups.append((current_key, current_indices)) + return groups + + +def _batched_sample_by_group( + logits: torch.Tensor, + histories: Sequence[torch.LongTensor], + sampling_keys: Sequence[Tuple[int, float, float, float, bool]], +) -> Tuple[List[torch.Tensor], List[int]]: + sampled_list: List[Optional[torch.Tensor]] = [None] * len(histories) + argmax_list: List[Optional[int]] = [None] * len(histories) + for group_key, group_indices in _iter_contiguous_sampling_groups(sampling_keys): + top_k, top_p, temperature, repetition_penalty, trim_eos = group_key + index_tensor = torch.tensor(group_indices, dtype=torch.long, device=logits.device) + group_logits = torch.index_select(logits, dim=0, index=index_tensor) + if trim_eos: + group_logits = group_logits[:, :-1] + group_histories = [histories[index] for index in group_indices] + padded_histories, history_mask = _pad_token_sequences(group_histories) + probs = logits_to_probs( + logits=group_logits, + previous_tokens=padded_histories, + previous_token_mask=history_mask, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + argmax_tokens = torch.argmax(group_logits, dim=-1) + for local_index, global_index in enumerate(group_indices): + sampled_list[global_index] = multinomial_sample_one_no_sync(probs[local_index : local_index + 1]) + argmax_list[global_index] = int(argmax_tokens[local_index].item()) + + return [item for item in sampled_list if item is not None], [int(item) for item in argmax_list if item is not None] + + @torch.inference_mode() def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: x_items: List[torch.Tensor] = [] @@ -360,19 +444,26 @@ def _sample_per_request( updated_sequences: List[torch.LongTensor] = [] step_idx = active_batch.step_idx - for batch_index, state in enumerate(active_batch.states): - logits_i = logits[batch_index : batch_index + 1].clone() - current_history = active_batch.y_sequences[batch_index] - sampled = sample( - logits_i, - current_history.unsqueeze(0), + sampling_keys = [ + _sampling_group_key( top_k=state.top_k, top_p=state.top_p, - repetition_penalty=state.repetition_penalty, temperature=state.temperature, - )[0] + repetition_penalty=state.repetition_penalty, + trim_eos=False, + ) + for state in active_batch.states + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, + ) + for batch_index, state in enumerate(active_batch.states): + current_history = active_batch.y_sequences[batch_index] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None @@ -507,25 +598,30 @@ def run_prefill_step( if len(states) == 1 and not decode_attn_mask.any().item(): decode_attn_mask = None logits = model.ar_predict_layer(xy_dec[:, -1]) + sampling_keys = [ + _sampling_group_key( + top_k=state.top_k, + top_p=state.top_p, + temperature=state.temperature, + repetition_penalty=state.repetition_penalty, + trim_eos=True, + ) + for state in states + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, + ) running_requests: List[T2SRunningRequest] = [] finished_items: List[T2SFinishedItem] = [] for batch_index, state in enumerate(states): - logits_i = logits[batch_index : batch_index + 1].clone() - if 0 < 11: - logits_i = logits_i[:, :-1] current_history = active_batch.y_sequences[batch_index] - sampled = sample( - logits_i, - current_history.unsqueeze(0), - top_k=state.top_k, - top_p=state.top_p, - repetition_penalty=state.repetition_penalty, - temperature=state.temperature, - )[0] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) prefix_len = int(active_batch.prefix_lens[batch_index].item()) @@ -624,25 +720,31 @@ def run_decode_step_for_running( batched_decode_attn_mask, ) logits = model.ar_predict_layer(xy_dec[:, -1]) + sampling_keys = [ + _sampling_group_key( + top_k=running_request.state.top_k, + top_p=running_request.state.top_p, + temperature=running_request.state.temperature, + repetition_penalty=running_request.state.repetition_penalty, + trim_eos=running_request.step_idx < 11, + ) + for running_request in running_requests + ] + histories = [running_request.y_sequence for running_request in running_requests] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=histories, + sampling_keys=sampling_keys, + ) next_running: List[T2SRunningRequest] = [] finished_items: List[T2SFinishedItem] = [] for batch_index, running_request in enumerate(running_requests): current_idx = running_request.step_idx - logits_i = logits[batch_index : batch_index + 1].clone() - if current_idx < 11: - logits_i = logits_i[:, :-1] - sampled = sample( - logits_i, - running_request.y_sequence.unsqueeze(0), - top_k=running_request.state.top_k, - top_p=running_request.state.top_p, - repetition_penalty=running_request.state.repetition_penalty, - temperature=running_request.state.temperature, - )[0] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None