api_v2.py和inference_webui_fast.py的v3适配

This commit is contained in:
ChasonJiang 2025-03-13 17:41:17 +08:00
parent bf06ac589c
commit f25d830ef4
12 changed files with 235 additions and 112 deletions

4
.gitignore vendored
View File

@ -17,4 +17,6 @@ SoVITS_weights_v3
TEMP TEMP
weight.json weight.json
ffmpeg* ffmpeg*
ffprobe* ffprobe*
tools/AP_BWE_main/24kto48k/*
!tools/AP_BWE_main/24kto48k/readme.txt

View File

@ -2,7 +2,7 @@
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
import torch.nn as nn import torch.nn as nn
from alias_free_activation.torch.resample import UpSample1d, DownSample1d from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module): class Activation1d(nn.Module):

View File

@ -3,8 +3,8 @@
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from alias_free_activation.torch.filter import LowPassFilter1d from .filter import LowPassFilter1d
from alias_free_activation.torch.filter import kaiser_sinc_filter1d from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module): class UpSample1d(nn.Module):

View File

@ -14,10 +14,10 @@ import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm from torch.nn.utils import weight_norm, remove_weight_norm
import activations from . import activations
from utils0 import init_weights, get_padding from .utils0 import init_weights, get_padding
from alias_free_activation.torch.act import Activation1d as TorchActivation1d from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
from env import AttrDict from .env import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
@ -93,7 +93,7 @@ class AMPBlock1(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d,
) )
@ -193,7 +193,7 @@ class AMPBlock2(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d,
) )
@ -271,7 +271,7 @@ class BigVGAN(
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import ( from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d, Activation1d as CudaActivation1d,
) )

View File

@ -15,7 +15,7 @@ from librosa.filters import mel as librosa_mel_fn
import pathlib import pathlib
from tqdm import tqdm from tqdm import tqdm
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from env import AttrDict from .env import AttrDict
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)

View File

@ -9,7 +9,7 @@ from torch.nn.utils import weight_norm
matplotlib.use("Agg") matplotlib.use("Agg")
import matplotlib.pylab as plt import matplotlib.pylab as plt
from meldataset import MAX_WAV_VALUE from .meldataset import MAX_WAV_VALUE
from scipy.io.wavfile import write from scipy.io.wavfile import write

View File

@ -16,7 +16,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import yaml import yaml
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from feature_extractor.cnhubert import CNHubert from feature_extractor.cnhubert import CNHubert
from module.models import SynthesizerTrn, SynthesizerTrnV3 from module.models import SynthesizerTrn, SynthesizerTrnV3
@ -55,6 +55,65 @@ mel_fn=lambda x: mel_spectrogram_torch(x, **{
}) })
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()
# 设置 ffmpeg 输入流
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
# 变速处理
output_stream = input_stream.filter('atempo', speed)
# 输出流到管道
out, _ = (
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
)
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
resample_transform_dict={}
def resample(audio_tensor, sr0, device):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
return resample_transform_dict[sr0](audio_tensor)
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
@ -303,6 +362,8 @@ class TTS:
self.bert_model:AutoModelForMaskedLM = None self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None self.cnhuhbert_model:CNHubert = None
self.bigvgan_model:BigVGAN = None self.bigvgan_model:BigVGAN = None
self.sr_model:AP_BWE = None
self.sr_model_not_exist:bool = False
self._init_models() self._init_models()
@ -370,13 +431,13 @@ class TTS:
hps = dict_s2["config"] hps = dict_s2["config"]
hps["model"]["semantic_frame_rate"] = "25hz" hps["model"]["semantic_frame_rate"] = "25hz"
if 'enc_p.text_embedding.weight'not in dict_s2['weight']: # if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps["model"]["version"] = "v2"#v3model,v2sybomls # hps["model"]["version"] = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: # elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps["model"]["version"] = "v1" # hps["model"]["version"] = "v1"
else: # else:
hps["model"]["version"] = "v2" # hps["model"]["version"] = "v2"
version = hps["model"]["version"] # version = hps["model"]["version"]
self.configs.filter_length = hps["data"]["filter_length"] self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"] self.configs.segment_size = hps["train"]["segment_size"]
@ -386,8 +447,9 @@ class TTS:
self.configs.n_speakers = hps["data"]["n_speakers"] self.configs.n_speakers = hps["data"]["n_speakers"]
self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"] self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"]
kwargs = hps["model"] kwargs = hps["model"]
# print(f"self.configs.sampling_rate:{self.configs.sampling_rate}")
self.configs.update_version(version) self.configs.update_version(model_version)
if model_version!="v3": if model_version!="v3":
@ -397,7 +459,6 @@ class TTS:
n_speakers=self.configs.n_speakers, n_speakers=self.configs.n_speakers,
**kwargs **kwargs
) )
model_version=version
if hasattr(vits_model, "enc_q"): if hasattr(vits_model, "enc_q"):
del vits_model.enc_q del vits_model.enc_q
self.configs.is_v3_synthesizer = False self.configs.is_v3_synthesizer = False
@ -413,9 +474,9 @@ class TTS:
if if_lora_v3==False: if if_lora_v3==False:
print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}") print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
else: else:
print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)}") print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}")
lora_rank=dict_s2["lora_rank"] lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig( lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@ -424,7 +485,7 @@ class TTS:
init_lora_weights=True, init_lora_weights=True,
) )
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config) vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}") print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
vits_model.cfm = vits_model.cfm.merge_and_unload() vits_model.cfm = vits_model.cfm.merge_and_unload()
@ -466,6 +527,15 @@ class TTS:
else: else:
self.bigvgan_model = self.bigvgan_model.to(self.configs.device) self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
def init_sr_model(self):
if self.sr_model is not None:
return
try:
self.sr_model:AP_BWE=AP_BWE(self.configs.device,DictToAttrRecursive)
self.sr_model_not_exist = False
except FileNotFoundError:
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
self.sr_model_not_exist = True
def enable_half_precision(self, enable: bool = True, save: bool = True): def enable_half_precision(self, enable: bool = True, save: bool = True):
@ -523,6 +593,11 @@ class TTS:
self.bert_model = self.bert_model.to(device) self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device) self.cnhuhbert_model = self.cnhuhbert_model.to(device)
if self.bigvgan_model is not None:
self.bigvgan_model = self.bigvgan_model.to(device)
if self.sr_model is not None:
self.sr_model = self.sr_model.to(device)
def set_ref_audio(self, ref_audio_path:str): def set_ref_audio(self, ref_audio_path:str):
''' '''
@ -546,6 +621,11 @@ class TTS:
self.prompt_cache["refer_spec"][0] = spec self.prompt_cache["refer_spec"][0] = spec
def _get_ref_spec(self, ref_audio_path): def _get_ref_spec(self, ref_audio_path):
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
raw_audio=raw_audio.to(self.configs.device).float()
self.prompt_cache["raw_audio"] = raw_audio
self.prompt_cache["raw_sr"] = raw_sr
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
maxx=audio.abs().max() maxx=audio.abs().max()
@ -734,11 +814,11 @@ class TTS:
Recovery the order of the audio according to the batch_index_list. Recovery the order of the audio according to the batch_index_list.
Args: Args:
data (List[list(np.ndarray)]): the out of order audio . data (List[list(torch.Tensor)]): the out of order audio .
batch_index_list (List[list[int]]): the batch index list. batch_index_list (List[list[int]]): the batch index list.
Returns: Returns:
list (List[np.ndarray]): the data in the original order. list (List[torch.Tensor]): the data in the original order.
''' '''
length = len(sum(batch_index_list, [])) length = len(sum(batch_index_list, []))
_data = [None]*length _data = [None]*length
@ -780,6 +860,8 @@ class TTS:
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model. "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
} }
returns: returns:
Tuple[int, np.ndarray]: sampling rate and audio data. Tuple[int, np.ndarray]: sampling rate and audio data.
@ -807,7 +889,8 @@ class TTS:
actual_seed = set_seed(seed) actual_seed = set_seed(seed)
parallel_infer = inputs.get("parallel_infer", True) parallel_infer = inputs.get("parallel_infer", True)
repetition_penalty = inputs.get("repetition_penalty", 1.35) repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 16) sample_steps = inputs.get("sample_steps", 32)
super_sampling = inputs.get("super_sampling", False)
if parallel_infer: if parallel_infer:
print(i18n("并行推理模式已开启")) print(i18n("并行推理模式已开启"))
@ -894,8 +977,7 @@ class TTS:
if not return_fragment: if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if len(data) == 0: if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield 16000, np.zeros(int(16000), dtype=np.int16)
dtype=np.int16)
return return
batch_index_list:list = None batch_index_list:list = None
@ -949,6 +1031,7 @@ class TTS:
t_34 = 0.0 t_34 = 0.0
t_45 = 0.0 t_45 = 0.0
audio = [] audio = []
output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000
for item in data: for item in data:
t3 = ttime() t3 = ttime()
if return_fragment: if return_fragment:
@ -971,7 +1054,7 @@ class TTS:
else: else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
print(f"############ {i18n('预测语义Token')} ############")
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
all_phoneme_lens, all_phoneme_lens,
@ -1005,7 +1088,7 @@ class TTS:
# batch_audio_fragment = (self.vits_model.batched_decode( # batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# )) # ))
print(f"############ {i18n('合成音频')} ############")
if not self.configs.is_v3_synthesizer: if not self.configs.is_v3_synthesizer:
if speed_factor == 1.0: if speed_factor == 1.0:
# ## vits并行推理 method 2 # ## vits并行推理 method 2
@ -1022,7 +1105,7 @@ class TTS:
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
else: else:
# ## vits串行推理 # ## vits串行推理
for i, idx in enumerate(idx_list): for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device) phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode( audio_fragment =(self.vits_model.decode(
@ -1032,7 +1115,7 @@ class TTS:
audio_fragment audio_fragment
) ###试试重建不带上prompt部分 ) ###试试重建不带上prompt部分
else: else:
for i, idx in enumerate(idx_list): for i, idx in enumerate(tqdm(idx_list)):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device) phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.v3_synthesis( audio_fragment = self.v3_synthesis(
@ -1047,39 +1130,38 @@ class TTS:
if return_fragment: if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment], yield self.audio_postprocess([batch_audio_fragment],
self.configs.sampling_rate, output_sr,
None, None,
speed_factor, speed_factor,
False, False,
fragment_interval fragment_interval,
super_sampling if self.configs.is_v3_synthesizer else False
) )
else: else:
audio.append(batch_audio_fragment) audio.append(batch_audio_fragment)
if self.stop_flag: if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield 16000, np.zeros(int(16000), dtype=np.int16)
dtype=np.int16)
return return
if not return_fragment: if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if len(audio) == 0: if len(audio) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield 16000, np.zeros(int(16000), dtype=np.int16)
dtype=np.int16)
return return
yield self.audio_postprocess(audio, yield self.audio_postprocess(audio,
self.configs.sampling_rate, output_sr,
batch_index_list, batch_index_list,
speed_factor, speed_factor,
split_bucket, split_bucket,
fragment_interval fragment_interval,
super_sampling if self.configs.is_v3_synthesizer else False
) )
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。 # 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield 16000, np.zeros(int(16000), dtype=np.int16)
dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。 # 重置模型, 否则会导致显存释放不完全。
del self.t2s_model del self.t2s_model
del self.vits_model del self.vits_model
@ -1107,7 +1189,8 @@ class TTS:
batch_index_list:list=None, batch_index_list:list=None,
speed_factor:float=1.0, speed_factor:float=1.0,
split_bucket:bool=True, split_bucket:bool=True,
fragment_interval:float=0.3 fragment_interval:float=0.3,
super_sampling:bool=False,
)->Tuple[int, np.ndarray]: )->Tuple[int, np.ndarray]:
zero_wav = torch.zeros( zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), int(self.configs.sampling_rate * fragment_interval),
@ -1120,7 +1203,7 @@ class TTS:
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio if max_audio>1: audio_fragment/=max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy() audio[i][j] = audio_fragment
if split_bucket: if split_bucket:
@ -1129,8 +1212,21 @@ class TTS:
# audio = [item for batch in audio for item in batch] # audio = [item for batch in audio for item in batch]
audio = sum(audio, []) audio = sum(audio, [])
audio = torch.cat(audio, dim=0)
if super_sampling:
print(f"############ {i18n('音频超采样')} ############")
t1 = ttime()
self.init_sr_model()
if not self.sr_model_not_exist:
audio,sr=self.sr_model(audio.unsqueeze(0),sr)
max_audio=np.abs(audio).max()
if max_audio > 1: audio /= max_audio
t2 = ttime()
print(f"超采样用时:{t2-t1:.3f}s")
else:
audio = audio.cpu().numpy()
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16) audio = (audio * 32768).astype(np.int16)
# try: # try:
@ -1146,10 +1242,10 @@ class TTS:
semantic_tokens:torch.Tensor, semantic_tokens:torch.Tensor,
phones:torch.Tensor, phones:torch.Tensor,
speed:float=1.0, speed:float=1.0,
sample_steps:int=16 sample_steps:int=32
): ):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device) prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device) refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
@ -1160,7 +1256,7 @@ class TTS:
if (ref_audio.shape[0] == 2): if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
if ref_sr!=24000: if ref_sr!=24000:
ref_audio=resample(ref_audio, ref_sr) ref_audio=resample(ref_audio, ref_sr, self.configs.device)
# print("ref_audio",ref_audio.abs().mean()) # print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio) mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
@ -1198,36 +1294,3 @@ class TTS:
audio=wav_gen[0][0]#.cpu().detach().numpy() audio=wav_gen[0][0]#.cpu().detach().numpy()
return audio return audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()
# 设置 ffmpeg 输入流
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
# 变速处理
output_stream = input_stream.filter('atempo', speed)
# 输出流到管道
out, _ = (
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
)
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
resample_transform_dict={}
def resample(audio_tensor, sr0, device):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
sr0, 24000
).to(device)
return resample_transform_dict[sr0](audio_tensor)

View File

@ -199,6 +199,7 @@ class TextPreprocessor:
return phone_level_feature.T return phone_level_feature.T
def clean_text_inf(self, text:str, language:str, version:str="v2"): def clean_text_inf(self, text:str, language:str, version:str="v2"):
language = language.replace("all_","")
phones, word2ph, norm_text = clean_text(text, language, version) phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version) phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text return phones, word2ph, norm_text

View File

@ -3,10 +3,10 @@ custom:
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda device: cuda
is_half: true is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v2 version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
default: v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
@ -14,7 +14,7 @@ default:
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
version: v1 version: v1
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
default_v2: v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
@ -22,3 +22,11 @@ default_v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2 version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth

View File

@ -7,7 +7,7 @@
全部按日文识别 全部按日文识别
''' '''
import random import random
import os, re, logging import os, re, logging, json
import sys import sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -62,6 +62,9 @@ if torch.cuda.is_available():
else: else:
device = "cpu" device = "cpu"
# is_half = False
# device = "cpu"
dict_language_v1 = { dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别 i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变 i18n("英文"): "en",#全部按英文识别#######不变
@ -123,7 +126,7 @@ def inference(text, text_lang,
speed_factor, ref_text_free, speed_factor, ref_text_free,
split_bucket,fragment_interval, split_bucket,fragment_interval,
seed, keep_random, parallel_infer, seed, keep_random, parallel_infer,
repetition_penalty repetition_penalty, sample_steps, super_sampling,
): ):
seed = -1 if keep_random else seed seed = -1 if keep_random else seed
@ -147,6 +150,8 @@ def inference(text, text_lang,
"seed":actual_seed, "seed":actual_seed,
"parallel_infer": parallel_infer, "parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty, "repetition_penalty": repetition_penalty,
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
} }
for item in tts_pipeline.run(inputs): for item in tts_pipeline.run(inputs):
yield item, actual_seed yield item, actual_seed
@ -163,19 +168,38 @@ def change_choices():
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"} return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
_ =[[],[]] _ =[[],[]]
for i in range(2): for i in range(3):
if os.path.exists(pretrained_gpt_name[i]): if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
_[0].append(pretrained_gpt_name[i]) if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name,pretrained_sovits_name = _ pretrained_gpt_name,pretrained_sovits_name = _
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
GPT_weight_root=["GPT_weights_v2","GPT_weights"] if os.path.exists(f"./weight.json"):
pass
else:
with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
with open(f"./weight.json", 'r', encoding="utf-8") as file:
weight_data = file.read()
weight_data=json.loads(weight_data)
gpt_path = os.environ.get(
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
sovits_path = os.environ.get(
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
if isinstance(gpt_path,list):
gpt_path = gpt_path[0]
if isinstance(sovits_path,list):
sovits_path = sovits_path[0]
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
for path in SoVITS_weight_root+GPT_weight_root: for path in SoVITS_weight_root+GPT_weight_root:
os.makedirs(path,exist_ok=True) os.makedirs(path,exist_ok=True)
@ -194,12 +218,10 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from process_ckpt import get_sovits_version_from_path_fast
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
global version, dict_language global version, dict_language
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3)
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
if if_lora_v3 and not os.path.exists(path_sovits_v3): if if_lora_v3 and not os.path.exists(path_sovits_v3):
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
@ -228,7 +250,11 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
visible_inp_refs=True visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False} yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
with open("./weight.json")as f:
data=f.read()
data=json.loads(data)
data["SoVITS"][version]=sovits_path
with open("./weight.json","w")as f:f.write(json.dumps(data))
with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown( gr.Markdown(
@ -274,12 +300,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Column(): with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True) fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True) speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True) repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
how_to_cut = gr.Dropdown( how_to_cut = gr.Dropdown(
@ -288,10 +316,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
value=i18n("凑四句一切"), value=i18n("凑四句一切"),
interactive=True, scale=1 interactive=True, scale=1
) )
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
with gr.Row():
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True) parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True) split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
with gr.Row(): with gr.Row():
seed = gr.Number(label=i18n("随机种子"),value=-1) seed = gr.Number(label=i18n("随机种子"),value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True) keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
@ -311,7 +343,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
speed_factor, ref_text_free, speed_factor, ref_text_free,
split_bucket,fragment_interval, split_bucket,fragment_interval,
seed, keep_random, parallel_infer, seed, keep_random, parallel_infer,
repetition_penalty repetition_penalty, sample_steps, super_sampling,
], ],
[output, seed], [output, seed],
) )

View File

@ -39,6 +39,8 @@ POST:
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model. "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
} }
``` ```
@ -164,6 +166,8 @@ class TTS_Request(BaseModel):
streaming_mode:bool = False streaming_mode:bool = False
parallel_infer:bool = True parallel_infer:bool = True
repetition_penalty:float = 1.35 repetition_penalty:float = 1.35
sample_steps:int = 32
super_sampling:bool = False
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
@ -294,7 +298,9 @@ async def tts_handle(req:dict):
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response. "streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference. "parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
} }
returns: returns:
StreamingResponse: audio stream response. StreamingResponse: audio stream response.
@ -316,10 +322,12 @@ async def tts_handle(req:dict):
if streaming_mode: if streaming_mode:
def streaming_generator(tts_generator:Generator, media_type:str): def streaming_generator(tts_generator:Generator, media_type:str):
if media_type == "wav": if_frist_chunk = True
yield wave_header_chunk()
media_type = "raw"
for sr, chunk in tts_generator: for sr, chunk in tts_generator:
if if_frist_chunk and media_type == "wav":
yield wave_header_chunk(sample_rate=sr)
media_type = "raw"
if_frist_chunk = False
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
@ -365,7 +373,9 @@ async def tts_get_endpoint(
media_type:str = "wav", media_type:str = "wav",
streaming_mode:bool = False, streaming_mode:bool = False,
parallel_infer:bool = True, parallel_infer:bool = True,
repetition_penalty:float = 1.35 repetition_penalty:float = 1.35,
sample_steps:int =32,
super_sampling:bool = False
): ):
req = { req = {
"text": text, "text": text,
@ -387,7 +397,9 @@ async def tts_get_endpoint(
"media_type":media_type, "media_type":media_type,
"streaming_mode":streaming_mode, "streaming_mode":streaming_mode,
"parallel_infer":parallel_infer, "parallel_infer":parallel_infer,
"repetition_penalty":float(repetition_penalty) "repetition_penalty":float(repetition_penalty),
"sample_steps":int(sample_steps),
"super_sampling":super_sampling
} }
return await tts_handle(req) return await tts_handle(req)

View File

@ -39,6 +39,11 @@ class AP_BWE():
self.model=model self.model=model
self.h=h self.h=h
def to(self, *arg, **kwargs):
self.model.to(*arg, **kwargs)
self.device = self.model.conv_pre_mag.weight.device
return self
def __call__(self, audio,orig_sampling_rate): def __call__(self, audio,orig_sampling_rate):
with torch.no_grad(): with torch.no_grad():
# audio, orig_sampling_rate = torchaudio.load(inp_path) # audio, orig_sampling_rate = torchaudio.load(inp_path)