为并行推理版本适配v4 (#2307)

* 适配v4版本

* 适配v4版本

* modified:   GPT_SoVITS/inference_webui_fast.py

* 合并main分支

* fallback config

* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py

* fix bug

* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py

* modified:   GPT_SoVITS/inference_webui_fast.py
This commit is contained in:
ChasonJiang 2025-04-21 23:20:20 +08:00 committed by GitHub
parent bc2fe5ec86
commit e0f2818df7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 213 additions and 131 deletions

View File

@ -25,7 +25,7 @@ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from BigVGAN.bigvgan import BigVGAN from BigVGAN.bigvgan import BigVGAN
from feature_extractor.cnhubert import CNHubert from feature_extractor.cnhubert import CNHubert
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from module.models import SynthesizerTrn, SynthesizerTrnV3,Generator from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -66,6 +66,7 @@ mel_fn = lambda x: mel_spectrogram_torch(
"center": False, "center": False,
}, },
) )
mel_fn_v4 = lambda x: mel_spectrogram_torch( mel_fn_v4 = lambda x: mel_spectrogram_torch(
x, x,
**{ **{
@ -105,7 +106,7 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
resample_transform_dict = {} resample_transform_dict = {}
def resample(audio_tensor, sr0,sr1, device): def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict global resample_transform_dict
key="%s-%s"%(sr0,sr1) key="%s-%s"%(sr0,sr1)
if key not in resample_transform_dict: if key not in resample_transform_dict:
@ -144,6 +145,52 @@ class DictToAttrRecursive(dict):
class NO_PROMPT_ERROR(Exception): class NO_PROMPT_ERROR(Exception):
pass pass
# configs/tts_infer.yaml
"""
custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2
v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
version: v1
v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2
v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
version: v3
v4:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v4
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth
"""
def set_seed(seed: int): def set_seed(seed: int):
seed = int(seed) seed = int(seed)
seed = seed if seed != -1 else random.randint(0, 2**32 - 1) seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
@ -201,10 +248,11 @@ class TTS_Config:
"is_half": False, "is_half": False,
"version": "v4", "version": "v4",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth", "vits_weights_path": "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}, },
} }
configs: dict = None configs: dict = None
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
@ -261,7 +309,7 @@ class TTS_Config:
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
self.languages = self.v1_languages if self.version == "v1" else self.v2_languages self.languages = self.v1_languages if self.version == "v1" else self.v2_languages
self.is_v3_synthesizer: bool = False self.use_vocoder: bool = False
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs[version]["t2s_weights_path"] self.t2s_weights_path = self.default_configs[version]["t2s_weights_path"]
@ -341,7 +389,7 @@ class TTS_Config:
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TTS_Config) and self.configs_path == other.configs_path return isinstance(other, TTS_Config) and self.configs_path == other.configs_path
from inference_webui import v3v4set
class TTS: class TTS:
def __init__(self, configs: Union[dict, str, TTS_Config]): def __init__(self, configs: Union[dict, str, TTS_Config]):
if isinstance(configs, TTS_Config): if isinstance(configs, TTS_Config):
@ -354,10 +402,18 @@ class TTS:
self.bert_tokenizer: AutoTokenizer = None self.bert_tokenizer: AutoTokenizer = None
self.bert_model: AutoModelForMaskedLM = None self.bert_model: AutoModelForMaskedLM = None
self.cnhuhbert_model: CNHubert = None self.cnhuhbert_model: CNHubert = None
self.vocoder_model = None self.vocoder = None
self.sr_model: AP_BWE = None self.sr_model: AP_BWE = None
self.sr_model_not_exist: bool = False self.sr_model_not_exist: bool = False
self.vocoder_configs: dict = {
"sr": None,
"T_ref": None,
"T_chunk": None,
"upsample_rate": None,
"overlapped_len": None,
}
self._init_models() self._init_models()
self.text_preprocessor: TextPreprocessor = TextPreprocessor( self.text_preprocessor: TextPreprocessor = TextPreprocessor(
@ -376,6 +432,7 @@ class TTS:
"aux_ref_audio_paths": [], "aux_ref_audio_paths": [],
} }
self.stop_flag: bool = False self.stop_flag: bool = False
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
@ -408,7 +465,6 @@ class TTS:
def init_vits_weights(self, weights_path: str): def init_vits_weights(self, weights_path: str):
self.configs.vits_weights_path = weights_path self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
print(self.configs.default_configs)
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
if if_lora_v3 == True and os.path.exists(path_sovits) == False: if if_lora_v3 == True and os.path.exists(path_sovits) == False:
@ -442,23 +498,23 @@ class TTS:
# print(f"model_version:{model_version}") # print(f"model_version:{model_version}")
# print(f'hps["model"]["version"]:{hps["model"]["version"]}') # print(f'hps["model"]["version"]:{hps["model"]["version"]}')
if model_version not in v3v4set: if model_version not in ["v3", "v4"]:
vits_model = SynthesizerTrn( vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1, self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length, self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers, n_speakers=self.configs.n_speakers,
**kwargs, **kwargs,
) )
self.configs.is_v3_synthesizer = False self.configs.use_vocoder = False
else: else:
self.configs.is_v3_synthesizer = kwargs["version"]=model_version
vits_model = SynthesizerTrnV3( vits_model = SynthesizerTrnV3(
self.configs.filter_length // 2 + 1, self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length, self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers, n_speakers=self.configs.n_speakers,
**kwargs, **kwargs,
) )
self.init_vocoder() self.configs.use_vocoder = True
self.init_vocoder(model_version)
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"): if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
del vits_model.enc_q del vits_model.enc_q
@ -507,36 +563,64 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half() self.t2s_model = self.t2s_model.half()
def init_vocoder(self): def init_vocoder(self, version: str):
if self.vocoder_model is not None: if version == "v3":
return if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
if self.configs.is_v3_synthesizer=="v3": return
self.vocoder_model = BigVGAN.from_pretrained( if self.vocoder is not None:
self.vocoder.cpu()
del self.vocoder
self.empty_cache()
self.vocoder = BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False, use_cuda_kernel=False,
) # if True, RuntimeError: Ninja is required to load C++ extensions ) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode # remove weight norm in the model and set to eval mode
self.vocoder_model.remove_weight_norm() self.vocoder.remove_weight_norm()
self.vocoder_model = self.vocoder_model.eval()
else: self.vocoder_configs["sr"] = 24000
self.vocoder_model = Generator( self.vocoder_configs["T_ref"] = 468
initial_channel=100, self.vocoder_configs["T_chunk"] = 934
resblock="1", self.vocoder_configs["upsample_rate"] = 256
resblock_kernel_sizes=[3, 7, 11], self.vocoder_configs["overlapped_len"] = 12
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2], elif version == "v4":
upsample_initial_channel=512, if self.vocoder is not None and self.vocoder.__class__.__name__ == "Generator":
upsample_kernel_sizes=[20, 12, 4, 4, 4], return
gin_channels=0, is_bias=True if self.vocoder is not None:
) self.vocoder.cpu()
self.vocoder_model.eval() del self.vocoder
self.vocoder_model.remove_weight_norm() self.empty_cache()
self.vocoder = Generator(
initial_channel=100,
resblock="1",
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0, is_bias=True
)
self.vocoder.remove_weight_norm()
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu") state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
print("loading v4 vocoder", self.vocoder_model.load_state_dict(state_dict_g)) print("loading vocoder",self.vocoder.load_state_dict(state_dict_g))
self.vocoder_configs["sr"] = 48000
self.vocoder_configs["T_ref"] = 500
self.vocoder_configs["T_chunk"] = 1000
self.vocoder_configs["upsample_rate"] = 480
self.vocoder_configs["overlapped_len"] = 12
self.vocoder = self.vocoder.eval()
if self.configs.is_half == True: if self.configs.is_half == True:
self.vocoder_model = self.vocoder_model.half().to(self.configs.device) self.vocoder = self.vocoder.half().to(self.configs.device)
else: else:
self.vocoder_model = self.vocoder_model.to(self.configs.device) self.vocoder = self.vocoder.to(self.configs.device)
def init_sr_model(self): def init_sr_model(self):
if self.sr_model is not None: if self.sr_model is not None:
@ -572,8 +656,8 @@ class TTS:
self.bert_model = self.bert_model.half() self.bert_model = self.bert_model.half()
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half() self.cnhuhbert_model = self.cnhuhbert_model.half()
if self.vocoder_model is not None: if self.vocoder is not None:
self.vocoder_model = self.vocoder_model.half() self.vocoder = self.vocoder.half()
else: else:
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model = self.t2s_model.float() self.t2s_model = self.t2s_model.float()
@ -583,8 +667,8 @@ class TTS:
self.bert_model = self.bert_model.float() self.bert_model = self.bert_model.float()
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float() self.cnhuhbert_model = self.cnhuhbert_model.float()
if self.vocoder_model is not None: if self.vocoder is not None:
self.vocoder_model = self.vocoder_model.float() self.vocoder = self.vocoder.float()
def set_device(self, device: torch.device, save: bool = True): def set_device(self, device: torch.device, save: bool = True):
""" """
@ -603,8 +687,8 @@ class TTS:
self.bert_model = self.bert_model.to(device) self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device) self.cnhuhbert_model = self.cnhuhbert_model.to(device)
if self.vocoder_model is not None: if self.vocoder is not None:
self.vocoder_model = self.vocoder_model.to(device) self.vocoder = self.vocoder.to(device)
if self.sr_model is not None: if self.sr_model is not None:
self.sr_model = self.sr_model.to(device) self.sr_model = self.sr_model.to(device)
@ -915,13 +999,13 @@ class TTS:
split_bucket = False split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket and speed_factor == 1.0 and not (self.configs.is_v3_synthesizer!=False and parallel_infer): if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
print(i18n("分桶处理模式已开启")) print(i18n("分桶处理模式已开启"))
elif speed_factor != 1.0: elif speed_factor != 1.0:
print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理")) print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理"))
split_bucket = False split_bucket = False
elif self.configs.is_v3_synthesizer!=False and parallel_infer: elif self.configs.use_vocoder and parallel_infer:
print(i18n("当开启并行推理模式时SoVits V3V4模型不支持分桶处理已自动关闭分桶处理")) print(i18n("当开启并行推理模式时SoVits V3/4模型不支持分桶处理已自动关闭分桶处理"))
split_bucket = False split_bucket = False
else: else:
print(i18n("分桶处理模式已关闭")) print(i18n("分桶处理模式已关闭"))
@ -938,7 +1022,7 @@ class TTS:
if not no_prompt_text: if not no_prompt_text:
assert prompt_lang in self.configs.languages assert prompt_lang in self.configs.languages
if no_prompt_text and self.configs.is_v3_synthesizer!=False: if no_prompt_text and self.configs.use_vocoder:
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3") raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
if ref_audio_path in [None, ""] and ( if ref_audio_path in [None, ""] and (
@ -1046,12 +1130,7 @@ class TTS:
t_34 = 0.0 t_34 = 0.0
t_45 = 0.0 t_45 = 0.0
audio = [] audio = []
if self.configs.is_v3_synthesizer==False: output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
output_sr = 32000
elif self.configs.is_v3_synthesizer == "v3":
output_sr = 24000
else:
output_sr = 48000 # v4
for item in data: for item in data:
t3 = time.perf_counter() t3 = time.perf_counter()
if return_fragment: if return_fragment:
@ -1113,7 +1192,7 @@ class TTS:
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# )) # ))
print(f"############ {i18n('合成音频')} ############") print(f"############ {i18n('合成音频')} ############")
if not self.configs.is_v3_synthesizer: if not self.configs.use_vocoder:
if speed_factor == 1.0: if speed_factor == 1.0:
print(f"{i18n('并行合成中')}...") print(f"{i18n('并行合成中')}...")
# ## vits并行推理 method 2 # ## vits并行推理 method 2
@ -1150,8 +1229,8 @@ class TTS:
else: else:
if parallel_infer: if parallel_infer:
print(f"{i18n('并行合成中')}...") print(f"{i18n('并行合成中')}...")
audio_fragments = self.v3_synthesis_batched_infer( audio_fragments = self.using_vocoder_synthesis_batched_infer(
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps,model_version=self.configs.is_v3_synthesizer idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
) )
batch_audio_fragment.extend(audio_fragments) batch_audio_fragment.extend(audio_fragments)
else: else:
@ -1160,8 +1239,8 @@ class TTS:
_pred_semantic = ( _pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次 ) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.v3_synthesis( audio_fragment = self.using_vocoder_synthesis(
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps,model_version=self.configs.is_v3_synthesizer _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
) )
batch_audio_fragment.append(audio_fragment) batch_audio_fragment.append(audio_fragment)
@ -1176,7 +1255,7 @@ class TTS:
speed_factor, speed_factor,
False, False,
fragment_interval, fragment_interval,
super_sampling if self.configs.is_v3_synthesizer=="v3" else False, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
) )
else: else:
audio.append(batch_audio_fragment) audio.append(batch_audio_fragment)
@ -1197,7 +1276,7 @@ class TTS:
speed_factor, speed_factor,
split_bucket, split_bucket,
fragment_interval, fragment_interval,
super_sampling if self.configs.is_v3_synthesizer=="v3" else False, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
) )
except Exception as e: except Exception as e:
@ -1279,8 +1358,8 @@ class TTS:
return sr, audio return sr, audio
def v3_synthesis( def using_vocoder_synthesis(
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32,model_version="v4" self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
): ):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
@ -1292,22 +1371,24 @@ class TTS:
ref_audio = ref_audio.to(self.configs.device).float() ref_audio = ref_audio.to(self.configs.device).float()
if ref_audio.shape[0] == 2: if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
tgt_sr = 24000 if model_version == "v3" else 32000
if ref_sr != tgt_sr:
ref_audio = resample(ref_audio, ref_sr,tgt_sr, self.configs.device)
mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio) # tgt_sr = self.vocoder_configs["sr"]
tgt_sr = 24000 if self.configs.version == "v3" else 32000
if ref_sr != tgt_sr:
ref_audio = resample(ref_audio, ref_sr, tgt_sr, self.configs.device)
mel2 = mel_fn(ref_audio) if self.configs.version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min] mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min]
Tref = 468 if model_version == "v3" else 500 T_ref = self.vocoder_configs["T_ref"]
Tchunk = 934 if model_version == "v3" else 1000 T_chunk = self.vocoder_configs["T_chunk"]
if T_min > Tref: if T_min > T_ref:
mel2 = mel2[:, :, -Tref:] mel2 = mel2[:, :, -T_ref:]
fea_ref = fea_ref[:, :, -Tref:] fea_ref = fea_ref[:, :, -T_ref:]
T_min = Tref T_min = T_ref
chunk_len = Tchunk - T_min chunk_len = T_chunk - T_min
mel2 = mel2.to(self.precision) mel2 = mel2.to(self.precision)
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
@ -1334,18 +1415,18 @@ class TTS:
cfm_res = denorm_spec(cfm_res) cfm_res = denorm_spec(cfm_res)
with torch.inference_mode(): with torch.inference_mode():
wav_gen = self.vocoder_model(cfm_res) wav_gen = self.vocoder(cfm_res)
audio = wav_gen[0][0] # .cpu().detach().numpy() audio = wav_gen[0][0] # .cpu().detach().numpy()
return audio return audio
def v3_synthesis_batched_infer( def using_vocoder_synthesis_batched_infer(
self, self,
idx_list: List[int], idx_list: List[int],
semantic_tokens_list: List[torch.Tensor], semantic_tokens_list: List[torch.Tensor],
batch_phones: List[torch.Tensor], batch_phones: List[torch.Tensor],
speed: float = 1.0, speed: float = 1.0,
sample_steps: int = 32,model_version="v4" sample_steps: int = 32,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
@ -1357,27 +1438,29 @@ class TTS:
ref_audio = ref_audio.to(self.configs.device).float() ref_audio = ref_audio.to(self.configs.device).float()
if ref_audio.shape[0] == 2: if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
tgt_sr = 24000 if model_version == "v3" else 32000
# tgt_sr = self.vocoder_configs["sr"]
tgt_sr = 24000 if self.configs.version == "v3" else 32000
if ref_sr != tgt_sr: if ref_sr != tgt_sr:
ref_audio = resample(ref_audio, ref_sr,tgt_sr, self.configs.device) ref_audio = resample(ref_audio, ref_sr, tgt_sr, self.configs.device)
mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio) mel2 = mel_fn(ref_audio) if self.configs.version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min] mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min]
Tref = 468 if model_version == "v3" else 500 T_ref = self.vocoder_configs["T_ref"]
Tchunk = 934 if model_version == "v3" else 1000 T_chunk = self.vocoder_configs["T_chunk"]
if T_min > Tref: if T_min > T_ref:
mel2 = mel2[:, :, -Tref:] mel2 = mel2[:, :, -T_ref:]
fea_ref = fea_ref[:, :, -Tref:] fea_ref = fea_ref[:, :, -T_ref:]
T_min = Tref T_min = T_ref
chunk_len = Tchunk - T_min chunk_len = T_chunk - T_min
mel2 = mel2.to(self.precision) mel2 = mel2.to(self.precision)
# #### batched inference # #### batched inference
overlapped_len = 12 overlapped_len = self.vocoder_configs["overlapped_len"]
feat_chunks = [] feat_chunks = []
feat_lens = [] feat_lens = []
feat_list = [] feat_list = []
@ -1426,11 +1509,11 @@ class TTS:
pred_spec = denorm_spec(pred_spec) pred_spec = denorm_spec(pred_spec)
with torch.no_grad(): with torch.no_grad():
wav_gen = self.vocoder_model(pred_spec) wav_gen = self.vocoder(pred_spec)
audio = wav_gen[0][0] # .cpu().detach().numpy() audio = wav_gen[0][0] # .cpu().detach().numpy()
audio_fragments = [] audio_fragments = []
upsample_rate = 256 upsample_rate = self.vocoder_configs["upsample_rate"]
pos = 0 pos = 0
while pos < audio.shape[-1]: while pos < audio.shape[-1]:

