mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-17 01:59:08 +08:00
支持固定chunk长度的流式推理,优化sola算法
This commit is contained in:
parent
9b147cd24a
commit
6bce575d69
@ -940,11 +940,12 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
|
||||
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
|
||||
token_counter == chunk_length
|
||||
yield y[:, -token_counter:], False
|
||||
curr_ptr+=token_counter
|
||||
token_counter = 0
|
||||
|
||||
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
|
||||
@ -1014,18 +1014,19 @@ class TTS:
|
||||
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"streaming_mode": False, # bool. return audio chunk by chunk.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||||
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||||
}
|
||||
returns:
|
||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||
@ -1058,6 +1059,7 @@ class TTS:
|
||||
streaming_mode = inputs.get("streaming_mode", False)
|
||||
overlap_length = inputs.get("overlap_length", 2)
|
||||
min_chunk_length = inputs.get("min_chunk_length", 16)
|
||||
fixed_length_chunk = inputs.get("fixed_length_chunk", False)
|
||||
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。
|
||||
|
||||
if parallel_infer and not streaming_mode:
|
||||
@ -1367,7 +1369,7 @@ class TTS:
|
||||
repetition_penalty=repetition_penalty,
|
||||
streaming_mode=True,
|
||||
chunk_length=min_chunk_length,
|
||||
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
|
||||
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None,
|
||||
chunk_split_thershold=chunk_split_thershold,
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
@ -1456,11 +1458,6 @@ class TTS:
|
||||
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
|
||||
)
|
||||
|
||||
# audio_chunk_ = (
|
||||
# audio_chunk_[overlap_size:-overlap_size] if not is_final \
|
||||
# else audio_chunk_[overlap_size:]
|
||||
# )
|
||||
|
||||
last_latent = latent
|
||||
last_audio_chunk = audio_chunk
|
||||
yield self.audio_postprocess(
|
||||
@ -1785,30 +1782,35 @@ class TTS:
|
||||
self,
|
||||
audio_fragments: List[torch.Tensor],
|
||||
overlap_len: int,
|
||||
search_len:int= 320
|
||||
):
|
||||
for i in range(len(audio_fragments) - 1):
|
||||
f1 = audio_fragments[i]
|
||||
f2 = audio_fragments[i + 1]
|
||||
w1 = f1[-overlap_len:]
|
||||
w2 = f2[:overlap_len]
|
||||
w2 = w2[-w2.shape[-1]//2:]
|
||||
# assert w1.shape == w2.shape
|
||||
corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1)).view(-1)
|
||||
# overlap_len-=search_len
|
||||
|
||||
squared_sum = F.conv1d(w1.view(1, 1, -1)**2, torch.ones_like(w2).view(1, 1, -1)).view(-1)+ 1e-8
|
||||
idx = (corr/squared_sum.sqrt()).argmax()
|
||||
dtype = audio_fragments[0].dtype
|
||||
|
||||
for i in range(len(audio_fragments) - 1):
|
||||
f1 = audio_fragments[i].float()
|
||||
f2 = audio_fragments[i + 1].float()
|
||||
w1 = f1[-overlap_len:]
|
||||
w2 = f2[:overlap_len+search_len]
|
||||
# w2 = w2[-w2.shape[-1]//2:]
|
||||
# assert w1.shape == w2.shape
|
||||
corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1)
|
||||
|
||||
corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8
|
||||
idx = (corr_norm/corr_den.sqrt()).argmax()
|
||||
|
||||
print(f"seg_idx: {idx}")
|
||||
|
||||
# idx = corr.argmax()
|
||||
f1_ = f1[: -(overlap_len - idx)]
|
||||
f1_ = f1[: -overlap_len]
|
||||
audio_fragments[i] = f1_
|
||||
|
||||
f2_ = f2[idx:]
|
||||
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
|
||||
f2_[: (overlap_len - idx)] = (
|
||||
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
|
||||
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
|
||||
window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype)
|
||||
f2_[: overlap_len] = (
|
||||
window[: overlap_len] * f2_[: overlap_len]
|
||||
+ window[overlap_len :] * f1[-overlap_len :]
|
||||
)
|
||||
|
||||
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
|
||||
@ -1819,4 +1821,4 @@ class TTS:
|
||||
|
||||
audio_fragments[i + 1] = f2_
|
||||
|
||||
return torch.cat(audio_fragments, 0)
|
||||
return torch.cat(audio_fragments, 0).to(dtype)
|
||||
|
||||
33
api_v2.py
33
api_v2.py
@ -35,15 +35,17 @@ POST:
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||||
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||||
}
|
||||
```
|
||||
|
||||
@ -176,6 +178,7 @@ class TTS_Request(BaseModel):
|
||||
overlap_length: int = 2
|
||||
min_chunk_length: int = 16
|
||||
return_fragment: bool = False
|
||||
fixed_length_chunk: bool = False
|
||||
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
@ -313,7 +316,7 @@ async def tts_handle(req: dict):
|
||||
"text": "", # str.(required) text to be synthesized
|
||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||
"ref_audio_path": "", # str.(required) reference audio path
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
@ -322,19 +325,19 @@ async def tts_handle(req: dict):
|
||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||||
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||||
}
|
||||
returns:
|
||||
StreamingResponse: audio stream response.
|
||||
@ -402,7 +405,7 @@ async def tts_get_endpoint(
|
||||
top_k: int = 5,
|
||||
top_p: float = 1,
|
||||
temperature: float = 1,
|
||||
text_split_method: str = "cut0",
|
||||
text_split_method: str = "cut5",
|
||||
batch_size: int = 1,
|
||||
batch_threshold: float = 0.75,
|
||||
split_bucket: bool = True,
|
||||
@ -410,14 +413,15 @@ async def tts_get_endpoint(
|
||||
fragment_interval: float = 0.3,
|
||||
seed: int = -1,
|
||||
media_type: str = "wav",
|
||||
streaming_mode: bool = False,
|
||||
parallel_infer: bool = True,
|
||||
repetition_penalty: float = 1.35,
|
||||
sample_steps: int = 32,
|
||||
super_sampling: bool = False,
|
||||
return_fragment: bool = False,
|
||||
streaming_mode: bool = False,
|
||||
overlap_length: int = 2,
|
||||
min_chunk_length: int = 16,
|
||||
return_fragment: bool = False,
|
||||
fixed_length_chunk: bool = False,
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
@ -445,6 +449,7 @@ async def tts_get_endpoint(
|
||||
"overlap_length": int(overlap_length),
|
||||
"min_chunk_length": int(min_chunk_length),
|
||||
"return_fragment": return_fragment,
|
||||
"fixed_length_chunk": fixed_length_chunk
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user