feat: 添加导出v4的script

This commit is contained in:
csh 2025-05-29 02:41:21 +08:00
parent acd68355c9
commit 150cd842df
2 changed files with 270 additions and 51 deletions

View File

@ -10,7 +10,7 @@ from inference_webui import get_phones_and_bert
import librosa
from module import commons
from module.mel_processing import mel_spectrogram_torch
from module.models_onnx import CFM, SynthesizerTrnV3
from module.models_onnx import CFM, Generator, SynthesizerTrnV3
import numpy as np
import torch._dynamo.config
import torchaudio
@ -46,7 +46,7 @@ class MelSpectrgram(torch.nn.Module):
center=False,
):
super().__init__()
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
self.hann_window = torch.hann_window(win_size).to(device=device, dtype=dtype)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
self.n_fft: int = n_fft
@ -189,6 +189,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
"center": False,
},
)
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
spec_min = -12
spec_max = 2
@ -285,6 +298,84 @@ class ExportGPTSovitsHalf(torch.nn.Module):
return fea_ref, fea_todo, mel2
class ExportGPTSovitsV4Half(torch.nn.Module):
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
super().__init__()
self.hps = hps
self.t2s_m = t2s_m
self.vq_model = vq_model
self.mel2 = MelSpectrgram(
dtype=torch.float32,
device=device,
n_fft=1280,
num_mels=100,
sampling_rate=32000,
hop_size=320,
win_size=1280,
fmin=0,
fmax=None,
center=False,
)
# self.dtype = dtype
self.filter_length: int = hps.data.filter_length
self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length
def forward(
self,
ssl_content,
ref_audio_32k: torch.FloatTensor,
phoneme_ids0,
phoneme_ids1,
bert1,
bert2,
top_k,
):
refer = spectrogram_torch(
ref_audio_32k,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
).to(ssl_content.dtype)
codes = self.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0)
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
ge = self.vq_model.create_ge(refer)
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
prompt_ = prompt.unsqueeze(0)
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
# print(prompt_.shape, phoneme_ids0.shape, ge.shape)
# print(fea_ref.shape)
ref_32k = ref_audio_32k
mel2 = norm_spec(self.mel2(ref_32k)).to(ssl_content.dtype)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if T_min > 500:
mel2 = mel2[:, :, -500:]
fea_ref = fea_ref[:, :, -500:]
T_min = 500
fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge)
# print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
# print(pred_semantic.shape, phoneme_ids1.shape, ge.shape)
# print(fea_todo.shape)
return fea_ref, fea_todo, mel2
class GPTSoVITSV3(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, bigvgan):
super().__init__()
@ -311,6 +402,7 @@ class GPTSoVITSV3(torch.nn.Module):
chunk_len = 934 - fea_ref.shape[2]
wav_gen_list = []
idx = 0
fea_todo = fea_todo[:,:,:-5]
wav_gen_length = fea_todo.shape[2] * 256
while 1:
# current_time = datetime.now()
@ -342,6 +434,65 @@ class GPTSoVITSV3(torch.nn.Module):
wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length]
class GPTSoVITSV4(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, hifigan):
super().__init__()
self.gpt_sovits_half = gpt_sovits_half
self.cfm = cfm
self.hifigan = hifigan
def forward(
self,
ssl_content,
ref_audio_32k: torch.FloatTensor,
phoneme_ids0: torch.LongTensor,
phoneme_ids1: torch.LongTensor,
bert1,
bert2,
top_k: torch.LongTensor,
sample_steps: torch.LongTensor,
):
# current_time = datetime.now()
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
chunk_len = 1000 - fea_ref.shape[2]
wav_gen_list = []
idx = 0
fea_todo = fea_todo[:,:,:-10]
wav_gen_length = fea_todo.shape[2] * 480
while 1:
# current_time = datetime.now()
# print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
if fea_todo_chunk.shape[-1] == 0:
break
# 因为导出的模型在不同shape时会重新编译还是怎么的会卡顿10s这样
# 所以在这里补0让他shape维持不变
# 但是这样会导致生成的音频长度不对,所以在最后截取一下。
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
complete_len = chunk_len - fea_todo_chunk.shape[-1]
if complete_len != 0:
fea_todo_chunk = torch.cat(
[
fea_todo_chunk,
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
],
2,
)
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
idx += chunk_len
cfm_res = denorm_spec(cfm_res)
hifigan_res = self.hifigan(cfm_res)
wav_gen_list.append(hifigan_res)
wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length]
def init_bigvgan():
@ -361,6 +512,31 @@ def init_bigvgan():
bigvgan_model = bigvgan_model.to(device)
def init_hifigan():
global hifigan_model, bigvgan_model
hifigan_model = Generator(
initial_channel=100,
resblock="1",
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0,
is_bias=True,
)
hifigan_model.eval()
hifigan_model.remove_weight_norm()
state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
)
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
if is_half == True:
hifigan_model = hifigan_model.half().to(device)
else:
hifigan_model = hifigan_model.to(device)
class Sovits:
def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
self.vq_model = vq_model
@ -399,6 +575,7 @@ class DictToAttrRecursive(dict):
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
v3v4set = {"v3", "v4"}
def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
@ -419,8 +596,8 @@ def get_sovits_weights(sovits_path):
else:
hps.model.version = "v2"
if model_version == "v3":
hps.model.version = "v3"
if model_version in v3v4set:
hps.model.version = model_version
logger.info(f"hps: {hps}")
@ -522,10 +699,14 @@ def export_cfm(
return export_cfm
def export():
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
init_bigvgan()
def export(version="v3"):
if version == "v3":
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
init_bigvgan()
else:
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
init_hifigan()
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
@ -542,7 +723,7 @@ def export():
hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0
sample_steps = 32
sample_steps = 8
dtype = torch.float16 if is_half == True else torch.float32
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
zero_wav = np.zeros(
@ -634,25 +815,33 @@ def export():
# vq_model = sovits.vq_model
vq_model = trace_vq_model
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
if version == "v3":
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
else:
gpt_sovits_half = ExportGPTSovitsV4Half(sovits.hps, script_t2s, trace_vq_model)
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v4_half.pt")
ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr != 24000:
ref_audio = resample(ref_audio, sr)
tgt_sr = 24000 if version == "v3" else 32000
if sr != tgt_sr:
ref_audio = resample(ref_audio, sr, tgt_sr)
# mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel_fn(ref_audio))
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
fea_ref = fea_ref[:, :, :T_min]
print("fea_ref:", fea_ref.shape, T_min)
if T_min > 468:
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
Tref = 468 if version == "v3" else 500
Tchunk = 934 if version == "v3" else 1000
if T_min > Tref:
mel2 = mel2[:, :, -Tref:]
fea_ref = fea_ref[:, :, -Tref:]
T_min = Tref
chunk_len = Tchunk - T_min
mel2 = mel2.to(dtype)
# fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
@ -714,13 +903,19 @@ def export():
with torch.inference_mode():
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
wav_gen = bigvgan_model(cmf_res)
if version == "v3":
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
wav_gen = bigvgan_model(cmf_res)
else:
hifigan_model_ = torch.jit.trace(hifigan_model, optimize=True, example_inputs=(cmf_res_rand,))
hifigan_model_.save("onnx/ad/hifigan_model.pt")
wav_gen = hifigan_model(cmf_res)
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
audio = wav_gen[0][0].cpu().detach().numpy()
sr = 24000
sr = 24000 if version == "v3" else 48000
soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
@ -848,8 +1043,9 @@ def test_export(
def test_export1(
todo_text,
gpt_sovits_v3,
gpt_sovits_v3v4,
output,
out_sr=24000,
):
# hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav"
@ -859,7 +1055,7 @@ def test_export1(
dtype = torch.float16 if is_half == True else torch.float32
zero_wav = np.zeros(
int(24000 * 0.3),
int(out_sr * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
@ -907,22 +1103,26 @@ def test_export1(
bert2.shape,
top_k.shape,
)
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
wav_gen = gpt_sovits_v3v4(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
audio = wav_gen.cpu().detach().numpy()
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
sr = 24000
soundfile.write(output, (audio * 32768).astype(np.int16), sr)
soundfile.write(output, (audio * 32768).astype(np.int16), out_sr)
import time
def test_():
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
def test_(version="v3"):
if version == "v3":
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
# init_bigvgan()
else:
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
# init_hifigan()
# cfm = ExportCFM(sovits.cfm)
# cfm.cfm.estimator = dit
@ -963,25 +1163,41 @@ def test_():
# gpt_sovits_v3_half = gpt_sovits_v3_half.half()
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
# gpt_sovits_v3_half.eval()
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
logger.info("gpt_sovits_v3_half ok")
if version == "v3":
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
logger.info("gpt_sovits_v3_half ok")
# init_bigvgan()
# global bigvgan_model
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
bigvgan_model = bigvgan_model.half()
bigvgan_model = bigvgan_model.cuda()
bigvgan_model.eval()
# init_bigvgan()
# global bigvgan_model
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
bigvgan_model = bigvgan_model.half()
bigvgan_model = bigvgan_model.cuda()
bigvgan_model.eval()
logger.info("bigvgan ok")
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
gpt_sovits_v3.eval()
print("save gpt_sovits_v3 ok")
else:
gpt_sovits_v4_half = ExportGPTSovitsV4Half(sovits.hps, t2s_m, vq_model)
logger.info("gpt_sovits_v4 ok")
logger.info("bigvgan ok")
hifigan_model = torch.jit.load("onnx/ad/hifigan_model.pt")
hifigan_model = hifigan_model.half()
hifigan_model = hifigan_model.cuda()
hifigan_model.eval()
logger.info("hifigan ok")
gpt_sovits_v4 = GPTSoVITSV4(gpt_sovits_v4_half, cfm, hifigan_model)
gpt_sovits_v4 = torch.jit.script(gpt_sovits_v4)
gpt_sovits_v4.save("onnx/ad/gpt_sovits_v4.pt")
print("save gpt_sovits_v4 ok")
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
sr = 24000 if version == "v3" else 48000
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
gpt_sovits_v3.eval()
print("save gpt_sovits_v3 ok")
time.sleep(5)
# print("thread:", torch.get_num_threads())
@ -991,14 +1207,16 @@ def test_():
test_export1(
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
gpt_sovits_v3,
gpt_sovits_v3v4,
"out.wav",
sr
)
test_export1(
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
gpt_sovits_v3,
gpt_sovits_v3v4,
"out2.wav",
sr
)
# test_export(
@ -1030,6 +1248,6 @@ def test_export_gpt_sovits_v3():
with torch.no_grad():
# export()
test_()
export("v4")
# test_("v4")
# test_export_gpt_sovits_v3()

View File

@ -391,6 +391,7 @@ class Generator(torch.nn.Module):
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
is_bias=False,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
@ -418,7 +419,7 @@ class Generator(torch.nn.Module):
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
self.ups.apply(init_weights)
if gin_channels != 0: