mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-17 01:59:08 +08:00
支持固定chunk长度的流式推理,优化sola算法
This commit is contained in:
parent
9b147cd24a
commit
6bce575d69
@ -940,9 +940,10 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
|
elif streaming_mode and (mute_emb_sim_matrix is None) and (token_counter >= chunk_length):
|
||||||
token_counter == chunk_length
|
|
||||||
yield y[:, -token_counter:], False
|
yield y[:, -token_counter:], False
|
||||||
|
curr_ptr+=token_counter
|
||||||
token_counter = 0
|
token_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
####################### update next step ###################################
|
####################### update next step ###################################
|
||||||
|
|||||||
@ -1014,18 +1014,19 @@ class TTS:
|
|||||||
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
|
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
|
||||||
"batch_size": 1, # int. batch size for inference
|
"batch_size": 1, # int. batch size for inference
|
||||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
|
||||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||||
"streaming_mode": False, # bool. return audio chunk by chunk.
|
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||||||
|
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||||
|
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||||||
}
|
}
|
||||||
returns:
|
returns:
|
||||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||||
@ -1058,6 +1059,7 @@ class TTS:
|
|||||||
streaming_mode = inputs.get("streaming_mode", False)
|
streaming_mode = inputs.get("streaming_mode", False)
|
||||||
overlap_length = inputs.get("overlap_length", 2)
|
overlap_length = inputs.get("overlap_length", 2)
|
||||||
min_chunk_length = inputs.get("min_chunk_length", 16)
|
min_chunk_length = inputs.get("min_chunk_length", 16)
|
||||||
|
fixed_length_chunk = inputs.get("fixed_length_chunk", False)
|
||||||
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。
|
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。
|
||||||
|
|
||||||
if parallel_infer and not streaming_mode:
|
if parallel_infer and not streaming_mode:
|
||||||
@ -1367,7 +1369,7 @@ class TTS:
|
|||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
streaming_mode=True,
|
streaming_mode=True,
|
||||||
chunk_length=min_chunk_length,
|
chunk_length=min_chunk_length,
|
||||||
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix,
|
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None,
|
||||||
chunk_split_thershold=chunk_split_thershold,
|
chunk_split_thershold=chunk_split_thershold,
|
||||||
)
|
)
|
||||||
t4 = time.perf_counter()
|
t4 = time.perf_counter()
|
||||||
@ -1456,11 +1458,6 @@ class TTS:
|
|||||||
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
|
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
|
||||||
)
|
)
|
||||||
|
|
||||||
# audio_chunk_ = (
|
|
||||||
# audio_chunk_[overlap_size:-overlap_size] if not is_final \
|
|
||||||
# else audio_chunk_[overlap_size:]
|
|
||||||
# )
|
|
||||||
|
|
||||||
last_latent = latent
|
last_latent = latent
|
||||||
last_audio_chunk = audio_chunk
|
last_audio_chunk = audio_chunk
|
||||||
yield self.audio_postprocess(
|
yield self.audio_postprocess(
|
||||||
@ -1785,30 +1782,35 @@ class TTS:
|
|||||||
self,
|
self,
|
||||||
audio_fragments: List[torch.Tensor],
|
audio_fragments: List[torch.Tensor],
|
||||||
overlap_len: int,
|
overlap_len: int,
|
||||||
|
search_len:int= 320
|
||||||
):
|
):
|
||||||
for i in range(len(audio_fragments) - 1):
|
# overlap_len-=search_len
|
||||||
f1 = audio_fragments[i]
|
|
||||||
f2 = audio_fragments[i + 1]
|
|
||||||
w1 = f1[-overlap_len:]
|
|
||||||
w2 = f2[:overlap_len]
|
|
||||||
w2 = w2[-w2.shape[-1]//2:]
|
|
||||||
# assert w1.shape == w2.shape
|
|
||||||
corr = F.conv1d(w1.view(1, 1, -1), w2.view(1, 1, -1)).view(-1)
|
|
||||||
|
|
||||||
squared_sum = F.conv1d(w1.view(1, 1, -1)**2, torch.ones_like(w2).view(1, 1, -1)).view(-1)+ 1e-8
|
dtype = audio_fragments[0].dtype
|
||||||
idx = (corr/squared_sum.sqrt()).argmax()
|
|
||||||
|
for i in range(len(audio_fragments) - 1):
|
||||||
|
f1 = audio_fragments[i].float()
|
||||||
|
f2 = audio_fragments[i + 1].float()
|
||||||
|
w1 = f1[-overlap_len:]
|
||||||
|
w2 = f2[:overlap_len+search_len]
|
||||||
|
# w2 = w2[-w2.shape[-1]//2:]
|
||||||
|
# assert w1.shape == w2.shape
|
||||||
|
corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1)
|
||||||
|
|
||||||
|
corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8
|
||||||
|
idx = (corr_norm/corr_den.sqrt()).argmax()
|
||||||
|
|
||||||
print(f"seg_idx: {idx}")
|
print(f"seg_idx: {idx}")
|
||||||
|
|
||||||
# idx = corr.argmax()
|
# idx = corr.argmax()
|
||||||
f1_ = f1[: -(overlap_len - idx)]
|
f1_ = f1[: -overlap_len]
|
||||||
audio_fragments[i] = f1_
|
audio_fragments[i] = f1_
|
||||||
|
|
||||||
f2_ = f2[idx:]
|
f2_ = f2[idx:]
|
||||||
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
|
window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype)
|
||||||
f2_[: (overlap_len - idx)] = (
|
f2_[: overlap_len] = (
|
||||||
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
|
window[: overlap_len] * f2_[: overlap_len]
|
||||||
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
|
+ window[overlap_len :] * f1[-overlap_len :]
|
||||||
)
|
)
|
||||||
|
|
||||||
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
|
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
|
||||||
@ -1819,4 +1821,4 @@ class TTS:
|
|||||||
|
|
||||||
audio_fragments[i + 1] = f2_
|
audio_fragments[i + 1] = f2_
|
||||||
|
|
||||||
return torch.cat(audio_fragments, 0)
|
return torch.cat(audio_fragments, 0).to(dtype)
|
||||||
|
|||||||
33
api_v2.py
33
api_v2.py
@ -35,15 +35,17 @@ POST:
|
|||||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||||
|
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||||||
|
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||||
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
|
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -176,6 +178,7 @@ class TTS_Request(BaseModel):
|
|||||||
overlap_length: int = 2
|
overlap_length: int = 2
|
||||||
min_chunk_length: int = 16
|
min_chunk_length: int = 16
|
||||||
return_fragment: bool = False
|
return_fragment: bool = False
|
||||||
|
fixed_length_chunk: bool = False
|
||||||
|
|
||||||
|
|
||||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||||
@ -313,7 +316,7 @@ async def tts_handle(req: dict):
|
|||||||
"text": "", # str.(required) text to be synthesized
|
"text": "", # str.(required) text to be synthesized
|
||||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||||
"ref_audio_path": "", # str.(required) reference audio path
|
"ref_audio_path": "", # str.(required) reference audio path
|
||||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||||
"top_k": 5, # int. top k sampling
|
"top_k": 5, # int. top k sampling
|
||||||
@ -322,19 +325,19 @@ async def tts_handle(req: dict):
|
|||||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||||
"batch_size": 1, # int. batch size for inference
|
"batch_size": 1, # int. batch size for inference
|
||||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||||
"seed": -1, # int. random seed for reproducibility.
|
"seed": -1, # int. random seed for reproducibility.
|
||||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
|
||||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
|
||||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||||
|
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||||||
|
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||||
"min_chunk_length: 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||||
"return_fragment": False, # bool. step by step return the audio fragment. (old version of streaming mode)
|
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||||||
}
|
}
|
||||||
returns:
|
returns:
|
||||||
StreamingResponse: audio stream response.
|
StreamingResponse: audio stream response.
|
||||||
@ -402,7 +405,7 @@ async def tts_get_endpoint(
|
|||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
top_p: float = 1,
|
top_p: float = 1,
|
||||||
temperature: float = 1,
|
temperature: float = 1,
|
||||||
text_split_method: str = "cut0",
|
text_split_method: str = "cut5",
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
batch_threshold: float = 0.75,
|
batch_threshold: float = 0.75,
|
||||||
split_bucket: bool = True,
|
split_bucket: bool = True,
|
||||||
@ -410,14 +413,15 @@ async def tts_get_endpoint(
|
|||||||
fragment_interval: float = 0.3,
|
fragment_interval: float = 0.3,
|
||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
media_type: str = "wav",
|
media_type: str = "wav",
|
||||||
streaming_mode: bool = False,
|
|
||||||
parallel_infer: bool = True,
|
parallel_infer: bool = True,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
sample_steps: int = 32,
|
sample_steps: int = 32,
|
||||||
super_sampling: bool = False,
|
super_sampling: bool = False,
|
||||||
|
return_fragment: bool = False,
|
||||||
|
streaming_mode: bool = False,
|
||||||
overlap_length: int = 2,
|
overlap_length: int = 2,
|
||||||
min_chunk_length: int = 16,
|
min_chunk_length: int = 16,
|
||||||
return_fragment: bool = False,
|
fixed_length_chunk: bool = False,
|
||||||
):
|
):
|
||||||
req = {
|
req = {
|
||||||
"text": text,
|
"text": text,
|
||||||
@ -445,6 +449,7 @@ async def tts_get_endpoint(
|
|||||||
"overlap_length": int(overlap_length),
|
"overlap_length": int(overlap_length),
|
||||||
"min_chunk_length": int(min_chunk_length),
|
"min_chunk_length": int(min_chunk_length),
|
||||||
"return_fragment": return_fragment,
|
"return_fragment": return_fragment,
|
||||||
|
"fixed_length_chunk": fixed_length_chunk
|
||||||
}
|
}
|
||||||
return await tts_handle(req)
|
return await tts_handle(req)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user