api_v2.py和inference_webui_fast.py的v3适配

This commit is contained in:
ChasonJiang 2025-03-13 17:41:17 +08:00
parent bf06ac589c
commit f25d830ef4
12 changed files with 235 additions and 112 deletions

4
.gitignore vendored
View File

@ -17,4 +17,6 @@ SoVITS_weights_v3
TEMP
weight.json
ffmpeg*
ffprobe*
ffprobe*
tools/AP_BWE_main/24kto48k/*
!tools/AP_BWE_main/24kto48k/readme.txt

View File

@ -2,7 +2,7 @@
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module):

View File

@ -3,8 +3,8 @@
import torch.nn as nn
from torch.nn import functional as F
from alias_free_activation.torch.filter import LowPassFilter1d
from alias_free_activation.torch.filter import kaiser_sinc_filter1d
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):

View File

@ -14,10 +14,10 @@ import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
import activations
from utils0 import init_weights, get_padding
from alias_free_activation.torch.act import Activation1d as TorchActivation1d
from env import AttrDict
from . import activations
from .utils0 import init_weights, get_padding
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
from .env import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
@ -93,7 +93,7 @@ class AMPBlock1(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
@ -193,7 +193,7 @@ class AMPBlock2(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
@ -271,7 +271,7 @@ class BigVGAN(
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)

View File

@ -15,7 +15,7 @@ from librosa.filters import mel as librosa_mel_fn
import pathlib
from tqdm import tqdm
from typing import List, Tuple, Optional
from env import AttrDict
from .env import AttrDict
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)

View File

@ -9,7 +9,7 @@ from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
from meldataset import MAX_WAV_VALUE
from .meldataset import MAX_WAV_VALUE
from scipy.io.wavfile import write

View File

@ -16,7 +16,7 @@ import torch
import torch.nn.functional as F
import yaml
from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from feature_extractor.cnhubert import CNHubert
from module.models import SynthesizerTrn, SynthesizerTrnV3
@ -55,6 +55,65 @@ mel_fn=lambda x: mel_spectrogram_torch(x, **{
})
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()
# 设置 ffmpeg 输入流
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
# 变速处理
output_stream = input_stream.filter('atempo', speed)
# 输出流到管道
out, _ = (
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
)
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
resample_transform_dict={}
def resample(audio_tensor, sr0, device):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
return resample_transform_dict[sr0](audio_tensor)
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
@ -303,6 +362,8 @@ class TTS:
self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None
self.bigvgan_model:BigVGAN = None
self.sr_model:AP_BWE = None
self.sr_model_not_exist:bool = False
self._init_models()
@ -370,13 +431,13 @@ class TTS:
hps = dict_s2["config"]
hps["model"]["semantic_frame_rate"] = "25hz"
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps["model"]["version"] = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps["model"]["version"] = "v1"
else:
hps["model"]["version"] = "v2"
version = hps["model"]["version"]
# if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
# hps["model"]["version"] = "v2"#v3model,v2sybomls
# elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
# hps["model"]["version"] = "v1"
# else:
# hps["model"]["version"] = "v2"
# version = hps["model"]["version"]
self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"]
@ -386,8 +447,9 @@ class TTS:
self.configs.n_speakers = hps["data"]["n_speakers"]
self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"]
kwargs = hps["model"]
# print(f"self.configs.sampling_rate:{self.configs.sampling_rate}")
self.configs.update_version(version)
self.configs.update_version(model_version)
if model_version!="v3":
@ -397,7 +459,6 @@ class TTS:
n_speakers=self.configs.n_speakers,
**kwargs
)
model_version=version
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
self.configs.is_v3_synthesizer = False
@ -413,9 +474,9 @@ class TTS:
if if_lora_v3==False:
print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}")
print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
else:
print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)}")
print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}")
lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@ -424,7 +485,7 @@ class TTS:
init_lora_weights=True,
)
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}")
print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
vits_model.cfm = vits_model.cfm.merge_and_unload()
@ -466,6 +527,15 @@ class TTS:
else:
self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
def init_sr_model(self):
if self.sr_model is not None:
return
try:
self.sr_model:AP_BWE=AP_BWE(self.configs.device,DictToAttrRecursive)
self.sr_model_not_exist = False
except FileNotFoundError:
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
self.sr_model_not_exist = True
def enable_half_precision(self, enable: bool = True, save: bool = True):
@ -523,6 +593,11 @@ class TTS:
self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
if self.bigvgan_model is not None:
self.bigvgan_model = self.bigvgan_model.to(device)
if self.sr_model is not None:
self.sr_model = self.sr_model.to(device)
def set_ref_audio(self, ref_audio_path:str):
'''
@ -546,6 +621,11 @@ class TTS:
self.prompt_cache["refer_spec"][0] = spec
def _get_ref_spec(self, ref_audio_path):
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
raw_audio=raw_audio.to(self.configs.device).float()
self.prompt_cache["raw_audio"] = raw_audio
self.prompt_cache["raw_sr"] = raw_sr
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio)
maxx=audio.abs().max()
@ -734,11 +814,11 @@ class TTS:
Recovery the order of the audio according to the batch_index_list.
Args:
data (List[list(np.ndarray)]): the out of order audio .
data (List[list(torch.Tensor)]): the out of order audio .
batch_index_list (List[list[int]]): the batch index list.
Returns:
list (List[np.ndarray]): the data in the original order.
list (List[torch.Tensor]): the data in the original order.
'''
length = len(sum(batch_index_list, []))
_data = [None]*length
@ -780,6 +860,8 @@ class TTS:
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
}
returns:
Tuple[int, np.ndarray]: sampling rate and audio data.
@ -807,7 +889,8 @@ class TTS:
actual_seed = set_seed(seed)
parallel_infer = inputs.get("parallel_infer", True)
repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 16)
sample_steps = inputs.get("sample_steps", 32)
super_sampling = inputs.get("super_sampling", False)
if parallel_infer:
print(i18n("并行推理模式已开启"))
@ -894,8 +977,7 @@ class TTS:
if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
yield 16000, np.zeros(int(16000), dtype=np.int16)
return
batch_index_list:list = None
@ -949,6 +1031,7 @@ class TTS:
t_34 = 0.0
t_45 = 0.0
audio = []
output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000
for item in data:
t3 = ttime()
if return_fragment:
@ -971,7 +1054,7 @@ class TTS:
else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
print(f"############ {i18n('预测语义Token')} ############")
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
@ -1005,7 +1088,7 @@ class TTS:
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# ))
print(f"############ {i18n('合成音频')} ############")
if not self.configs.is_v3_synthesizer:
if speed_factor == 1.0:
# ## vits并行推理 method 2
@ -1022,7 +1105,7 @@ class TTS:
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
else:
# ## vits串行推理
for i, idx in enumerate(idx_list):
for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode(
@ -1032,7 +1115,7 @@ class TTS:
audio_fragment
) ###试试重建不带上prompt部分
else:
for i, idx in enumerate(idx_list):
for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.v3_synthesis(
@ -1047,39 +1130,38 @@ class TTS:
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment],
self.configs.sampling_rate,
output_sr,
None,
speed_factor,
False,
fragment_interval
fragment_interval,
super_sampling if self.configs.is_v3_synthesizer else False
)
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
yield 16000, np.zeros(int(16000), dtype=np.int16)
return
if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if len(audio) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
yield 16000, np.zeros(int(16000), dtype=np.int16)
return
yield self.audio_postprocess(audio,
self.configs.sampling_rate,
output_sr,
batch_index_list,
speed_factor,
split_bucket,
fragment_interval
fragment_interval,
super_sampling if self.configs.is_v3_synthesizer else False
)
except Exception as e:
traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
yield 16000, np.zeros(int(16000), dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。
del self.t2s_model
del self.vits_model
@ -1107,7 +1189,8 @@ class TTS:
batch_index_list:list=None,
speed_factor:float=1.0,
split_bucket:bool=True,
fragment_interval:float=0.3
fragment_interval:float=0.3,
super_sampling:bool=False,
)->Tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval),
@ -1120,7 +1203,7 @@ class TTS:
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy()
audio[i][j] = audio_fragment
if split_bucket:
@ -1129,8 +1212,21 @@ class TTS:
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])
audio = torch.cat(audio, dim=0)
if super_sampling:
print(f"############ {i18n('音频超采样')} ############")
t1 = ttime()
self.init_sr_model()
if not self.sr_model_not_exist:
audio,sr=self.sr_model(audio.unsqueeze(0),sr)
max_audio=np.abs(audio).max()
if max_audio > 1: audio /= max_audio
t2 = ttime()
print(f"超采样用时:{t2-t1:.3f}s")
else:
audio = audio.cpu().numpy()
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
# try:
@ -1146,10 +1242,10 @@ class TTS:
semantic_tokens:torch.Tensor,
phones:torch.Tensor,
speed:float=1.0,
sample_steps:int=16
sample_steps:int=32
):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device)
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
@ -1160,7 +1256,7 @@ class TTS:
if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr!=24000:
ref_audio=resample(ref_audio, ref_sr)
ref_audio=resample(ref_audio, ref_sr, self.configs.device)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
@ -1198,36 +1294,3 @@ class TTS:
audio=wav_gen[0][0]#.cpu().detach().numpy()
return audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()
# 设置 ffmpeg 输入流
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
# 变速处理
output_stream = input_stream.filter('atempo', speed)
# 输出流到管道
out, _ = (
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
)
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
resample_transform_dict={}
def resample(audio_tensor, sr0, device):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
return resample_transform_dict[sr0](audio_tensor)

View File

@ -199,6 +199,7 @@ class TextPreprocessor:
return phone_level_feature.T
def clean_text_inf(self, text:str, language:str, version:str="v2"):
language = language.replace("all_","")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text

View File

@ -3,10 +3,10 @@ custom:
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda
is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
default:
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
@ -14,7 +14,7 @@ default:
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
version: v1
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
default_v2:
v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
@ -22,3 +22,11 @@ default_v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth

View File

@ -7,7 +7,7 @@
全部按日文识别
'''
import random
import os, re, logging
import os, re, logging, json
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -62,6 +62,9 @@ if torch.cuda.is_available():
else:
device = "cpu"
# is_half = False
# device = "cpu"
dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
@ -123,7 +126,7 @@ def inference(text, text_lang,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
repetition_penalty, sample_steps, super_sampling,
):
seed = -1 if keep_random else seed
@ -147,6 +150,8 @@ def inference(text, text_lang,
"seed":actual_seed,
"parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty,
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
}
for item in tts_pipeline.run(inputs):
yield item, actual_seed
@ -163,19 +168,38 @@ def change_choices():
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
_ =[[],[]]
for i in range(2):
if os.path.exists(pretrained_gpt_name[i]):
_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
for i in range(3):
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name,pretrained_sovits_name = _
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
if os.path.exists(f"./weight.json"):
pass
else:
with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
with open(f"./weight.json", 'r', encoding="utf-8") as file:
weight_data = file.read()
weight_data=json.loads(weight_data)
gpt_path = os.environ.get(
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
sovits_path = os.environ.get(
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
if isinstance(gpt_path,list):
gpt_path = gpt_path[0]
if isinstance(sovits_path,list):
sovits_path = sovits_path[0]
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
for path in SoVITS_weight_root+GPT_weight_root:
os.makedirs(path,exist_ok=True)
@ -194,12 +218,10 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from process_ckpt import get_sovits_version_from_path_fast
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
global version, dict_language
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3)
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
if if_lora_v3 and not os.path.exists(path_sovits_v3):
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
@ -228,7 +250,11 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
with open("./weight.json")as f:
data=f.read()
data=json.loads(data)
data["SoVITS"][version]=sovits_path
with open("./weight.json","w")as f:f.write(json.dumps(data))
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
@ -274,12 +300,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
with gr.Column():
with gr.Row():
how_to_cut = gr.Dropdown(
@ -288,10 +316,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
value=i18n("凑四句一切"),
interactive=True, scale=1
)
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
with gr.Row():
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
with gr.Row():
seed = gr.Number(label=i18n("随机种子"),value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
@ -311,7 +343,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
repetition_penalty, sample_steps, super_sampling,
],
[output, seed],
)

View File

@ -39,6 +39,8 @@ POST:
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
}
```
@ -164,6 +166,8 @@ class TTS_Request(BaseModel):
streaming_mode:bool = False
parallel_infer:bool = True
repetition_penalty:float = 1.35
sample_steps:int = 32
super_sampling:bool = False
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
@ -294,7 +298,9 @@ async def tts_handle(req:dict):
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
}
returns:
StreamingResponse: audio stream response.
@ -316,10 +322,12 @@ async def tts_handle(req:dict):
if streaming_mode:
def streaming_generator(tts_generator:Generator, media_type:str):
if media_type == "wav":
yield wave_header_chunk()
media_type = "raw"
if_frist_chunk = True
for sr, chunk in tts_generator:
if if_frist_chunk and media_type == "wav":
yield wave_header_chunk(sample_rate=sr)
media_type = "raw"
if_frist_chunk = False
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
@ -365,7 +373,9 @@ async def tts_get_endpoint(
media_type:str = "wav",
streaming_mode:bool = False,
parallel_infer:bool = True,
repetition_penalty:float = 1.35
repetition_penalty:float = 1.35,
sample_steps:int =32,
super_sampling:bool = False
):
req = {
"text": text,
@ -387,7 +397,9 @@ async def tts_get_endpoint(
"media_type":media_type,
"streaming_mode":streaming_mode,
"parallel_infer":parallel_infer,
"repetition_penalty":float(repetition_penalty)
"repetition_penalty":float(repetition_penalty),
"sample_steps":int(sample_steps),
"super_sampling":super_sampling
}
return await tts_handle(req)

View File

@ -39,6 +39,11 @@ class AP_BWE():
self.model=model
self.h=h
def to(self, *arg, **kwargs):
self.model.to(*arg, **kwargs)
self.device = self.model.conv_pre_mag.weight.device
return self
def __call__(self, audio,orig_sampling_rate):
with torch.no_grad():
# audio, orig_sampling_rate = torchaudio.load(inp_path)