支持固定chunk长度的流式推理,优化sola算法

This commit is contained in:
ChasonJiang 2025-11-26 14:41:42 +08:00
parent 9b147cd24a
commit 6bce575d69
3 changed files with 50 additions and 42 deletions

View File

@ -940,11 +940,12 @@ class Text2SemanticDecoder(nn.Module):
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
token_counter == chunk_length
yield y[:, -token_counter:], False
curr_ptr+=token_counter
token_counter = 0
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[

View File

@ -1014,18 +1014,19 @@ class TTS:
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"streaming_mode": False, # bool. return audio chunk by chunk.
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
}
returns:
Tuple[int, np.ndarray]: sampling rate and audio data.
@ -1058,6 +1059,7 @@ class TTS:
streaming_mode = inputs.get("streaming_mode", False)
overlap_length = inputs.get("overlap_length", 2)
min_chunk_length = inputs.get("min_chunk_length", 16)
fixed_length_chunk = inputs.get("fixed_length_chunk", False)
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值若大于该阈值则视为可切分点。
if parallel_infer and not streaming_mode:
@ -1367,7 +1369,7 @@ class TTS:
repetition_penalty=repetition_penalty,
streaming_mode=True,
chunk_length=min_chunk_length,
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None,
chunk_split_thershold=chunk_split_thershold,
)
t4 = time.perf_counter()
@ -1456,11 +1458,6 @@ class TTS:
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
)
# audio_chunk_ = (
# audio_chunk_[overlap_size:-overlap_size] if not is_final \
# else audio_chunk_[overlap_size:]
# )
last_latent = latent
last_audio_chunk = audio_chunk
yield self.audio_postprocess(
@ -1785,30 +1782,35 @@ class TTS:
self,
audio_fragments: List[torch.Tensor],
overlap_len: int,
search_len:int= 320
):
for i in range(len(audio_fragments) - 1):
f1 = audio_fragments[i]
f2 = audio_fragments[i + 1]
w1 = f1[-overlap_len:]
w2 = f2[:overlap_len]
w2 = w2[-w2.shape[-1]//2:]
# assert w1.shape == w2.shape
corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1)).view(-1)
# overlap_len-=search_len
squared_sum = F.conv1d(w1.view(1, 1, -1)**2, torch.ones_like(w2).view(1, 1, -1)).view(-1)+ 1e-8
idx = (corr/squared_sum.sqrt()).argmax()
dtype = audio_fragments[0].dtype
for i in range(len(audio_fragments) - 1):
f1 = audio_fragments[i].float()
f2 = audio_fragments[i + 1].float()
w1 = f1[-overlap_len:]
w2 = f2[:overlap_len+search_len]
# w2 = w2[-w2.shape[-1]//2:]
# assert w1.shape == w2.shape
corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1)
corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8
idx = (corr_norm/corr_den.sqrt()).argmax()
print(f"seg_idx: {idx}")
# idx = corr.argmax()
f1_ = f1[: -(overlap_len - idx)]
f1_ = f1[: -overlap_len]
audio_fragments[i] = f1_
f2_ = f2[idx:]
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
f2_[: (overlap_len - idx)] = (
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype)
f2_[: overlap_len] = (
window[: overlap_len] * f2_[: overlap_len]
+ window[overlap_len :] * f1[-overlap_len :]
)
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
@ -1819,4 +1821,4 @@ class TTS:
audio_fragments[i + 1] = f2_
return torch.cat(audio_fragments, 0)
return torch.cat(audio_fragments, 0).to(dtype)

View File

@ -35,15 +35,17 @@ POST:
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"streaming_mode": False, # bool. whether to return a streaming response.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
}
```
@ -176,6 +178,7 @@ class TTS_Request(BaseModel):
overlap_length: int = 2
min_chunk_length: int = 16
return_fragment: bool = False
fixed_length_chunk: bool = False
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
@ -313,7 +316,7 @@ async def tts_handle(req: dict):
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
@ -322,19 +325,19 @@ async def tts_handle(req: dict):
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
}
returns:
StreamingResponse: audio stream response.
@ -402,7 +405,7 @@ async def tts_get_endpoint(
top_k: int = 5,
top_p: float = 1,
temperature: float = 1,
text_split_method: str = "cut0",
text_split_method: str = "cut5",
batch_size: int = 1,
batch_threshold: float = 0.75,
split_bucket: bool = True,
@ -410,14 +413,15 @@ async def tts_get_endpoint(
fragment_interval: float = 0.3,
seed: int = -1,
media_type: str = "wav",
streaming_mode: bool = False,
parallel_infer: bool = True,
repetition_penalty: float = 1.35,
sample_steps: int = 32,
super_sampling: bool = False,
return_fragment: bool = False,
streaming_mode: bool = False,
overlap_length: int = 2,
min_chunk_length: int = 16,
return_fragment: bool = False,
fixed_length_chunk: bool = False,
):
req = {
"text": text,
@ -445,6 +449,7 @@ async def tts_get_endpoint(
"overlap_length": int(overlap_length),
"min_chunk_length": int(min_chunk_length),
"return_fragment": return_fragment,
"fixed_length_chunk": fixed_length_chunk
}
return await tts_handle(req)