为sovits_v3 适配并行推理 (#2241)

* 为sovits_v3 适配并行推理

* 清理无用代码
This commit is contained in:
ChasonJiang 2025-03-31 11:56:05 +08:00 committed by GitHub
parent 6c468583c5
commit 03b662a769
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ import math
import os, sys, gc import os, sys, gc
import random import random
import traceback import traceback
import time
import torchaudio import torchaudio
from tqdm import tqdm from tqdm import tqdm
now_dir = os.getcwd() now_dir = os.getcwd()
@ -908,11 +908,14 @@ class TTS:
split_bucket = False split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket and speed_factor==1.0: if split_bucket and speed_factor==1.0 and not (self.configs.is_v3_synthesizer and parallel_infer):
print(i18n("分桶处理模式已开启")) print(i18n("分桶处理模式已开启"))
elif speed_factor!=1.0: elif speed_factor!=1.0:
print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理")) print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理"))
split_bucket = False split_bucket = False
elif self.configs.is_v3_synthesizer and parallel_infer:
print(i18n("当开启并行推理模式时SoVits V3模型不支持分桶处理已自动关闭分桶处理"))
split_bucket = False
else: else:
print(i18n("分桶处理模式已关闭")) print(i18n("分桶处理模式已关闭"))
@ -936,7 +939,7 @@ class TTS:
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
###### setting reference audio and prompt text preprocessing ######## ###### setting reference audio and prompt text preprocessing ########
t0 = ttime() t0 = time.perf_counter()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
if not os.path.exists(ref_audio_path): if not os.path.exists(ref_audio_path):
raise ValueError(f"{ref_audio_path} not exists") raise ValueError(f"{ref_audio_path} not exists")
@ -975,7 +978,7 @@ class TTS:
###### text preprocessing ######## ###### text preprocessing ########
t1 = ttime() t1 = time.perf_counter()
data:list = None data:list = None
if not return_fragment: if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
@ -1027,7 +1030,7 @@ class TTS:
return batch[0] return batch[0]
t2 = ttime() t2 = time.perf_counter()
try: try:
print("############ 推理 ############") print("############ 推理 ############")
###### inference ###### ###### inference ######
@ -1036,7 +1039,7 @@ class TTS:
audio = [] audio = []
output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000 output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000
for item in data: for item in data:
t3 = ttime() t3 = time.perf_counter()
if return_fragment: if return_fragment:
item = make_batch(item) item = make_batch(item)
if item is None: if item is None:
@ -1071,7 +1074,7 @@ class TTS:
max_len=max_len, max_len=max_len,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
) )
t4 = ttime() t4 = time.perf_counter()
t_34 += t4 - t3 t_34 += t4 - t3
refer_audio_spec:torch.Tensor = [item.to(dtype=self.precision, device=self.configs.device) for item in self.prompt_cache["refer_spec"]] refer_audio_spec:torch.Tensor = [item.to(dtype=self.precision, device=self.configs.device) for item in self.prompt_cache["refer_spec"]]
@ -1094,6 +1097,7 @@ class TTS:
print(f"############ {i18n('合成音频')} ############") print(f"############ {i18n('合成音频')} ############")
if not self.configs.is_v3_synthesizer: if not self.configs.is_v3_synthesizer:
if speed_factor == 1.0: if speed_factor == 1.0:
print(f"{i18n('并行合成中')}...")
# ## vits并行推理 method 2 # ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates) upsample_rate = math.prod(self.vits_model.upsample_rates)
@ -1118,17 +1122,28 @@ class TTS:
audio_fragment audio_fragment
) ###试试重建不带上prompt部分 ) ###试试重建不带上prompt部分
else: else:
for i, idx in enumerate(tqdm(idx_list)): if parallel_infer:
phones = batch_phones[i].unsqueeze(0).to(self.configs.device) print(f"{i18n('并行合成中')}...")
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 audio_fragments = self.v3_synthesis_batched_infer(
audio_fragment = self.v3_synthesis( idx_list,
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps pred_semantic_list,
) batch_phones,
batch_audio_fragment.append( speed=speed_factor,
audio_fragment sample_steps=sample_steps
) )
batch_audio_fragment.extend(audio_fragments)
else:
for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.v3_synthesis(
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
)
batch_audio_fragment.append(
audio_fragment
)
t5 = ttime() t5 = time.perf_counter()
t_45 += t5 - t4 t_45 += t5 - t4
if return_fragment: if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
@ -1219,13 +1234,13 @@ class TTS:
if super_sampling: if super_sampling:
print(f"############ {i18n('音频超采样')} ############") print(f"############ {i18n('音频超采样')} ############")
t1 = ttime() t1 = time.perf_counter()
self.init_sr_model() self.init_sr_model()
if not self.sr_model_not_exist: if not self.sr_model_not_exist:
audio,sr=self.sr_model(audio.unsqueeze(0),sr) audio,sr=self.sr_model(audio.unsqueeze(0),sr)
max_audio=np.abs(audio).max() max_audio=np.abs(audio).max()
if max_audio > 1: audio /= max_audio if max_audio > 1: audio /= max_audio
t2 = ttime() t2 = time.perf_counter()
print(f"超采样用时:{t2-t1:.3f}s") print(f"超采样用时:{t2-t1:.3f}s")
else: else:
audio = audio.cpu().numpy() audio = audio.cpu().numpy()
@ -1260,7 +1275,7 @@ class TTS:
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr!=24000: if ref_sr!=24000:
ref_audio=resample(ref_audio, ref_sr, self.configs.device) ref_audio=resample(ref_audio, ref_sr, self.configs.device)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio) mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
@ -1285,15 +1300,156 @@ class TTS:
cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:] cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:]
mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:] fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res) cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2) cfm_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res) cfm_res = denorm_spec(cfm_res)
with torch.inference_mode(): with torch.inference_mode():
wav_gen = self.bigvgan_model(cmf_res) wav_gen = self.bigvgan_model(cfm_res)
audio=wav_gen[0][0]#.cpu().detach().numpy() audio=wav_gen[0][0]#.cpu().detach().numpy()
return audio return audio
def v3_synthesis_batched_infer(self,
idx_list:List[int],
semantic_tokens_list:List[torch.Tensor],
batch_phones:List[torch.Tensor],
speed:float=1.0,
sample_steps:int=32
)->List[torch.Tensor]:
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio:torch.Tensor = self.prompt_cache["raw_audio"]
ref_sr = self.prompt_cache["raw_sr"]
ref_audio=ref_audio.to(self.configs.device).float()
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr!=24000:
ref_audio=resample(ref_audio, ref_sr, self.configs.device)
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if (T_min > 468):
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
mel2=mel2.to(self.precision)
# #### batched inference
overlapped_len = 12
feat_chunks = []
feat_lens = []
feat_list = []
for i, idx in enumerate(idx_list):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
semantic_tokens = semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
feat_list.append(feat)
feat_lens.append(feat.shape[2])
feats = torch.cat(feat_list, 2)
feats_padded = F.pad(feats, (overlapped_len,0), "constant", 0)
pos = 0
padding_len = 0
while True:
if pos ==0:
chunk = feats_padded[:, :, pos:pos + chunk_len]
else:
pos = pos - overlapped_len
chunk = feats_padded[:, :, pos:pos + chunk_len]
pos += chunk_len
if (chunk.shape[-1] == 0): break
# padding for the last chunk
padding_len = chunk_len - chunk.shape[2]
if padding_len != 0:
chunk = F.pad(chunk, (0,padding_len), "constant", 0)
feat_chunks.append(chunk)
feat_chunks = torch.cat(feat_chunks, 0)
bs = feat_chunks.shape[0]
fea_ref = fea_ref.repeat(bs,1,1)
fea = torch.cat([fea_ref, feat_chunks], 2).transpose(2, 1)
pred_spec = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
pred_spec = pred_spec[:, :, -chunk_len:]
dd = pred_spec.shape[1]
pred_spec = pred_spec.permute(1, 0, 2).contiguous().view(dd, -1).unsqueeze(0)
# pred_spec = pred_spec[..., :-padding_len]
pred_spec = denorm_spec(pred_spec)
with torch.no_grad():
wav_gen = self.bigvgan_model(pred_spec)
audio = wav_gen[0][0]#.cpu().detach().numpy()
audio_fragments = []
upsample_rate = 256
pos = 0
while pos < audio.shape[-1]:
audio_fragment = audio[pos:pos+chunk_len*upsample_rate]
audio_fragments.append(audio_fragment)
pos += chunk_len*upsample_rate
audio = self.sola_algorithm(audio_fragments, overlapped_len*upsample_rate)
audio = audio[overlapped_len*upsample_rate:-padding_len*upsample_rate]
audio_fragments = []
for feat_len in feat_lens:
audio_fragment = audio[:feat_len*upsample_rate]
audio_fragments.append(audio_fragment)
audio = audio[feat_len*upsample_rate:]
return audio_fragments
def sola_algorithm(self,
audio_fragments:List[torch.Tensor],
overlap_len:int,
):
for i in range(len(audio_fragments)-1):
f1 = audio_fragments[i]
f2 = audio_fragments[i+1]
w1 = f1[-overlap_len:]
w2 = f2[:overlap_len]
assert w1.shape == w2.shape
corr = F.conv1d(w1.view(1,1,-1), w2.view(1,1,-1),padding=w2.shape[-1]//2).view(-1)[:-1]
idx = corr.argmax()
f1_ = f1[:-(overlap_len-idx)]
audio_fragments[i] = f1_
f2_ = f2[idx:]
window = torch.hann_window((overlap_len-idx)*2, device=f1.device, dtype=f1.dtype)
f2_[:(overlap_len-idx)] = window[:(overlap_len-idx)]*f2_[:(overlap_len-idx)] + window[(overlap_len-idx):]*f1[-(overlap_len-idx):]
audio_fragments[i+1] = f2_
return torch.cat(audio_fragments, 0)