From 7394dc7b0c9e5012b614f8d7b48404a1d6c5ad38 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Wed, 26 Mar 2025 14:34:51 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BAapi=5Fv2=E5=92=8Cinference=5Fwebui=5Ff?= =?UTF-8?q?ast=E9=80=82=E9=85=8DV3=E7=89=88=E6=9C=AC=20(#2188)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py modified: GPT_SoVITS/inference_webui_fast.py * 适配V3版本 * api_v2.py和inference_webui_fast.py的v3适配 * 修改了个远古bug,增加了更友好的提示信息 * 优化webui * 修改为正确的path * 修复v3 lora模型的载入问题 * 修复读取tts_infer.yaml文件时遇到的编码不匹配的问题 --- .gitignore | 4 +- .../alias_free_activation/torch/act.py | 2 +- .../alias_free_activation/torch/resample.py | 4 +- GPT_SoVITS/BigVGAN/bigvgan.py | 14 +- GPT_SoVITS/BigVGAN/meldataset.py | 2 +- GPT_SoVITS/BigVGAN/utils0.py | 2 +- GPT_SoVITS/TTS_infer_pack/TTS.py | 447 ++++++++++++++---- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 7 +- GPT_SoVITS/configs/tts_infer.yaml | 12 +- GPT_SoVITS/inference_webui_fast.py | 109 +++-- api_v2.py | 24 +- tools/audio_sr.py | 5 + 12 files changed, 486 insertions(+), 146 deletions(-) diff --git a/.gitignore b/.gitignore index b7fec30..e5cedbf 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,6 @@ SoVITS_weights_v3 TEMP weight.json ffmpeg* -ffprobe* \ No newline at end of file +ffprobe* +tools/AP_BWE_main/24kto48k/* +!tools/AP_BWE_main/24kto48k/readme.txt \ No newline at end of file diff --git a/GPT_SoVITS/BigVGAN/alias_free_activation/torch/act.py b/GPT_SoVITS/BigVGAN/alias_free_activation/torch/act.py index cc6e9f8..a6693aa 100644 --- a/GPT_SoVITS/BigVGAN/alias_free_activation/torch/act.py +++ b/GPT_SoVITS/BigVGAN/alias_free_activation/torch/act.py @@ -2,7 +2,7 @@ # LICENSE is in incl_licenses directory. import torch.nn as nn -from alias_free_activation.torch.resample import UpSample1d, DownSample1d +from .resample import UpSample1d, DownSample1d class Activation1d(nn.Module): diff --git a/GPT_SoVITS/BigVGAN/alias_free_activation/torch/resample.py b/GPT_SoVITS/BigVGAN/alias_free_activation/torch/resample.py index f321150..a35380f 100644 --- a/GPT_SoVITS/BigVGAN/alias_free_activation/torch/resample.py +++ b/GPT_SoVITS/BigVGAN/alias_free_activation/torch/resample.py @@ -3,8 +3,8 @@ import torch.nn as nn from torch.nn import functional as F -from alias_free_activation.torch.filter import LowPassFilter1d -from alias_free_activation.torch.filter import kaiser_sinc_filter1d +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d class UpSample1d(nn.Module): diff --git a/GPT_SoVITS/BigVGAN/bigvgan.py b/GPT_SoVITS/BigVGAN/bigvgan.py index 214672e..6c4a223 100644 --- a/GPT_SoVITS/BigVGAN/bigvgan.py +++ b/GPT_SoVITS/BigVGAN/bigvgan.py @@ -14,10 +14,10 @@ import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn.utils import weight_norm, remove_weight_norm -import activations -from utils0 import init_weights, get_padding -from alias_free_activation.torch.act import Activation1d as TorchActivation1d -from env import AttrDict +from . import activations +from .utils0 import init_weights, get_padding +from .alias_free_activation.torch.act import Activation1d as TorchActivation1d +from .env import AttrDict from huggingface_hub import PyTorchModelHubMixin, hf_hub_download @@ -93,7 +93,7 @@ class AMPBlock1(torch.nn.Module): # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from .alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -193,7 +193,7 @@ class AMPBlock2(torch.nn.Module): # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from .alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -271,7 +271,7 @@ class BigVGAN( # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from .alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) diff --git a/GPT_SoVITS/BigVGAN/meldataset.py b/GPT_SoVITS/BigVGAN/meldataset.py index bfbd4b6..a5859b9 100644 --- a/GPT_SoVITS/BigVGAN/meldataset.py +++ b/GPT_SoVITS/BigVGAN/meldataset.py @@ -15,7 +15,7 @@ from librosa.filters import mel as librosa_mel_fn import pathlib from tqdm import tqdm from typing import List, Tuple, Optional -from env import AttrDict +from .env import AttrDict MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) diff --git a/GPT_SoVITS/BigVGAN/utils0.py b/GPT_SoVITS/BigVGAN/utils0.py index 888ea89..da98a24 100644 --- a/GPT_SoVITS/BigVGAN/utils0.py +++ b/GPT_SoVITS/BigVGAN/utils0.py @@ -9,7 +9,7 @@ from torch.nn.utils import weight_norm matplotlib.use("Agg") import matplotlib.pylab as plt -from meldataset import MAX_WAV_VALUE +from .meldataset import MAX_WAV_VALUE from scipy.io.wavfile import write diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 012cbf8..52402e9 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -4,6 +4,7 @@ import os, sys, gc import random import traceback +import torchaudio from tqdm import tqdm now_dir = os.getcwd() sys.path.append(now_dir) @@ -15,10 +16,11 @@ import torch import torch.nn.functional as F import yaml from transformers import AutoModelForMaskedLM, AutoTokenizer - +from tools.audio_sr import AP_BWE from AR.models.t2s_lightning_module import Text2SemanticLightningModule from feature_extractor.cnhubert import CNHubert -from module.models import SynthesizerTrn +from module.models import SynthesizerTrn, SynthesizerTrnV3 +from peft import LoraConfig, get_peft_model import librosa from time import time as ttime from tools.i18n.i18n import I18nAuto, scan_language_list @@ -26,10 +28,98 @@ from tools.my_utils import load_audio from module.mel_processing import spectrogram_torch from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.TextPreprocessor import TextPreprocessor +from BigVGAN.bigvgan import BigVGAN +from module.mel_processing import spectrogram_torch,mel_spectrogram_torch +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new language=os.environ.get("language","Auto") language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language i18n = I18nAuto(language=language) + + +spec_min = -12 +spec_max = 2 +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min +mel_fn=lambda x: mel_spectrogram_torch(x, **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False +}) + + +def speed_change(input_audio:np.ndarray, speed:float, sr:int): + # 将 NumPy 数组转换为原始 PCM 流 + raw_audio = input_audio.astype(np.int16).tobytes() + + # 设置 ffmpeg 输入流 + input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1) + + # 变速处理 + output_stream = input_stream.filter('atempo', speed) + + # 输出流到管道 + out, _ = ( + output_stream.output('pipe:', format='s16le', acodec='pcm_s16le') + .run(input=raw_audio, capture_stdout=True, capture_stderr=True) + ) + + # 将管道输出解码为 NumPy 数组 + processed_audio = np.frombuffer(out, np.int16) + + return processed_audio + + + +resample_transform_dict={} +def resample(audio_tensor, sr0, device): + global resample_transform_dict + if sr0 not in resample_transform_dict: + resample_transform_dict[sr0] = torchaudio.transforms.Resample( + sr0, 24000 + ).to(device) + return resample_transform_dict[sr0](audio_tensor) + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +class NO_PROMPT_ERROR(Exception): + pass + + # configs/tts_infer.yaml """ custom: @@ -56,11 +146,19 @@ default_v2: t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth version: v2 +default_v3: + bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large + cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base + device: cpu + is_half: false + t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt + vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth + version: v3 """ def set_seed(seed:int): seed = int(seed) - seed = seed if seed != -1 else random.randrange(1 << 32) + seed = seed if seed != -1 else random.randint(0, 2**32 - 1) print(f"Set seed to {seed}") os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) @@ -82,7 +180,7 @@ def set_seed(seed:int): class TTS_Config: default_configs={ - "default":{ + "v1":{ "device": "cpu", "is_half": False, "version": "v1", @@ -91,7 +189,7 @@ class TTS_Config: "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", }, - "default_v2":{ + "v2":{ "device": "cpu", "is_half": False, "version": "v2", @@ -100,6 +198,15 @@ class TTS_Config: "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", }, + "v3":{ + "device": "cpu", + "is_half": False, + "version": "v3", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, } configs:dict = None v1_languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] @@ -136,12 +243,9 @@ class TTS_Config: assert isinstance(configs, dict) version = configs.get("version", "v2").lower() - assert version in ["v1", "v2"] - self.default_configs["default"] = configs.get("default", self.default_configs["default"]) - self.default_configs["default_v2"] = configs.get("default_v2", self.default_configs["default_v2"]) - - default_config_key = "default"if version=="v1" else "default_v2" - self.configs:dict = configs.get("custom", deepcopy(self.default_configs[default_config_key])) + assert version in ["v1", "v2", "v3"] + self.default_configs[version] = configs.get(version, self.default_configs[version]) + self.configs:dict = configs.get("custom", deepcopy(self.default_configs[version])) self.device = self.configs.get("device", torch.device("cpu")) @@ -159,20 +263,22 @@ class TTS_Config: self.vits_weights_path = self.configs.get("vits_weights_path", None) self.bert_base_path = self.configs.get("bert_base_path", None) self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) - self.languages = self.v2_languages if self.version=="v2" else self.v1_languages + self.languages = self.v1_languages if self.version=="v1" else self.v2_languages + + self.is_v3_synthesizer:bool = False if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): - self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path'] + self.t2s_weights_path = self.default_configs[version]['t2s_weights_path'] print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): - self.vits_weights_path = self.default_configs[default_config_key]['vits_weights_path'] + self.vits_weights_path = self.default_configs[version]['vits_weights_path'] print(f"fall back to default vits_weights_path: {self.vits_weights_path}") if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): - self.bert_base_path = self.default_configs[default_config_key]['bert_base_path'] + self.bert_base_path = self.default_configs[version]['bert_base_path'] print(f"fall back to default bert_base_path: {self.bert_base_path}") if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): - self.cnhuhbert_base_path = self.default_configs[default_config_key]['cnhuhbert_base_path'] + self.cnhuhbert_base_path = self.default_configs[version]['cnhuhbert_base_path'] print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") self.update_configs() @@ -195,7 +301,7 @@ class TTS_Config: else: print(i18n("路径不存在,使用默认配置")) self.save_configs(configs_path) - with open(configs_path, 'r') as f: + with open(configs_path, 'r', encoding='utf-8') as f: configs = yaml.load(f, Loader=yaml.FullLoader) return configs @@ -224,7 +330,7 @@ class TTS_Config: def update_version(self, version:str)->None: self.version = version - self.languages = self.v2_languages if self.version=="v2" else self.v1_languages + self.languages = self.v1_languages if self.version=="v1" else self.v2_languages def __str__(self): self.configs = self.update_configs() @@ -252,10 +358,13 @@ class TTS: self.configs:TTS_Config = TTS_Config(configs) self.t2s_model:Text2SemanticLightningModule = None - self.vits_model:SynthesizerTrn = None + self.vits_model:Union[SynthesizerTrn, SynthesizerTrnV3] = None self.bert_tokenizer:AutoTokenizer = None self.bert_model:AutoModelForMaskedLM = None self.cnhuhbert_model:CNHubert = None + self.bigvgan_model:BigVGAN = None + self.sr_model:AP_BWE = None + self.sr_model_not_exist:bool = False self._init_models() @@ -310,38 +419,83 @@ class TTS: self.bert_model = self.bert_model.half() def init_vits_weights(self, weights_path: str): - print(f"Loading VITS weights from {weights_path}") + self.configs.vits_weights_path = weights_path - dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) - hps = dict_s2["config"] - if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: - self.configs.update_version("v1") - else: - self.configs.update_version("v2") - self.configs.save_configs() + version, model_version, if_lora_v3=get_sovits_version_from_path_fast(weights_path) + path_sovits_v3=self.configs.default_configs["v3"]["vits_weights_path"] + + if if_lora_v3==True and os.path.exists(path_sovits_v3)==False: + info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + raise FileExistsError(info) + + # dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) + dict_s2 = load_sovits_new(weights_path) + hps = dict_s2["config"] + + hps["model"]["semantic_frame_rate"] = "25hz" + if 'enc_p.text_embedding.weight'not in dict_s2['weight']: + hps["model"]["version"] = "v2"#v3model,v2sybomls + elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + hps["model"]["version"] = "v1" + else: + hps["model"]["version"] = "v2" + # version = hps["model"]["version"] - hps["model"]["version"] = self.configs.version self.configs.filter_length = hps["data"]["filter_length"] self.configs.segment_size = hps["train"]["segment_size"] self.configs.sampling_rate = hps["data"]["sampling_rate"] self.configs.hop_length = hps["data"]["hop_length"] self.configs.win_length = hps["data"]["win_length"] self.configs.n_speakers = hps["data"]["n_speakers"] - self.configs.semantic_frame_rate = "25hz" + self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"] kwargs = hps["model"] - vits_model = SynthesizerTrn( - self.configs.filter_length // 2 + 1, - self.configs.segment_size // self.configs.hop_length, - n_speakers=self.configs.n_speakers, - **kwargs - ) + # print(f"self.configs.sampling_rate:{self.configs.sampling_rate}") + + self.configs.update_version(model_version) + + # print(f"model_version:{model_version}") + # print(f'hps["model"]["version"]:{hps["model"]["version"]}') + if model_version!="v3": + vits_model = SynthesizerTrn( + self.configs.filter_length // 2 + 1, + self.configs.segment_size // self.configs.hop_length, + n_speakers=self.configs.n_speakers, + **kwargs + ) + if hasattr(vits_model, "enc_q"): + del vits_model.enc_q + self.configs.is_v3_synthesizer = False + else: + vits_model = SynthesizerTrnV3( + self.configs.filter_length // 2 + 1, + self.configs.segment_size // self.configs.hop_length, + n_speakers=self.configs.n_speakers, + **kwargs + ) + self.configs.is_v3_synthesizer = True + self.init_bigvgan() + + + if if_lora_v3==False: + print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}") + else: + print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}") + lora_rank=dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vits_model.cfm = get_peft_model(vits_model.cfm, lora_config) + print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}") + + vits_model.cfm = vits_model.cfm.merge_and_unload() - if hasattr(vits_model, "enc_q"): - del vits_model.enc_q vits_model = vits_model.to(self.configs.device) vits_model = vits_model.eval() - vits_model.load_state_dict(dict_s2["weight"], strict=False) + self.vits_model = vits_model if self.configs.is_half and str(self.configs.device)!="cpu": self.vits_model = self.vits_model.half() @@ -363,6 +517,30 @@ class TTS: if self.configs.is_half and str(self.configs.device)!="cpu": self.t2s_model = self.t2s_model.half() + + def init_bigvgan(self): + if self.bigvgan_model is not None: + return + self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + self.bigvgan_model.remove_weight_norm() + self.bigvgan_model = self.bigvgan_model.eval() + if self.configs.is_half == True: + self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device) + else: + self.bigvgan_model = self.bigvgan_model.to(self.configs.device) + + def init_sr_model(self): + if self.sr_model is not None: + return + try: + self.sr_model:AP_BWE=AP_BWE(self.configs.device,DictToAttrRecursive) + self.sr_model_not_exist = False + except FileNotFoundError: + print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好")) + self.sr_model_not_exist = True + + def enable_half_precision(self, enable: bool = True, save: bool = True): ''' To enable half precision for the TTS model. @@ -387,6 +565,8 @@ class TTS: self.bert_model =self.bert_model.half() if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.half() + if self.bigvgan_model is not None: + self.bigvgan_model = self.bigvgan_model.half() else: if self.t2s_model is not None: self.t2s_model = self.t2s_model.float() @@ -396,6 +576,8 @@ class TTS: self.bert_model = self.bert_model.float() if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.float() + if self.bigvgan_model is not None: + self.bigvgan_model = self.bigvgan_model.float() def set_device(self, device: torch.device, save: bool = True): ''' @@ -414,6 +596,11 @@ class TTS: self.bert_model = self.bert_model.to(device) if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.to(device) + if self.bigvgan_model is not None: + self.bigvgan_model = self.bigvgan_model.to(device) + if self.sr_model is not None: + self.sr_model = self.sr_model.to(device) + def set_ref_audio(self, ref_audio_path:str): ''' @@ -437,6 +624,11 @@ class TTS: self.prompt_cache["refer_spec"][0] = spec def _get_ref_spec(self, ref_audio_path): + raw_audio, raw_sr = torchaudio.load(ref_audio_path) + raw_audio=raw_audio.to(self.configs.device).float() + self.prompt_cache["raw_audio"] = raw_audio + self.prompt_cache["raw_sr"] = raw_sr + audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) audio = torch.FloatTensor(audio) maxx=audio.abs().max() @@ -625,11 +817,11 @@ class TTS: Recovery the order of the audio according to the batch_index_list. Args: - data (List[list(np.ndarray)]): the out of order audio . + data (List[list(torch.Tensor)]): the out of order audio . batch_index_list (List[list[int]]): the batch index list. Returns: - list (List[np.ndarray]): the data in the original order. + list (List[torch.Tensor]): the data in the original order. ''' length = len(sum(batch_index_list, [])) _data = [None]*length @@ -671,6 +863,8 @@ class TTS: "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35 # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. } returns: Tuple[int, np.ndarray]: sampling rate and audio data. @@ -698,6 +892,8 @@ class TTS: actual_seed = set_seed(seed) parallel_infer = inputs.get("parallel_infer", True) repetition_penalty = inputs.get("repetition_penalty", 1.35) + sample_steps = inputs.get("sample_steps", 32) + super_sampling = inputs.get("super_sampling", False) if parallel_infer: print(i18n("并行推理模式已开启")) @@ -732,6 +928,9 @@ class TTS: if not no_prompt_text: assert prompt_lang in self.configs.languages + if no_prompt_text and self.configs.is_v3_synthesizer: + raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3") + if ref_audio_path in [None, ""] and \ ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])): raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") @@ -761,13 +960,13 @@ class TTS: if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "." print(i18n("实际输入的参考文本:"), prompt_text) if self.prompt_cache["prompt_text"] != prompt_text: - self.prompt_cache["prompt_text"] = prompt_text - self.prompt_cache["prompt_lang"] = prompt_lang phones, bert_features, norm_text = \ self.text_preprocessor.segment_and_extract_feature_for_text( prompt_text, prompt_lang, self.configs.version) + self.prompt_cache["prompt_text"] = prompt_text + self.prompt_cache["prompt_lang"] = prompt_lang self.prompt_cache["phones"] = phones self.prompt_cache["bert_features"] = bert_features self.prompt_cache["norm_text"] = norm_text @@ -781,8 +980,7 @@ class TTS: if not return_fragment: data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) if len(data) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) + yield 16000, np.zeros(int(16000), dtype=np.int16) return batch_index_list:list = None @@ -836,6 +1034,7 @@ class TTS: t_34 = 0.0 t_45 = 0.0 audio = [] + output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000 for item in data: t3 = ttime() if return_fragment: @@ -858,7 +1057,7 @@ class TTS: else: prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) - + print(f"############ {i18n('预测语义Token')} ############") pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_lens, @@ -892,70 +1091,80 @@ class TTS: # batch_audio_fragment = (self.vits_model.batched_decode( # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec # )) - - if speed_factor == 1.0: - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] - audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :]) - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + print(f"############ {i18n('合成音频')} ############") + if not self.configs.is_v3_synthesizer: + if speed_factor == 1.0: + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = (self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + else: + # ## vits串行推理 + for i, idx in enumerate(tqdm(idx_list)): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment =(self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :]) + batch_audio_fragment.append( + audio_fragment + ) ###试试重建不带上prompt部分 else: - # ## vits串行推理 - for i, idx in enumerate(idx_list): + for i, idx in enumerate(tqdm(idx_list)): phones = batch_phones[i].unsqueeze(0).to(self.configs.device) _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment =(self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :]) + audio_fragment = self.v3_synthesis( + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + ) batch_audio_fragment.append( audio_fragment - ) ###试试重建不带上prompt部分 + ) t5 = ttime() t_45 += t5 - t4 if return_fragment: print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) yield self.audio_postprocess([batch_audio_fragment], - self.configs.sampling_rate, + output_sr, None, speed_factor, False, - fragment_interval + fragment_interval, + super_sampling if self.configs.is_v3_synthesizer else False ) else: audio.append(batch_audio_fragment) if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) + yield 16000, np.zeros(int(16000), dtype=np.int16) return if not return_fragment: print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) if len(audio) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) + yield 16000, np.zeros(int(16000), dtype=np.int16) return yield self.audio_postprocess(audio, - self.configs.sampling_rate, + output_sr, batch_index_list, speed_factor, split_bucket, - fragment_interval + fragment_interval, + super_sampling if self.configs.is_v3_synthesizer else False ) except Exception as e: traceback.print_exc() # 必须返回一个空音频, 否则会导致显存不释放。 - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) + yield 16000, np.zeros(int(16000), dtype=np.int16) # 重置模型, 否则会导致显存释放不完全。 del self.t2s_model del self.vits_model @@ -983,7 +1192,8 @@ class TTS: batch_index_list:list=None, speed_factor:float=1.0, split_bucket:bool=True, - fragment_interval:float=0.3 + fragment_interval:float=0.3, + super_sampling:bool=False, )->Tuple[int, np.ndarray]: zero_wav = torch.zeros( int(self.configs.sampling_rate * fragment_interval), @@ -996,7 +1206,7 @@ class TTS: max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 if max_audio>1: audio_fragment/=max_audio audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) - audio[i][j] = audio_fragment.cpu().numpy() + audio[i][j] = audio_fragment if split_bucket: @@ -1005,8 +1215,21 @@ class TTS: # audio = [item for batch in audio for item in batch] audio = sum(audio, []) + audio = torch.cat(audio, dim=0) + + if super_sampling: + print(f"############ {i18n('音频超采样')} ############") + t1 = ttime() + self.init_sr_model() + if not self.sr_model_not_exist: + audio,sr=self.sr_model(audio.unsqueeze(0),sr) + max_audio=np.abs(audio).max() + if max_audio > 1: audio /= max_audio + t2 = ttime() + print(f"超采样用时:{t2-t1:.3f}s") + else: + audio = audio.cpu().numpy() - audio = np.concatenate(audio, 0) audio = (audio * 32768).astype(np.int16) # try: @@ -1018,25 +1241,59 @@ class TTS: return sr, audio + def v3_synthesis(self, + semantic_tokens:torch.Tensor, + phones:torch.Tensor, + speed:float=1.0, + sample_steps:int=32 + ): + + prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) + refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device) + fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio:torch.Tensor = self.prompt_cache["raw_audio"] + ref_sr = self.prompt_cache["raw_sr"] + ref_audio=ref_audio.to(self.configs.device).float() + if (ref_audio.shape[0] == 2): + ref_audio = ref_audio.mean(0).unsqueeze(0) + if ref_sr!=24000: + ref_audio=resample(ref_audio, ref_sr, self.configs.device) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if (T_min > 468): + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min -def speed_change(input_audio:np.ndarray, speed:float, sr:int): - # 将 NumPy 数组转换为原始 PCM 流 - raw_audio = input_audio.astype(np.int16).tobytes() + mel2=mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) - # 设置 ffmpeg 输入流 - input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1) + cfm_resss = [] + idx = 0 + while (1): + fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len] + if (fea_todo_chunk.shape[-1] == 0): break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) - # 变速处理 - output_stream = input_stream.filter('atempo', speed) + cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) + cfm_res = cfm_res[:, :, mel2.shape[2]:] + mel2 = cfm_res[:, :, -T_min:] - # 输出流到管道 - out, _ = ( - output_stream.output('pipe:', format='s16le', acodec='pcm_s16le') - .run(input=raw_audio, capture_stdout=True, capture_stderr=True) - ) - - # 将管道输出解码为 NumPy 数组 - processed_audio = np.frombuffer(out, np.int16) - - return processed_audio + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res) + + with torch.inference_mode(): + wav_gen = self.bigvgan_model(cmf_res) + audio=wav_gen[0][0]#.cpu().detach().numpy() + + return audio diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 9def3da..653656a 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -118,11 +118,11 @@ class TextPreprocessor: def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False): if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: - language = language.replace("all_","") + # language = language.replace("all_","") formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") - if language == "zh": + if language == "all_zh": if re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) @@ -130,7 +130,7 @@ class TextPreprocessor: else: phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) bert = self.get_bert_feature(norm_text, word2ph).to(self.device) - elif language == "yue" and re.search(r'[A-Za-z]', formattext): + elif language == "all_yue" and re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) return self.get_phones_and_bert(formattext,"yue",version) @@ -199,6 +199,7 @@ class TextPreprocessor: return phone_level_feature.T def clean_text_inf(self, text:str, language:str, version:str="v2"): + language = language.replace("all_","") phones, word2ph, norm_text = clean_text(text, language, version) phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index 66f1193..344aae4 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -6,7 +6,7 @@ custom: t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt version: v2 vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth -default: +v1: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cpu @@ -14,7 +14,7 @@ default: t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt version: v1 vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth -default_v2: +v2: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base device: cpu @@ -22,3 +22,11 @@ default_v2: t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt version: v2 vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth +v3: + bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large + cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base + device: cpu + is_half: false + t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt + version: v3 + vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 5a6910d..9017aa4 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -7,7 +7,7 @@ 全部按日文识别 ''' import random -import os, re, logging +import os, re, logging, json import sys now_dir = os.getcwd() sys.path.append(now_dir) @@ -44,7 +44,7 @@ bert_path = os.environ.get("bert_path", None) version=os.environ.get("version","v2") import gradio as gr -from TTS_infer_pack.TTS import TTS, TTS_Config +from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR from TTS_infer_pack.text_segmentation_method import get_method from tools.i18n.i18n import I18nAuto, scan_language_list @@ -62,6 +62,9 @@ if torch.cuda.is_available(): else: device = "cpu" +# is_half = False +# device = "cpu" + dict_language_v1 = { i18n("中文"): "all_zh",#全部按中文识别 i18n("英文"): "en",#全部按英文识别#######不变 @@ -123,11 +126,11 @@ def inference(text, text_lang, speed_factor, ref_text_free, split_bucket,fragment_interval, seed, keep_random, parallel_infer, - repetition_penalty + repetition_penalty, sample_steps, super_sampling, ): seed = -1 if keep_random else seed - actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32) + actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1) inputs={ "text": text, "text_lang": dict_language[text_lang], @@ -147,9 +150,14 @@ def inference(text, text_lang, "seed":actual_seed, "parallel_infer": parallel_infer, "repetition_penalty": repetition_penalty, + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, } - for item in tts_pipeline.run(inputs): - yield item, actual_seed + try: + for item in tts_pipeline.run(inputs): + yield item, actual_seed + except NO_PROMPT_ERROR: + gr.Warning(i18n('V3不支持无参考文本模式,请填写参考文本!')) def custom_sort_key(s): # 使用正则表达式提取字符串中的数字部分和非数字部分 @@ -163,19 +171,38 @@ def change_choices(): SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"} +path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth" +pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3] +pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"] -pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"] -pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"] _ =[[],[]] -for i in range(2): - if os.path.exists(pretrained_gpt_name[i]): - _[0].append(pretrained_gpt_name[i]) - if os.path.exists(pretrained_sovits_name[i]): - _[-1].append(pretrained_sovits_name[i]) +for i in range(3): + if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i]) + if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i]) pretrained_gpt_name,pretrained_sovits_name = _ -SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"] -GPT_weight_root=["GPT_weights_v2","GPT_weights"] + +if os.path.exists(f"./weight.json"): + pass +else: + with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file) + +with open(f"./weight.json", 'r', encoding="utf-8") as file: + weight_data = file.read() + weight_data=json.loads(weight_data) + gpt_path = os.environ.get( + "gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name)) + sovits_path = os.environ.get( + "sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name)) + if isinstance(gpt_path,list): + gpt_path = gpt_path[0] + if isinstance(sovits_path,list): + sovits_path = sovits_path[0] + + + +SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"] +GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"] for path in SoVITS_weight_root+GPT_weight_root: os.makedirs(path,exist_ok=True) @@ -194,10 +221,18 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root): SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) - +from process_ckpt import get_sovits_version_from_path_fast def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): - tts_pipeline.init_vits_weights(sovits_path) global version, dict_language + version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path) + + if if_lora_v3 and not os.path.exists(path_sovits_v3): + info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + gr.Warning(info) + raise FileExistsError(info) + + tts_pipeline.init_vits_weights(sovits_path) + dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2 if prompt_language is not None and text_language is not None: if prompt_language in list(dict_language.keys()): @@ -210,9 +245,19 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): else: text_update = {'__type__':'update', 'value':''} text_language_update = {'__type__':'update', 'value':i18n("中文")} - return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update - + if model_version=="v3": + visible_sample_steps=True + visible_inp_refs=False + else: + visible_sample_steps=False + visible_inp_refs=True + yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False} + with open("./weight.json")as f: + data=f.read() + data=json.loads(data) + data["SoVITS"][version]=sovits_path + with open("./weight.json","w")as f:f.write(json.dumps(data)) with gr.Blocks(title="GPT-SoVITS WebUI") as app: gr.Markdown( @@ -257,13 +302,19 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Row(): with gr.Column(): - batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) - fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True) - speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True) - top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) - top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) - temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) - repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True) + with gr.Row(): + batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) + sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True) + with gr.Row(): + fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True) + speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True) + with gr.Row(): + top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) + top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) + with gr.Row(): + temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) + repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True) + with gr.Column(): with gr.Row(): how_to_cut = gr.Dropdown( @@ -272,10 +323,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: value=i18n("凑四句一切"), interactive=True, scale=1 ) + super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True) + + with gr.Row(): parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True) split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True) with gr.Row(): + seed = gr.Number(label=i18n("随机种子"),value=-1) keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True) @@ -295,7 +350,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: speed_factor, ref_text_free, split_bucket,fragment_interval, seed, keep_random, parallel_infer, - repetition_penalty + repetition_penalty, sample_steps, super_sampling, ], [output, seed], ) diff --git a/api_v2.py b/api_v2.py index 92a18f3..3a8566a 100644 --- a/api_v2.py +++ b/api_v2.py @@ -39,6 +39,8 @@ POST: "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35 # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. } ``` @@ -164,6 +166,8 @@ class TTS_Request(BaseModel): streaming_mode:bool = False parallel_infer:bool = True repetition_penalty:float = 1.35 + sample_steps:int = 32 + super_sampling:bool = False ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): @@ -294,7 +298,9 @@ async def tts_handle(req:dict): "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". "streaming_mode": False, # bool. whether to return a streaming response. "parallel_infer": True, # bool.(optional) whether to use parallel inference. - "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. + "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. } returns: StreamingResponse: audio stream response. @@ -316,10 +322,12 @@ async def tts_handle(req:dict): if streaming_mode: def streaming_generator(tts_generator:Generator, media_type:str): - if media_type == "wav": - yield wave_header_chunk() - media_type = "raw" + if_frist_chunk = True for sr, chunk in tts_generator: + if if_frist_chunk and media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + media_type = "raw" + if_frist_chunk = False yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") @@ -365,7 +373,9 @@ async def tts_get_endpoint( media_type:str = "wav", streaming_mode:bool = False, parallel_infer:bool = True, - repetition_penalty:float = 1.35 + repetition_penalty:float = 1.35, + sample_steps:int =32, + super_sampling:bool = False ): req = { "text": text, @@ -387,7 +397,9 @@ async def tts_get_endpoint( "media_type":media_type, "streaming_mode":streaming_mode, "parallel_infer":parallel_infer, - "repetition_penalty":float(repetition_penalty) + "repetition_penalty":float(repetition_penalty), + "sample_steps":int(sample_steps), + "super_sampling":super_sampling } return await tts_handle(req) diff --git a/tools/audio_sr.py b/tools/audio_sr.py index d51f055..009ad26 100644 --- a/tools/audio_sr.py +++ b/tools/audio_sr.py @@ -39,6 +39,11 @@ class AP_BWE(): self.model=model self.h=h + def to(self, *arg, **kwargs): + self.model.to(*arg, **kwargs) + self.device = self.model.conv_pre_mag.weight.device + return self + def __call__(self, audio,orig_sampling_rate): with torch.no_grad(): # audio, orig_sampling_rate = torchaudio.load(inp_path)