为api_v2和inference_webui_fast适配V3版本 (#2188)

* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py
	modified:   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	modified:   GPT_SoVITS/inference_webui_fast.py

* 适配V3版本

* api_v2.py和inference_webui_fast.py的v3适配

* 修改了个远古bug,增加了更友好的提示信息

* 优化webui

* 修改为正确的path

* 修复v3 lora模型的载入问题

* 修复读取tts_infer.yaml文件时遇到的编码不匹配的问题
This commit is contained in:
ChasonJiang 2025-03-26 14:34:51 +08:00 committed by GitHub
parent 165882d64f
commit 7394dc7b0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 486 additions and 146 deletions

2
.gitignore vendored
View File

@ -18,3 +18,5 @@ TEMP
weight.json weight.json
ffmpeg* ffmpeg*
ffprobe* ffprobe*
tools/AP_BWE_main/24kto48k/*
!tools/AP_BWE_main/24kto48k/readme.txt

View File

@ -2,7 +2,7 @@
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
import torch.nn as nn import torch.nn as nn
from alias_free_activation.torch.resample import UpSample1d, DownSample1d from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module): class Activation1d(nn.Module):

View File

@ -3,8 +3,8 @@
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from alias_free_activation.torch.filter import LowPassFilter1d from .filter import LowPassFilter1d
from alias_free_activation.torch.filter import kaiser_sinc_filter1d from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module): class UpSample1d(nn.Module):

View File

@ -14,10 +14,10 @@ import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm from torch.nn.utils import weight_norm, remove_weight_norm
import activations from . import activations
from utils0 import init_weights, get_padding from .utils0 import init_weights, get_padding
from alias_free_activation.torch.act import Activation1d as TorchActivation1d from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
from env import AttrDict from .env import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
@ -93,7 +93,7 @@ class AMPBlock1(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d,
) )
@ -193,7 +193,7 @@ class AMPBlock2(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d,
) )
@ -271,7 +271,7 @@ class BigVGAN(
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d,
) )

View File

@ -15,7 +15,7 @@ from librosa.filters import mel as librosa_mel_fn
import pathlib import pathlib
from tqdm import tqdm from tqdm import tqdm
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from env import AttrDict from .env import AttrDict
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)

View File

@ -9,7 +9,7 @@ from torch.nn.utils import weight_norm
matplotlib.use("Agg") matplotlib.use("Agg")
import matplotlib.pylab as plt import matplotlib.pylab as plt
from meldataset import MAX_WAV_VALUE from .meldataset import MAX_WAV_VALUE
from scipy.io.wavfile import write from scipy.io.wavfile import write

View File

@ -4,6 +4,7 @@ import os, sys, gc
import random import random
import traceback import traceback
import torchaudio
from tqdm import tqdm from tqdm import tqdm
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -15,10 +16,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import yaml import yaml
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from feature_extractor.cnhubert import CNHubert from feature_extractor.cnhubert import CNHubert
from module.models import SynthesizerTrn from module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, get_peft_model
import librosa import librosa
from time import time as ttime from time import time as ttime
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
@ -26,10 +28,98 @@ from tools.my_utils import load_audio
from module.mel_processing import spectrogram_torch from module.mel_processing import spectrogram_torch
from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from BigVGAN.bigvgan import BigVGAN
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
language=os.environ.get("language","Auto") language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language) i18n = I18nAuto(language=language)
spec_min = -12
spec_max = 2
def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram_torch(x, **{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
"num_mels": 100,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"center": False
})
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()
# 设置 ffmpeg 输入流
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
# 变速处理
output_stream = input_stream.filter('atempo', speed)
# 输出流到管道
out, _ = (
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
)
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
resample_transform_dict={}
def resample(audio_tensor, sr0, device):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
return resample_transform_dict[sr0](audio_tensor)
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
class NO_PROMPT_ERROR(Exception):
pass
# configs/tts_infer.yaml # configs/tts_infer.yaml
""" """
custom: custom:
@ -56,11 +146,19 @@ default_v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2 version: v2
default_v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
version: v3
""" """
def set_seed(seed:int): def set_seed(seed:int):
seed = int(seed) seed = int(seed)
seed = seed if seed != -1 else random.randrange(1 << 32) seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
print(f"Set seed to {seed}") print(f"Set seed to {seed}")
os.environ['PYTHONHASHSEED'] = str(seed) os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed) random.seed(seed)
@ -82,7 +180,7 @@ def set_seed(seed:int):
class TTS_Config: class TTS_Config:
default_configs={ default_configs={
"default":{ "v1":{
"device": "cpu", "device": "cpu",
"is_half": False, "is_half": False,
"version": "v1", "version": "v1",
@ -91,7 +189,7 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}, },
"default_v2":{ "v2":{
"device": "cpu", "device": "cpu",
"is_half": False, "is_half": False,
"version": "v2", "version": "v2",
@ -100,6 +198,15 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}, },
"v3":{
"device": "cpu",
"is_half": False,
"version": "v3",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
} }
configs:dict = None configs:dict = None
v1_languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v1_languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
@ -136,12 +243,9 @@ class TTS_Config:
assert isinstance(configs, dict) assert isinstance(configs, dict)
version = configs.get("version", "v2").lower() version = configs.get("version", "v2").lower()
assert version in ["v1", "v2"] assert version in ["v1", "v2", "v3"]
self.default_configs["default"] = configs.get("default", self.default_configs["default"]) self.default_configs[version] = configs.get(version, self.default_configs[version])
self.default_configs["default_v2"] = configs.get("default_v2", self.default_configs["default_v2"]) self.configs:dict = configs.get("custom", deepcopy(self.default_configs[version]))
default_config_key = "default"if version=="v1" else "default_v2"
self.configs:dict = configs.get("custom", deepcopy(self.default_configs[default_config_key]))
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
@ -159,20 +263,22 @@ class TTS_Config:
self.vits_weights_path = self.configs.get("vits_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None) self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages self.languages = self.v1_languages if self.version=="v1" else self.v2_languages
self.is_v3_synthesizer:bool = False
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path'] self.t2s_weights_path = self.default_configs[version]['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
self.vits_weights_path = self.default_configs[default_config_key]['vits_weights_path'] self.vits_weights_path = self.default_configs[version]['vits_weights_path']
print(f"fall back to default vits_weights_path: {self.vits_weights_path}") print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
self.bert_base_path = self.default_configs[default_config_key]['bert_base_path'] self.bert_base_path = self.default_configs[version]['bert_base_path']
print(f"fall back to default bert_base_path: {self.bert_base_path}") print(f"fall back to default bert_base_path: {self.bert_base_path}")
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
self.cnhuhbert_base_path = self.default_configs[default_config_key]['cnhuhbert_base_path'] self.cnhuhbert_base_path = self.default_configs[version]['cnhuhbert_base_path']
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs() self.update_configs()
@ -195,7 +301,7 @@ class TTS_Config:
else: else:
print(i18n("路径不存在,使用默认配置")) print(i18n("路径不存在,使用默认配置"))
self.save_configs(configs_path) self.save_configs(configs_path)
with open(configs_path, 'r') as f: with open(configs_path, 'r', encoding='utf-8') as f:
configs = yaml.load(f, Loader=yaml.FullLoader) configs = yaml.load(f, Loader=yaml.FullLoader)
return configs return configs
@ -224,7 +330,7 @@ class TTS_Config:
def update_version(self, version:str)->None: def update_version(self, version:str)->None:
self.version = version self.version = version
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages self.languages = self.v1_languages if self.version=="v1" else self.v2_languages
def __str__(self): def __str__(self):
self.configs = self.update_configs() self.configs = self.update_configs()
@ -252,10 +358,13 @@ class TTS:
self.configs:TTS_Config = TTS_Config(configs) self.configs:TTS_Config = TTS_Config(configs)
self.t2s_model:Text2SemanticLightningModule = None self.t2s_model:Text2SemanticLightningModule = None
self.vits_model:SynthesizerTrn = None self.vits_model:Union[SynthesizerTrn, SynthesizerTrnV3] = None
self.bert_tokenizer:AutoTokenizer = None self.bert_tokenizer:AutoTokenizer = None
self.bert_model:AutoModelForMaskedLM = None self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None self.cnhuhbert_model:CNHubert = None
self.bigvgan_model:BigVGAN = None
self.sr_model:AP_BWE = None
self.sr_model_not_exist:bool = False
self._init_models() self._init_models()
@ -310,38 +419,83 @@ class TTS:
self.bert_model = self.bert_model.half() self.bert_model = self.bert_model.half()
def init_vits_weights(self, weights_path: str): def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path
dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.configs.update_version("v1")
else:
self.configs.update_version("v2")
self.configs.save_configs()
hps["model"]["version"] = self.configs.version self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(weights_path)
path_sovits_v3=self.configs.default_configs["v3"]["vits_weights_path"]
if if_lora_v3==True and os.path.exists(path_sovits_v3)==False:
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
raise FileExistsError(info)
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
dict_s2 = load_sovits_new(weights_path)
hps = dict_s2["config"]
hps["model"]["semantic_frame_rate"] = "25hz"
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps["model"]["version"] = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps["model"]["version"] = "v1"
else:
hps["model"]["version"] = "v2"
# version = hps["model"]["version"]
self.configs.filter_length = hps["data"]["filter_length"] self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"] self.configs.segment_size = hps["train"]["segment_size"]
self.configs.sampling_rate = hps["data"]["sampling_rate"] self.configs.sampling_rate = hps["data"]["sampling_rate"]
self.configs.hop_length = hps["data"]["hop_length"] self.configs.hop_length = hps["data"]["hop_length"]
self.configs.win_length = hps["data"]["win_length"] self.configs.win_length = hps["data"]["win_length"]
self.configs.n_speakers = hps["data"]["n_speakers"] self.configs.n_speakers = hps["data"]["n_speakers"]
self.configs.semantic_frame_rate = "25hz" self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"]
kwargs = hps["model"] kwargs = hps["model"]
vits_model = SynthesizerTrn( # print(f"self.configs.sampling_rate:{self.configs.sampling_rate}")
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length, self.configs.update_version(model_version)
n_speakers=self.configs.n_speakers,
**kwargs # print(f"model_version:{model_version}")
) # print(f'hps["model"]["version"]:{hps["model"]["version"]}')
if model_version!="v3":
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs
)
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
self.configs.is_v3_synthesizer = False
else:
vits_model = SynthesizerTrnV3(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs
)
self.configs.is_v3_synthesizer = True
self.init_bigvgan()
if if_lora_v3==False:
print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
else:
print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}")
lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
vits_model.cfm = vits_model.cfm.merge_and_unload()
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
vits_model = vits_model.to(self.configs.device) vits_model = vits_model.to(self.configs.device)
vits_model = vits_model.eval() vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model self.vits_model = vits_model
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device)!="cpu":
self.vits_model = self.vits_model.half() self.vits_model = self.vits_model.half()
@ -363,6 +517,30 @@ class TTS:
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device)!="cpu":
self.t2s_model = self.t2s_model.half() self.t2s_model = self.t2s_model.half()
def init_bigvgan(self):
if self.bigvgan_model is not None:
return
self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
self.bigvgan_model.remove_weight_norm()
self.bigvgan_model = self.bigvgan_model.eval()
if self.configs.is_half == True:
self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device)
else:
self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
def init_sr_model(self):
if self.sr_model is not None:
return
try:
self.sr_model:AP_BWE=AP_BWE(self.configs.device,DictToAttrRecursive)
self.sr_model_not_exist = False
except FileNotFoundError:
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
self.sr_model_not_exist = True
def enable_half_precision(self, enable: bool = True, save: bool = True): def enable_half_precision(self, enable: bool = True, save: bool = True):
''' '''
To enable half precision for the TTS model. To enable half precision for the TTS model.
@ -387,6 +565,8 @@ class TTS:
self.bert_model =self.bert_model.half() self.bert_model =self.bert_model.half()
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half() self.cnhuhbert_model = self.cnhuhbert_model.half()
if self.bigvgan_model is not None:
self.bigvgan_model = self.bigvgan_model.half()
else: else:
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model = self.t2s_model.float() self.t2s_model = self.t2s_model.float()
@ -396,6 +576,8 @@ class TTS:
self.bert_model = self.bert_model.float() self.bert_model = self.bert_model.float()
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float() self.cnhuhbert_model = self.cnhuhbert_model.float()
if self.bigvgan_model is not None:
self.bigvgan_model = self.bigvgan_model.float()
def set_device(self, device: torch.device, save: bool = True): def set_device(self, device: torch.device, save: bool = True):
''' '''
@ -414,6 +596,11 @@ class TTS:
self.bert_model = self.bert_model.to(device) self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device) self.cnhuhbert_model = self.cnhuhbert_model.to(device)
if self.bigvgan_model is not None:
self.bigvgan_model = self.bigvgan_model.to(device)
if self.sr_model is not None:
self.sr_model = self.sr_model.to(device)
def set_ref_audio(self, ref_audio_path:str): def set_ref_audio(self, ref_audio_path:str):
''' '''
@ -437,6 +624,11 @@ class TTS:
self.prompt_cache["refer_spec"][0] = spec self.prompt_cache["refer_spec"][0] = spec
def _get_ref_spec(self, ref_audio_path): def _get_ref_spec(self, ref_audio_path):
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
raw_audio=raw_audio.to(self.configs.device).float()
self.prompt_cache["raw_audio"] = raw_audio
self.prompt_cache["raw_sr"] = raw_sr
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
maxx=audio.abs().max() maxx=audio.abs().max()
@ -625,11 +817,11 @@ class TTS:
Recovery the order of the audio according to the batch_index_list. Recovery the order of the audio according to the batch_index_list.
Args: Args:
data (List[list(np.ndarray)]): the out of order audio . data (List[list(torch.Tensor)]): the out of order audio .
batch_index_list (List[list[int]]): the batch index list. batch_index_list (List[list[int]]): the batch index list.
Returns: Returns:
list (List[np.ndarray]): the data in the original order. list (List[torch.Tensor]): the data in the original order.
''' '''
length = len(sum(batch_index_list, [])) length = len(sum(batch_index_list, []))
_data = [None]*length _data = [None]*length
@ -671,6 +863,8 @@ class TTS:
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model. "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
} }
returns: returns:
Tuple[int, np.ndarray]: sampling rate and audio data. Tuple[int, np.ndarray]: sampling rate and audio data.
@ -698,6 +892,8 @@ class TTS:
actual_seed = set_seed(seed) actual_seed = set_seed(seed)
parallel_infer = inputs.get("parallel_infer", True) parallel_infer = inputs.get("parallel_infer", True)
repetition_penalty = inputs.get("repetition_penalty", 1.35) repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 32)
super_sampling = inputs.get("super_sampling", False)
if parallel_infer: if parallel_infer:
print(i18n("并行推理模式已开启")) print(i18n("并行推理模式已开启"))
@ -732,6 +928,9 @@ class TTS:
if not no_prompt_text: if not no_prompt_text:
assert prompt_lang in self.configs.languages assert prompt_lang in self.configs.languages
if no_prompt_text and self.configs.is_v3_synthesizer:
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
if ref_audio_path in [None, ""] and \ if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])): ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
@ -761,13 +960,13 @@ class TTS:
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_lang != "en" else "." if (prompt_text[-1] not in splits): prompt_text += "" if prompt_lang != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text) print(i18n("实际输入的参考文本:"), prompt_text)
if self.prompt_cache["prompt_text"] != prompt_text: if self.prompt_cache["prompt_text"] != prompt_text:
self.prompt_cache["prompt_text"] = prompt_text
self.prompt_cache["prompt_lang"] = prompt_lang
phones, bert_features, norm_text = \ phones, bert_features, norm_text = \
self.text_preprocessor.segment_and_extract_feature_for_text( self.text_preprocessor.segment_and_extract_feature_for_text(
prompt_text, prompt_text,
prompt_lang, prompt_lang,
self.configs.version) self.configs.version)
self.prompt_cache["prompt_text"] = prompt_text
self.prompt_cache["prompt_lang"] = prompt_lang
self.prompt_cache["phones"] = phones self.prompt_cache["phones"] = phones
self.prompt_cache["bert_features"] = bert_features self.prompt_cache["bert_features"] = bert_features
self.prompt_cache["norm_text"] = norm_text self.prompt_cache["norm_text"] = norm_text
@ -781,8 +980,7 @@ class TTS:
if not return_fragment: if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0: if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield 16000, np.zeros(int(16000), dtype=np.int16)
dtype=np.int16)
return return
batch_index_list:list = None batch_index_list:list = None
@ -836,6 +1034,7 @@ class TTS:
t_34 = 0.0 t_34 = 0.0
t_45 = 0.0 t_45 = 0.0
audio = [] audio = []
output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000
for item in data: for item in data:
t3 = ttime() t3 = ttime()
if return_fragment: if return_fragment:
@ -858,7 +1057,7 @@ class TTS:
else: else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
print(f"############ {i18n('预测语义Token')} ############")
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
all_phoneme_lens, all_phoneme_lens,
@ -892,70 +1091,80 @@ class TTS:
# batch_audio_fragment = (self.vits_model.batched_decode( # batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# )) # ))
print(f"############ {i18n('合成音频')} ############")
if speed_factor == 1.0: if not self.configs.is_v3_synthesizer:
# ## vits并行推理 method 2 if speed_factor == 1.0:
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] # ## vits并行推理 method 2
upsample_rate = math.prod(self.vits_model.upsample_rates) pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode( _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor _batch_audio_fragment = (self.vits_model.decode(
).detach()[0, 0, :]) all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
audio_frag_end_idx.insert(0, 0) ).detach()[0, 0, :])
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
else:
# ## vits串行推理
for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :])
batch_audio_fragment.append(
audio_fragment
) ###试试重建不带上prompt部分
else: else:
# ## vits串行推理 for i, idx in enumerate(tqdm(idx_list)):
for i, idx in enumerate(idx_list):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device) phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode( audio_fragment = self.v3_synthesis(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
).detach()[0, 0, :]) )
batch_audio_fragment.append( batch_audio_fragment.append(
audio_fragment audio_fragment
) ###试试重建不带上prompt部分 )
t5 = ttime() t5 = ttime()
t_45 += t5 - t4 t_45 += t5 - t4
if return_fragment: if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment], yield self.audio_postprocess([batch_audio_fragment],
self.configs.sampling_rate, output_sr,
None, None,
speed_factor, speed_factor,
False, False,
fragment_interval fragment_interval,
super_sampling if self.configs.is_v3_synthesizer else False
) )
else: else:
audio.append(batch_audio_fragment) audio.append(batch_audio_fragment)
if self.stop_flag: if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield 16000, np.zeros(int(16000), dtype=np.int16)
dtype=np.int16)
return return
if not return_fragment: if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if len(audio) == 0: if len(audio) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield 16000, np.zeros(int(16000), dtype=np.int16)
dtype=np.int16)
return return
yield self.audio_postprocess(audio, yield self.audio_postprocess(audio,
self.configs.sampling_rate, output_sr,
batch_index_list, batch_index_list,
speed_factor, speed_factor,
split_bucket, split_bucket,
fragment_interval fragment_interval,
super_sampling if self.configs.is_v3_synthesizer else False
) )
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。 # 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield 16000, np.zeros(int(16000), dtype=np.int16)
dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。 # 重置模型, 否则会导致显存释放不完全。
del self.t2s_model del self.t2s_model
del self.vits_model del self.vits_model
@ -983,7 +1192,8 @@ class TTS:
batch_index_list:list=None, batch_index_list:list=None,
speed_factor:float=1.0, speed_factor:float=1.0,
split_bucket:bool=True, split_bucket:bool=True,
fragment_interval:float=0.3 fragment_interval:float=0.3,
super_sampling:bool=False,
)->Tuple[int, np.ndarray]: )->Tuple[int, np.ndarray]:
zero_wav = torch.zeros( zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), int(self.configs.sampling_rate * fragment_interval),
@ -996,7 +1206,7 @@ class TTS:
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio if max_audio>1: audio_fragment/=max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy() audio[i][j] = audio_fragment
if split_bucket: if split_bucket:
@ -1005,8 +1215,21 @@ class TTS:
# audio = [item for batch in audio for item in batch] # audio = [item for batch in audio for item in batch]
audio = sum(audio, []) audio = sum(audio, [])
audio = torch.cat(audio, dim=0)
if super_sampling:
print(f"############ {i18n('音频超采样')} ############")
t1 = ttime()
self.init_sr_model()
if not self.sr_model_not_exist:
audio,sr=self.sr_model(audio.unsqueeze(0),sr)
max_audio=np.abs(audio).max()
if max_audio > 1: audio /= max_audio
t2 = ttime()
print(f"超采样用时:{t2-t1:.3f}s")
else:
audio = audio.cpu().numpy()
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16) audio = (audio * 32768).astype(np.int16)
# try: # try:
@ -1018,25 +1241,59 @@ class TTS:
return sr, audio return sr, audio
def v3_synthesis(self,
semantic_tokens:torch.Tensor,
phones:torch.Tensor,
speed:float=1.0,
sample_steps:int=32
):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
def speed_change(input_audio:np.ndarray, speed:float, sr:int): fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
# 将 NumPy 数组转换为原始 PCM 流 ref_audio:torch.Tensor = self.prompt_cache["raw_audio"]
raw_audio = input_audio.astype(np.int16).tobytes() ref_sr = self.prompt_cache["raw_sr"]
ref_audio=ref_audio.to(self.configs.device).float()
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr!=24000:
ref_audio=resample(ref_audio, ref_sr, self.configs.device)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if (T_min > 468):
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
# 设置 ffmpeg 输入流 mel2=mel2.to(self.precision)
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1) fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
# 变速处理 cfm_resss = []
output_stream = input_stream.filter('atempo', speed) idx = 0
while (1):
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# 输出流到管道 cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
out, _ = ( cfm_res = cfm_res[:, :, mel2.shape[2]:]
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le') mel2 = cfm_res[:, :, -T_min:]
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
)
# 将管道输出解码为 NumPy 数组 fea_ref = fea_todo_chunk[:, :, -T_min:]
processed_audio = np.frombuffer(out, np.int16) cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
return processed_audio with torch.inference_mode():
wav_gen = self.bigvgan_model(cmf_res)
audio=wav_gen[0][0]#.cpu().detach().numpy()
return audio

View File

@ -118,11 +118,11 @@ class TextPreprocessor:
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False): def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","") # language = language.replace("all_","")
formattext = text formattext = text
while " " in formattext: while " " in formattext:
formattext = formattext.replace(" ", " ") formattext = formattext.replace(" ", " ")
if language == "zh": if language == "all_zh":
if re.search(r'[A-Za-z]', formattext): if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext) formattext = chinese.mix_text_normalize(formattext)
@ -130,7 +130,7 @@ class TextPreprocessor:
else: else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device) bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext): elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext) formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"yue",version) return self.get_phones_and_bert(formattext,"yue",version)
@ -199,6 +199,7 @@ class TextPreprocessor:
return phone_level_feature.T return phone_level_feature.T
def clean_text_inf(self, text:str, language:str, version:str="v2"): def clean_text_inf(self, text:str, language:str, version:str="v2"):
language = language.replace("all_","")
phones, word2ph, norm_text = clean_text(text, language, version) phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version) phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text return phones, word2ph, norm_text

View File

@ -6,7 +6,7 @@ custom:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2 version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
default: v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
@ -14,7 +14,7 @@ default:
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
version: v1 version: v1
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
default_v2: v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
@ -22,3 +22,11 @@ default_v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2 version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth

View File

@ -7,7 +7,7 @@
全部按日文识别 全部按日文识别
''' '''
import random import random
import os, re, logging import os, re, logging, json
import sys import sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -44,7 +44,7 @@ bert_path = os.environ.get("bert_path", None)
version=os.environ.get("version","v2") version=os.environ.get("version","v2")
import gradio as gr import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
from TTS_infer_pack.text_segmentation_method import get_method from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
@ -62,6 +62,9 @@ if torch.cuda.is_available():
else: else:
device = "cpu" device = "cpu"
# is_half = False
# device = "cpu"
dict_language_v1 = { dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别 i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变 i18n("英文"): "en",#全部按英文识别#######不变
@ -123,11 +126,11 @@ def inference(text, text_lang,
speed_factor, ref_text_free, speed_factor, ref_text_free,
split_bucket,fragment_interval, split_bucket,fragment_interval,
seed, keep_random, parallel_infer, seed, keep_random, parallel_infer,
repetition_penalty repetition_penalty, sample_steps, super_sampling,
): ):
seed = -1 if keep_random else seed seed = -1 if keep_random else seed
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32) actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
inputs={ inputs={
"text": text, "text": text,
"text_lang": dict_language[text_lang], "text_lang": dict_language[text_lang],
@ -147,9 +150,14 @@ def inference(text, text_lang,
"seed":actual_seed, "seed":actual_seed,
"parallel_infer": parallel_infer, "parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty, "repetition_penalty": repetition_penalty,
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
} }
for item in tts_pipeline.run(inputs): try:
yield item, actual_seed for item in tts_pipeline.run(inputs):
yield item, actual_seed
except NO_PROMPT_ERROR:
gr.Warning(i18n('V3不支持无参考文本模式请填写参考文本'))
def custom_sort_key(s): def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分 # 使用正则表达式提取字符串中的数字部分和非数字部分
@ -163,19 +171,38 @@ def change_choices():
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"} return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
_ =[[],[]] _ =[[],[]]
for i in range(2): for i in range(3):
if os.path.exists(pretrained_gpt_name[i]): if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
_[0].append(pretrained_gpt_name[i]) if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name,pretrained_sovits_name = _ pretrained_gpt_name,pretrained_sovits_name = _
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
GPT_weight_root=["GPT_weights_v2","GPT_weights"] if os.path.exists(f"./weight.json"):
pass
else:
with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
with open(f"./weight.json", 'r', encoding="utf-8") as file:
weight_data = file.read()
weight_data=json.loads(weight_data)
gpt_path = os.environ.get(
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
sovits_path = os.environ.get(
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
if isinstance(gpt_path,list):
gpt_path = gpt_path[0]
if isinstance(sovits_path,list):
sovits_path = sovits_path[0]
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
for path in SoVITS_weight_root+GPT_weight_root: for path in SoVITS_weight_root+GPT_weight_root:
os.makedirs(path,exist_ok=True) os.makedirs(path,exist_ok=True)
@ -194,10 +221,18 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
tts_pipeline.init_vits_weights(sovits_path)
global version, dict_language global version, dict_language
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 and not os.path.exists(path_sovits_v3):
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info)
raise FileExistsError(info)
tts_pipeline.init_vits_weights(sovits_path)
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2 dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
if prompt_language is not None and text_language is not None: if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()): if prompt_language in list(dict_language.keys()):
@ -210,9 +245,19 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
else: else:
text_update = {'__type__':'update', 'value':''} text_update = {'__type__':'update', 'value':''}
text_language_update = {'__type__':'update', 'value':i18n("中文")} text_language_update = {'__type__':'update', 'value':i18n("中文")}
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update if model_version=="v3":
visible_sample_steps=True
visible_inp_refs=False
else:
visible_sample_steps=False
visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
with open("./weight.json")as f:
data=f.read()
data=json.loads(data)
data["SoVITS"][version]=sovits_path
with open("./weight.json","w")as f:f.write(json.dumps(data))
with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown( gr.Markdown(
@ -257,13 +302,19 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) with gr.Row():
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True) batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True) sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) with gr.Row():
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True) with gr.Row():
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
with gr.Row():
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
how_to_cut = gr.Dropdown( how_to_cut = gr.Dropdown(
@ -272,10 +323,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
value=i18n("凑四句一切"), value=i18n("凑四句一切"),
interactive=True, scale=1 interactive=True, scale=1
) )
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
with gr.Row():
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True) parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True) split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
with gr.Row(): with gr.Row():
seed = gr.Number(label=i18n("随机种子"),value=-1) seed = gr.Number(label=i18n("随机种子"),value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True) keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
@ -295,7 +350,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
speed_factor, ref_text_free, speed_factor, ref_text_free,
split_bucket,fragment_interval, split_bucket,fragment_interval,
seed, keep_random, parallel_infer, seed, keep_random, parallel_infer,
repetition_penalty repetition_penalty, sample_steps, super_sampling,
], ],
[output, seed], [output, seed],
) )

View File

@ -39,6 +39,8 @@ POST:
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model. "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
} }
``` ```
@ -164,6 +166,8 @@ class TTS_Request(BaseModel):
streaming_mode:bool = False streaming_mode:bool = False
parallel_infer:bool = True parallel_infer:bool = True
repetition_penalty:float = 1.35 repetition_penalty:float = 1.35
sample_steps:int = 32
super_sampling:bool = False
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
@ -295,6 +299,8 @@ async def tts_handle(req:dict):
"streaming_mode": False, # bool. whether to return a streaming response. "streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference. "parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
} }
returns: returns:
StreamingResponse: audio stream response. StreamingResponse: audio stream response.
@ -316,10 +322,12 @@ async def tts_handle(req:dict):
if streaming_mode: if streaming_mode:
def streaming_generator(tts_generator:Generator, media_type:str): def streaming_generator(tts_generator:Generator, media_type:str):
if media_type == "wav": if_frist_chunk = True
yield wave_header_chunk()
media_type = "raw"
for sr, chunk in tts_generator: for sr, chunk in tts_generator:
if if_frist_chunk and media_type == "wav":
yield wave_header_chunk(sample_rate=sr)
media_type = "raw"
if_frist_chunk = False
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
@ -365,7 +373,9 @@ async def tts_get_endpoint(
media_type:str = "wav", media_type:str = "wav",
streaming_mode:bool = False, streaming_mode:bool = False,
parallel_infer:bool = True, parallel_infer:bool = True,
repetition_penalty:float = 1.35 repetition_penalty:float = 1.35,
sample_steps:int =32,
super_sampling:bool = False
): ):
req = { req = {
"text": text, "text": text,
@ -387,7 +397,9 @@ async def tts_get_endpoint(
"media_type":media_type, "media_type":media_type,
"streaming_mode":streaming_mode, "streaming_mode":streaming_mode,
"parallel_infer":parallel_infer, "parallel_infer":parallel_infer,
"repetition_penalty":float(repetition_penalty) "repetition_penalty":float(repetition_penalty),
"sample_steps":int(sample_steps),
"super_sampling":super_sampling
} }
return await tts_handle(req) return await tts_handle(req)

View File

@ -39,6 +39,11 @@ class AP_BWE():
self.model=model self.model=model
self.h=h self.h=h
def to(self, *arg, **kwargs):
self.model.to(*arg, **kwargs)
self.device = self.model.conv_pre_mag.weight.device
return self
def __call__(self, audio,orig_sampling_rate): def __call__(self, audio,orig_sampling_rate):
with torch.no_grad(): with torch.no_grad():
# audio, orig_sampling_rate = torchaudio.load(inp_path) # audio, orig_sampling_rate = torchaudio.load(inp_path)