mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 04:22:46 +08:00
为api_v2和inference_webui_fast适配V3版本 (#2188)
* modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py modified: GPT_SoVITS/inference_webui_fast.py * 适配V3版本 * api_v2.py和inference_webui_fast.py的v3适配 * 修改了个远古bug,增加了更友好的提示信息 * 优化webui * 修改为正确的path * 修复v3 lora模型的载入问题 * 修复读取tts_infer.yaml文件时遇到的编码不匹配的问题
This commit is contained in:
parent
165882d64f
commit
7394dc7b0c
4
.gitignore
vendored
4
.gitignore
vendored
@ -17,4 +17,6 @@ SoVITS_weights_v3
|
||||
TEMP
|
||||
weight.json
|
||||
ffmpeg*
|
||||
ffprobe*
|
||||
ffprobe*
|
||||
tools/AP_BWE_main/24kto48k/*
|
||||
!tools/AP_BWE_main/24kto48k/readme.txt
|
@ -2,7 +2,7 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
||||
from .resample import UpSample1d, DownSample1d
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
|
@ -3,8 +3,8 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from alias_free_activation.torch.filter import LowPassFilter1d
|
||||
from alias_free_activation.torch.filter import kaiser_sinc_filter1d
|
||||
from .filter import LowPassFilter1d
|
||||
from .filter import kaiser_sinc_filter1d
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
|
@ -14,10 +14,10 @@ import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
import activations
|
||||
from utils0 import init_weights, get_padding
|
||||
from alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
from env import AttrDict
|
||||
from . import activations
|
||||
from .utils0 import init_weights, get_padding
|
||||
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
from .env import AttrDict
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
|
||||
@ -93,7 +93,7 @@ class AMPBlock1(torch.nn.Module):
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
from .alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
@ -193,7 +193,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
from .alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
@ -271,7 +271,7 @@ class BigVGAN(
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
from .alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
|
@ -15,7 +15,7 @@ from librosa.filters import mel as librosa_mel_fn
|
||||
import pathlib
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple, Optional
|
||||
from env import AttrDict
|
||||
from .env import AttrDict
|
||||
|
||||
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
||||
|
||||
|
@ -9,7 +9,7 @@ from torch.nn.utils import weight_norm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
from meldataset import MAX_WAV_VALUE
|
||||
from .meldataset import MAX_WAV_VALUE
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ import os, sys, gc
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -15,10 +16,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
from tools.audio_sr import AP_BWE
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from feature_extractor.cnhubert import CNHubert
|
||||
from module.models import SynthesizerTrn
|
||||
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||
from peft import LoraConfig, get_peft_model
|
||||
import librosa
|
||||
from time import time as ttime
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
@ -26,10 +28,98 @@ from tools.my_utils import load_audio
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
from BigVGAN.bigvgan import BigVGAN
|
||||
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
|
||||
|
||||
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
def norm_spec(x):
|
||||
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
||||
def denorm_spec(x):
|
||||
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||||
mel_fn=lambda x: mel_spectrogram_torch(x, **{
|
||||
"n_fft": 1024,
|
||||
"win_size": 1024,
|
||||
"hop_size": 256,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 24000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False
|
||||
})
|
||||
|
||||
|
||||
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||
# 将 NumPy 数组转换为原始 PCM 流
|
||||
raw_audio = input_audio.astype(np.int16).tobytes()
|
||||
|
||||
# 设置 ffmpeg 输入流
|
||||
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
|
||||
|
||||
# 变速处理
|
||||
output_stream = input_stream.filter('atempo', speed)
|
||||
|
||||
# 输出流到管道
|
||||
out, _ = (
|
||||
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
|
||||
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
|
||||
# 将管道输出解码为 NumPy 数组
|
||||
processed_audio = np.frombuffer(out, np.int16)
|
||||
|
||||
return processed_audio
|
||||
|
||||
|
||||
|
||||
resample_transform_dict={}
|
||||
def resample(audio_tensor, sr0, device):
|
||||
global resample_transform_dict
|
||||
if sr0 not in resample_transform_dict:
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
|
||||
sr0, 24000
|
||||
).to(device)
|
||||
return resample_transform_dict[sr0](audio_tensor)
|
||||
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
self[key] = value
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
class NO_PROMPT_ERROR(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# configs/tts_infer.yaml
|
||||
"""
|
||||
custom:
|
||||
@ -56,11 +146,19 @@ default_v2:
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
version: v2
|
||||
default_v3:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
|
||||
version: v3
|
||||
"""
|
||||
|
||||
def set_seed(seed:int):
|
||||
seed = int(seed)
|
||||
seed = seed if seed != -1 else random.randrange(1 << 32)
|
||||
seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
|
||||
print(f"Set seed to {seed}")
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
random.seed(seed)
|
||||
@ -82,7 +180,7 @@ def set_seed(seed:int):
|
||||
|
||||
class TTS_Config:
|
||||
default_configs={
|
||||
"default":{
|
||||
"v1":{
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"version": "v1",
|
||||
@ -91,7 +189,7 @@ class TTS_Config:
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
},
|
||||
"default_v2":{
|
||||
"v2":{
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"version": "v2",
|
||||
@ -100,6 +198,15 @@ class TTS_Config:
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
},
|
||||
"v3":{
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"version": "v3",
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
},
|
||||
}
|
||||
configs:dict = None
|
||||
v1_languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
@ -136,12 +243,9 @@ class TTS_Config:
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
version = configs.get("version", "v2").lower()
|
||||
assert version in ["v1", "v2"]
|
||||
self.default_configs["default"] = configs.get("default", self.default_configs["default"])
|
||||
self.default_configs["default_v2"] = configs.get("default_v2", self.default_configs["default_v2"])
|
||||
|
||||
default_config_key = "default"if version=="v1" else "default_v2"
|
||||
self.configs:dict = configs.get("custom", deepcopy(self.default_configs[default_config_key]))
|
||||
assert version in ["v1", "v2", "v3"]
|
||||
self.default_configs[version] = configs.get(version, self.default_configs[version])
|
||||
self.configs:dict = configs.get("custom", deepcopy(self.default_configs[version]))
|
||||
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
@ -159,20 +263,22 @@ class TTS_Config:
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
|
||||
self.languages = self.v1_languages if self.version=="v1" else self.v2_languages
|
||||
|
||||
self.is_v3_synthesizer:bool = False
|
||||
|
||||
|
||||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||
self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path']
|
||||
self.t2s_weights_path = self.default_configs[version]['t2s_weights_path']
|
||||
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||||
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
||||
self.vits_weights_path = self.default_configs[default_config_key]['vits_weights_path']
|
||||
self.vits_weights_path = self.default_configs[version]['vits_weights_path']
|
||||
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
||||
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
||||
self.bert_base_path = self.default_configs[default_config_key]['bert_base_path']
|
||||
self.bert_base_path = self.default_configs[version]['bert_base_path']
|
||||
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
||||
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||||
self.cnhuhbert_base_path = self.default_configs[default_config_key]['cnhuhbert_base_path']
|
||||
self.cnhuhbert_base_path = self.default_configs[version]['cnhuhbert_base_path']
|
||||
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||
self.update_configs()
|
||||
|
||||
@ -195,7 +301,7 @@ class TTS_Config:
|
||||
else:
|
||||
print(i18n("路径不存在,使用默认配置"))
|
||||
self.save_configs(configs_path)
|
||||
with open(configs_path, 'r') as f:
|
||||
with open(configs_path, 'r', encoding='utf-8') as f:
|
||||
configs = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
return configs
|
||||
@ -224,7 +330,7 @@ class TTS_Config:
|
||||
|
||||
def update_version(self, version:str)->None:
|
||||
self.version = version
|
||||
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
|
||||
self.languages = self.v1_languages if self.version=="v1" else self.v2_languages
|
||||
|
||||
def __str__(self):
|
||||
self.configs = self.update_configs()
|
||||
@ -252,10 +358,13 @@ class TTS:
|
||||
self.configs:TTS_Config = TTS_Config(configs)
|
||||
|
||||
self.t2s_model:Text2SemanticLightningModule = None
|
||||
self.vits_model:SynthesizerTrn = None
|
||||
self.vits_model:Union[SynthesizerTrn, SynthesizerTrnV3] = None
|
||||
self.bert_tokenizer:AutoTokenizer = None
|
||||
self.bert_model:AutoModelForMaskedLM = None
|
||||
self.cnhuhbert_model:CNHubert = None
|
||||
self.bigvgan_model:BigVGAN = None
|
||||
self.sr_model:AP_BWE = None
|
||||
self.sr_model_not_exist:bool = False
|
||||
|
||||
self._init_models()
|
||||
|
||||
@ -310,38 +419,83 @@ class TTS:
|
||||
self.bert_model = self.bert_model.half()
|
||||
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
print(f"Loading VITS weights from {weights_path}")
|
||||
|
||||
self.configs.vits_weights_path = weights_path
|
||||
dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
|
||||
hps = dict_s2["config"]
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
self.configs.update_version("v1")
|
||||
else:
|
||||
self.configs.update_version("v2")
|
||||
self.configs.save_configs()
|
||||
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(weights_path)
|
||||
path_sovits_v3=self.configs.default_configs["v3"]["vits_weights_path"]
|
||||
|
||||
if if_lora_v3==True and os.path.exists(path_sovits_v3)==False:
|
||||
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
raise FileExistsError(info)
|
||||
|
||||
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
|
||||
dict_s2 = load_sovits_new(weights_path)
|
||||
hps = dict_s2["config"]
|
||||
|
||||
hps["model"]["semantic_frame_rate"] = "25hz"
|
||||
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
|
||||
hps["model"]["version"] = "v2"#v3model,v2sybomls
|
||||
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
hps["model"]["version"] = "v1"
|
||||
else:
|
||||
hps["model"]["version"] = "v2"
|
||||
# version = hps["model"]["version"]
|
||||
|
||||
hps["model"]["version"] = self.configs.version
|
||||
self.configs.filter_length = hps["data"]["filter_length"]
|
||||
self.configs.segment_size = hps["train"]["segment_size"]
|
||||
self.configs.sampling_rate = hps["data"]["sampling_rate"]
|
||||
self.configs.hop_length = hps["data"]["hop_length"]
|
||||
self.configs.win_length = hps["data"]["win_length"]
|
||||
self.configs.n_speakers = hps["data"]["n_speakers"]
|
||||
self.configs.semantic_frame_rate = "25hz"
|
||||
self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"]
|
||||
kwargs = hps["model"]
|
||||
vits_model = SynthesizerTrn(
|
||||
self.configs.filter_length // 2 + 1,
|
||||
self.configs.segment_size // self.configs.hop_length,
|
||||
n_speakers=self.configs.n_speakers,
|
||||
**kwargs
|
||||
)
|
||||
# print(f"self.configs.sampling_rate:{self.configs.sampling_rate}")
|
||||
|
||||
self.configs.update_version(model_version)
|
||||
|
||||
# print(f"model_version:{model_version}")
|
||||
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
|
||||
if model_version!="v3":
|
||||
vits_model = SynthesizerTrn(
|
||||
self.configs.filter_length // 2 + 1,
|
||||
self.configs.segment_size // self.configs.hop_length,
|
||||
n_speakers=self.configs.n_speakers,
|
||||
**kwargs
|
||||
)
|
||||
if hasattr(vits_model, "enc_q"):
|
||||
del vits_model.enc_q
|
||||
self.configs.is_v3_synthesizer = False
|
||||
else:
|
||||
vits_model = SynthesizerTrnV3(
|
||||
self.configs.filter_length // 2 + 1,
|
||||
self.configs.segment_size // self.configs.hop_length,
|
||||
n_speakers=self.configs.n_speakers,
|
||||
**kwargs
|
||||
)
|
||||
self.configs.is_v3_synthesizer = True
|
||||
self.init_bigvgan()
|
||||
|
||||
|
||||
if if_lora_v3==False:
|
||||
print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
|
||||
else:
|
||||
print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}")
|
||||
lora_rank=dict_s2["lora_rank"]
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_rank,
|
||||
init_lora_weights=True,
|
||||
)
|
||||
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
|
||||
print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
|
||||
|
||||
vits_model.cfm = vits_model.cfm.merge_and_unload()
|
||||
|
||||
if hasattr(vits_model, "enc_q"):
|
||||
del vits_model.enc_q
|
||||
|
||||
vits_model = vits_model.to(self.configs.device)
|
||||
vits_model = vits_model.eval()
|
||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
|
||||
self.vits_model = vits_model
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
self.vits_model = self.vits_model.half()
|
||||
@ -363,6 +517,30 @@ class TTS:
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
|
||||
def init_bigvgan(self):
|
||||
if self.bigvgan_model is not None:
|
||||
return
|
||||
self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||
# remove weight norm in the model and set to eval mode
|
||||
self.bigvgan_model.remove_weight_norm()
|
||||
self.bigvgan_model = self.bigvgan_model.eval()
|
||||
if self.configs.is_half == True:
|
||||
self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device)
|
||||
else:
|
||||
self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
|
||||
|
||||
def init_sr_model(self):
|
||||
if self.sr_model is not None:
|
||||
return
|
||||
try:
|
||||
self.sr_model:AP_BWE=AP_BWE(self.configs.device,DictToAttrRecursive)
|
||||
self.sr_model_not_exist = False
|
||||
except FileNotFoundError:
|
||||
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
|
||||
self.sr_model_not_exist = True
|
||||
|
||||
|
||||
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||
'''
|
||||
To enable half precision for the TTS model.
|
||||
@ -387,6 +565,8 @@ class TTS:
|
||||
self.bert_model =self.bert_model.half()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
if self.bigvgan_model is not None:
|
||||
self.bigvgan_model = self.bigvgan_model.half()
|
||||
else:
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model = self.t2s_model.float()
|
||||
@ -396,6 +576,8 @@ class TTS:
|
||||
self.bert_model = self.bert_model.float()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||
if self.bigvgan_model is not None:
|
||||
self.bigvgan_model = self.bigvgan_model.float()
|
||||
|
||||
def set_device(self, device: torch.device, save: bool = True):
|
||||
'''
|
||||
@ -414,6 +596,11 @@ class TTS:
|
||||
self.bert_model = self.bert_model.to(device)
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
|
||||
if self.bigvgan_model is not None:
|
||||
self.bigvgan_model = self.bigvgan_model.to(device)
|
||||
if self.sr_model is not None:
|
||||
self.sr_model = self.sr_model.to(device)
|
||||
|
||||
|
||||
def set_ref_audio(self, ref_audio_path:str):
|
||||
'''
|
||||
@ -437,6 +624,11 @@ class TTS:
|
||||
self.prompt_cache["refer_spec"][0] = spec
|
||||
|
||||
def _get_ref_spec(self, ref_audio_path):
|
||||
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
|
||||
raw_audio=raw_audio.to(self.configs.device).float()
|
||||
self.prompt_cache["raw_audio"] = raw_audio
|
||||
self.prompt_cache["raw_sr"] = raw_sr
|
||||
|
||||
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
maxx=audio.abs().max()
|
||||
@ -625,11 +817,11 @@ class TTS:
|
||||
Recovery the order of the audio according to the batch_index_list.
|
||||
|
||||
Args:
|
||||
data (List[list(np.ndarray)]): the out of order audio .
|
||||
data (List[list(torch.Tensor)]): the out of order audio .
|
||||
batch_index_list (List[list[int]]): the batch index list.
|
||||
|
||||
Returns:
|
||||
list (List[np.ndarray]): the data in the original order.
|
||||
list (List[torch.Tensor]): the data in the original order.
|
||||
'''
|
||||
length = len(sum(batch_index_list, []))
|
||||
_data = [None]*length
|
||||
@ -671,6 +863,8 @@ class TTS:
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
}
|
||||
returns:
|
||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||
@ -698,6 +892,8 @@ class TTS:
|
||||
actual_seed = set_seed(seed)
|
||||
parallel_infer = inputs.get("parallel_infer", True)
|
||||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||
sample_steps = inputs.get("sample_steps", 32)
|
||||
super_sampling = inputs.get("super_sampling", False)
|
||||
|
||||
if parallel_infer:
|
||||
print(i18n("并行推理模式已开启"))
|
||||
@ -732,6 +928,9 @@ class TTS:
|
||||
if not no_prompt_text:
|
||||
assert prompt_lang in self.configs.languages
|
||||
|
||||
if no_prompt_text and self.configs.is_v3_synthesizer:
|
||||
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
|
||||
|
||||
if ref_audio_path in [None, ""] and \
|
||||
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
|
||||
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
|
||||
@ -761,13 +960,13 @@ class TTS:
|
||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "."
|
||||
print(i18n("实际输入的参考文本:"), prompt_text)
|
||||
if self.prompt_cache["prompt_text"] != prompt_text:
|
||||
self.prompt_cache["prompt_text"] = prompt_text
|
||||
self.prompt_cache["prompt_lang"] = prompt_lang
|
||||
phones, bert_features, norm_text = \
|
||||
self.text_preprocessor.segment_and_extract_feature_for_text(
|
||||
prompt_text,
|
||||
prompt_lang,
|
||||
self.configs.version)
|
||||
self.prompt_cache["prompt_text"] = prompt_text
|
||||
self.prompt_cache["prompt_lang"] = prompt_lang
|
||||
self.prompt_cache["phones"] = phones
|
||||
self.prompt_cache["bert_features"] = bert_features
|
||||
self.prompt_cache["norm_text"] = norm_text
|
||||
@ -781,8 +980,7 @@ class TTS:
|
||||
if not return_fragment:
|
||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
||||
if len(data) == 0:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
return
|
||||
|
||||
batch_index_list:list = None
|
||||
@ -836,6 +1034,7 @@ class TTS:
|
||||
t_34 = 0.0
|
||||
t_45 = 0.0
|
||||
audio = []
|
||||
output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000
|
||||
for item in data:
|
||||
t3 = ttime()
|
||||
if return_fragment:
|
||||
@ -858,7 +1057,7 @@ class TTS:
|
||||
else:
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||
|
||||
|
||||
print(f"############ {i18n('预测语义Token')} ############")
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_lens,
|
||||
@ -892,70 +1091,80 @@ class TTS:
|
||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
|
||||
# ))
|
||||
|
||||
if speed_factor == 1.0:
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
|
||||
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
_batch_audio_fragment = (self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :])
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
|
||||
print(f"############ {i18n('合成音频')} ############")
|
||||
if not self.configs.is_v3_synthesizer:
|
||||
if speed_factor == 1.0:
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
|
||||
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
_batch_audio_fragment = (self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :])
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
|
||||
else:
|
||||
# ## vits串行推理
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment =(self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :])
|
||||
batch_audio_fragment.append(
|
||||
audio_fragment
|
||||
) ###试试重建不带上prompt部分
|
||||
else:
|
||||
# ## vits串行推理
|
||||
for i, idx in enumerate(idx_list):
|
||||
for i, idx in enumerate(tqdm(idx_list)):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment =(self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :])
|
||||
audio_fragment = self.v3_synthesis(
|
||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||
)
|
||||
batch_audio_fragment.append(
|
||||
audio_fragment
|
||||
) ###试试重建不带上prompt部分
|
||||
)
|
||||
|
||||
t5 = ttime()
|
||||
t_45 += t5 - t4
|
||||
if return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||
yield self.audio_postprocess([batch_audio_fragment],
|
||||
self.configs.sampling_rate,
|
||||
output_sr,
|
||||
None,
|
||||
speed_factor,
|
||||
False,
|
||||
fragment_interval
|
||||
fragment_interval,
|
||||
super_sampling if self.configs.is_v3_synthesizer else False
|
||||
)
|
||||
else:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
if len(audio) == 0:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
return
|
||||
yield self.audio_postprocess(audio,
|
||||
self.configs.sampling_rate,
|
||||
output_sr,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket,
|
||||
fragment_interval
|
||||
fragment_interval,
|
||||
super_sampling if self.configs.is_v3_synthesizer else False
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# 必须返回一个空音频, 否则会导致显存不释放。
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||
# 重置模型, 否则会导致显存释放不完全。
|
||||
del self.t2s_model
|
||||
del self.vits_model
|
||||
@ -983,7 +1192,8 @@ class TTS:
|
||||
batch_index_list:list=None,
|
||||
speed_factor:float=1.0,
|
||||
split_bucket:bool=True,
|
||||
fragment_interval:float=0.3
|
||||
fragment_interval:float=0.3,
|
||||
super_sampling:bool=False,
|
||||
)->Tuple[int, np.ndarray]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * fragment_interval),
|
||||
@ -996,7 +1206,7 @@ class TTS:
|
||||
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
|
||||
if max_audio>1: audio_fragment/=max_audio
|
||||
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
audio[i][j] = audio_fragment.cpu().numpy()
|
||||
audio[i][j] = audio_fragment
|
||||
|
||||
|
||||
if split_bucket:
|
||||
@ -1005,8 +1215,21 @@ class TTS:
|
||||
# audio = [item for batch in audio for item in batch]
|
||||
audio = sum(audio, [])
|
||||
|
||||
audio = torch.cat(audio, dim=0)
|
||||
|
||||
if super_sampling:
|
||||
print(f"############ {i18n('音频超采样')} ############")
|
||||
t1 = ttime()
|
||||
self.init_sr_model()
|
||||
if not self.sr_model_not_exist:
|
||||
audio,sr=self.sr_model(audio.unsqueeze(0),sr)
|
||||
max_audio=np.abs(audio).max()
|
||||
if max_audio > 1: audio /= max_audio
|
||||
t2 = ttime()
|
||||
print(f"超采样用时:{t2-t1:.3f}s")
|
||||
else:
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
audio = np.concatenate(audio, 0)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
# try:
|
||||
@ -1018,25 +1241,59 @@ class TTS:
|
||||
return sr, audio
|
||||
|
||||
|
||||
def v3_synthesis(self,
|
||||
semantic_tokens:torch.Tensor,
|
||||
phones:torch.Tensor,
|
||||
speed:float=1.0,
|
||||
sample_steps:int=32
|
||||
):
|
||||
|
||||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio:torch.Tensor = self.prompt_cache["raw_audio"]
|
||||
ref_sr = self.prompt_cache["raw_sr"]
|
||||
ref_audio=ref_audio.to(self.configs.device).float()
|
||||
if (ref_audio.shape[0] == 2):
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if ref_sr!=24000:
|
||||
ref_audio=resample(ref_audio, ref_sr, self.configs.device)
|
||||
# print("ref_audio",ref_audio.abs().mean())
|
||||
mel2 = mel_fn(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if (T_min > 468):
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
|
||||
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||
# 将 NumPy 数组转换为原始 PCM 流
|
||||
raw_audio = input_audio.astype(np.int16).tobytes()
|
||||
mel2=mel2.to(self.precision)
|
||||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
|
||||
# 设置 ffmpeg 输入流
|
||||
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while (1):
|
||||
fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
|
||||
if (fea_todo_chunk.shape[-1] == 0): break
|
||||
idx += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
|
||||
# 变速处理
|
||||
output_stream = input_stream.filter('atempo', speed)
|
||||
cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2]:]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
|
||||
# 输出流到管道
|
||||
out, _ = (
|
||||
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
|
||||
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
|
||||
# 将管道输出解码为 NumPy 数组
|
||||
processed_audio = np.frombuffer(out, np.int16)
|
||||
|
||||
return processed_audio
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
cfm_resss.append(cfm_res)
|
||||
cmf_res = torch.cat(cfm_resss, 2)
|
||||
cmf_res = denorm_spec(cmf_res)
|
||||
|
||||
with torch.inference_mode():
|
||||
wav_gen = self.bigvgan_model(cmf_res)
|
||||
audio=wav_gen[0][0]#.cpu().detach().numpy()
|
||||
|
||||
return audio
|
||||
|
@ -118,11 +118,11 @@ class TextPreprocessor:
|
||||
|
||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
language = language.replace("all_","")
|
||||
# language = language.replace("all_","")
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "zh":
|
||||
if language == "all_zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
@ -130,7 +130,7 @@ class TextPreprocessor:
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"yue",version)
|
||||
@ -199,6 +199,7 @@ class TextPreprocessor:
|
||||
return phone_level_feature.T
|
||||
|
||||
def clean_text_inf(self, text:str, language:str, version:str="v2"):
|
||||
language = language.replace("all_","")
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
@ -6,7 +6,7 @@ custom:
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
default:
|
||||
v1:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
@ -14,7 +14,7 @@ default:
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
version: v1
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
default_v2:
|
||||
v2:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
@ -22,3 +22,11 @@ default_v2:
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
v3:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
version: v3
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
|
||||
|
@ -7,7 +7,7 @@
|
||||
全部按日文识别
|
||||
'''
|
||||
import random
|
||||
import os, re, logging
|
||||
import os, re, logging, json
|
||||
import sys
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -44,7 +44,7 @@ bert_path = os.environ.get("bert_path", None)
|
||||
version=os.environ.get("version","v2")
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
@ -62,6 +62,9 @@ if torch.cuda.is_available():
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# is_half = False
|
||||
# device = "cpu"
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
@ -123,11 +126,11 @@ def inference(text, text_lang,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty
|
||||
repetition_penalty, sample_steps, super_sampling,
|
||||
):
|
||||
|
||||
seed = -1 if keep_random else seed
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
||||
inputs={
|
||||
"text": text,
|
||||
"text_lang": dict_language[text_lang],
|
||||
@ -147,9 +150,14 @@ def inference(text, text_lang,
|
||||
"seed":actual_seed,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"sample_steps": int(sample_steps),
|
||||
"super_sampling": super_sampling,
|
||||
}
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
try:
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
except NO_PROMPT_ERROR:
|
||||
gr.Warning(i18n('V3不支持无参考文本模式,请填写参考文本!'))
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
@ -163,19 +171,38 @@ def change_choices():
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
|
||||
|
||||
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
|
||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
|
||||
|
||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
|
||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
|
||||
_ =[[],[]]
|
||||
for i in range(2):
|
||||
if os.path.exists(pretrained_gpt_name[i]):
|
||||
_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):
|
||||
_[-1].append(pretrained_sovits_name[i])
|
||||
for i in range(3):
|
||||
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
|
||||
pretrained_gpt_name,pretrained_sovits_name = _
|
||||
|
||||
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
|
||||
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
|
||||
|
||||
if os.path.exists(f"./weight.json"):
|
||||
pass
|
||||
else:
|
||||
with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
|
||||
|
||||
with open(f"./weight.json", 'r', encoding="utf-8") as file:
|
||||
weight_data = file.read()
|
||||
weight_data=json.loads(weight_data)
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
|
||||
sovits_path = os.environ.get(
|
||||
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
|
||||
if isinstance(gpt_path,list):
|
||||
gpt_path = gpt_path[0]
|
||||
if isinstance(sovits_path,list):
|
||||
sovits_path = sovits_path[0]
|
||||
|
||||
|
||||
|
||||
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
|
||||
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
|
||||
for path in SoVITS_weight_root+GPT_weight_root:
|
||||
os.makedirs(path,exist_ok=True)
|
||||
|
||||
@ -194,10 +221,18 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
|
||||
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast
|
||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
global version, dict_language
|
||||
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
|
||||
|
||||
if if_lora_v3 and not os.path.exists(path_sovits_v3):
|
||||
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
gr.Warning(info)
|
||||
raise FileExistsError(info)
|
||||
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
|
||||
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
@ -210,9 +245,19 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
else:
|
||||
text_update = {'__type__':'update', 'value':''}
|
||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
|
||||
|
||||
if model_version=="v3":
|
||||
visible_sample_steps=True
|
||||
visible_inp_refs=False
|
||||
else:
|
||||
visible_sample_steps=False
|
||||
visible_inp_refs=True
|
||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
|
||||
|
||||
with open("./weight.json")as f:
|
||||
data=f.read()
|
||||
data=json.loads(data)
|
||||
data["SoVITS"][version]=sovits_path
|
||||
with open("./weight.json","w")as f:f.write(json.dumps(data))
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
gr.Markdown(
|
||||
@ -257,13 +302,19 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
|
||||
with gr.Row():
|
||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
|
||||
with gr.Row():
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
how_to_cut = gr.Dropdown(
|
||||
@ -272,10 +323,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True, scale=1
|
||||
)
|
||||
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
|
||||
|
||||
with gr.Row():
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
|
||||
|
||||
with gr.Row():
|
||||
|
||||
seed = gr.Number(label=i18n("随机种子"),value=-1)
|
||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||
|
||||
@ -295,7 +350,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty
|
||||
repetition_penalty, sample_steps, super_sampling,
|
||||
],
|
||||
[output, seed],
|
||||
)
|
||||
|
24
api_v2.py
24
api_v2.py
@ -39,6 +39,8 @@ POST:
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
}
|
||||
```
|
||||
|
||||
@ -164,6 +166,8 @@ class TTS_Request(BaseModel):
|
||||
streaming_mode:bool = False
|
||||
parallel_infer:bool = True
|
||||
repetition_penalty:float = 1.35
|
||||
sample_steps:int = 32
|
||||
super_sampling:bool = False
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
@ -294,7 +298,9 @@ async def tts_handle(req:dict):
|
||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
}
|
||||
returns:
|
||||
StreamingResponse: audio stream response.
|
||||
@ -316,10 +322,12 @@ async def tts_handle(req:dict):
|
||||
|
||||
if streaming_mode:
|
||||
def streaming_generator(tts_generator:Generator, media_type:str):
|
||||
if media_type == "wav":
|
||||
yield wave_header_chunk()
|
||||
media_type = "raw"
|
||||
if_frist_chunk = True
|
||||
for sr, chunk in tts_generator:
|
||||
if if_frist_chunk and media_type == "wav":
|
||||
yield wave_header_chunk(sample_rate=sr)
|
||||
media_type = "raw"
|
||||
if_frist_chunk = False
|
||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||
@ -365,7 +373,9 @@ async def tts_get_endpoint(
|
||||
media_type:str = "wav",
|
||||
streaming_mode:bool = False,
|
||||
parallel_infer:bool = True,
|
||||
repetition_penalty:float = 1.35
|
||||
repetition_penalty:float = 1.35,
|
||||
sample_steps:int =32,
|
||||
super_sampling:bool = False
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
@ -387,7 +397,9 @@ async def tts_get_endpoint(
|
||||
"media_type":media_type,
|
||||
"streaming_mode":streaming_mode,
|
||||
"parallel_infer":parallel_infer,
|
||||
"repetition_penalty":float(repetition_penalty)
|
||||
"repetition_penalty":float(repetition_penalty),
|
||||
"sample_steps":int(sample_steps),
|
||||
"super_sampling":super_sampling
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
|
@ -39,6 +39,11 @@ class AP_BWE():
|
||||
self.model=model
|
||||
self.h=h
|
||||
|
||||
def to(self, *arg, **kwargs):
|
||||
self.model.to(*arg, **kwargs)
|
||||
self.device = self.model.conv_pre_mag.weight.device
|
||||
return self
|
||||
|
||||
def __call__(self, audio,orig_sampling_rate):
|
||||
with torch.no_grad():
|
||||
# audio, orig_sampling_rate = torchaudio.load(inp_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user