View File

@ -1,40 +1,40 @@
custom: custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda device: cuda
is_half: true is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v4 version: v2
vits_weights_path: SoVITS_weights_v4/diangun1min_e2_s66_l32.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v1: v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
is_half: false is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
version: v1 version: v1
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
v2: v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
is_half: false is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2 version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v3: v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
is_half: false is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3 version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
v4: v4:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
is_half: false is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v4 version: v4
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth

View File

@ -211,6 +211,7 @@ pretrained_gpt_name = [
"GPT_SoVITS/pretrained_models/s1v3.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt",
] ]
_ = [[], []] _ = [[], []]
for i in range(4): for i in range(4):
if os.path.exists(pretrained_gpt_name[i]): if os.path.exists(pretrained_gpt_name[i]):
@ -219,7 +220,6 @@ for i in range(4):
_[-1].append(pretrained_sovits_name[i]) _[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name, pretrained_sovits_name = _ pretrained_gpt_name, pretrained_sovits_name = _
if os.path.exists("./weight.json"): if os.path.exists("./weight.json"):
pass pass
else: else:
@ -237,8 +237,8 @@ with open("./weight.json", "r", encoding="utf-8") as file:
sovits_path = sovits_path[0] sovits_path = sovits_path[0]
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"] SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"] GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
for path in SoVITS_weight_root + GPT_weight_root: for path in SoVITS_weight_root + GPT_weight_root:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
@ -294,7 +294,6 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
else: else:
visible_sample_steps = False visible_sample_steps = False
visible_inp_refs = True visible_inp_refs = True
# prompt_language,text_language,prompt_text,prompt_language,text,text_language,inp_refs,ref_text_free,
yield ( yield (
{"__type__": "update", "choices": list(dict_language.keys())}, {"__type__": "update", "choices": list(dict_language.keys())},
{"__type__": "update", "choices": list(dict_language.keys())}, {"__type__": "update", "choices": list(dict_language.keys())},
@ -399,7 +398,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
) )
sample_steps = gr.Radio( sample_steps = gr.Radio(
label=i18n("采样步数(仅对V3生效)"), value=32, choices=[4, 8, 16, 32], visible=True label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
) )
with gr.Row(): with gr.Row():
fragment_interval = gr.Slider( fragment_interval = gr.Slider(