mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-01 03:55:53 +08:00
feat: 添加导出v4的script
This commit is contained in:
parent
acd68355c9
commit
150cd842df
@ -10,7 +10,7 @@ from inference_webui import get_phones_and_bert
|
|||||||
import librosa
|
import librosa
|
||||||
from module import commons
|
from module import commons
|
||||||
from module.mel_processing import mel_spectrogram_torch
|
from module.mel_processing import mel_spectrogram_torch
|
||||||
from module.models_onnx import CFM, SynthesizerTrnV3
|
from module.models_onnx import CFM, Generator, SynthesizerTrnV3
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch._dynamo.config
|
import torch._dynamo.config
|
||||||
import torchaudio
|
import torchaudio
|
||||||
@ -46,7 +46,7 @@ class MelSpectrgram(torch.nn.Module):
|
|||||||
center=False,
|
center=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
self.hann_window = torch.hann_window(win_size).to(device=device, dtype=dtype)
|
||||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
||||||
self.n_fft: int = n_fft
|
self.n_fft: int = n_fft
|
||||||
@ -189,6 +189,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
|||||||
"center": False,
|
"center": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||||
|
x,
|
||||||
|
**{
|
||||||
|
"n_fft": 1280,
|
||||||
|
"win_size": 1280,
|
||||||
|
"hop_size": 320,
|
||||||
|
"num_mels": 100,
|
||||||
|
"sampling_rate": 32000,
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": None,
|
||||||
|
"center": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
spec_min = -12
|
spec_min = -12
|
||||||
spec_max = 2
|
spec_max = 2
|
||||||
@ -285,6 +298,84 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
|||||||
return fea_ref, fea_todo, mel2
|
return fea_ref, fea_todo, mel2
|
||||||
|
|
||||||
|
|
||||||
|
class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||||
|
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
||||||
|
super().__init__()
|
||||||
|
self.hps = hps
|
||||||
|
self.t2s_m = t2s_m
|
||||||
|
self.vq_model = vq_model
|
||||||
|
self.mel2 = MelSpectrgram(
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
n_fft=1280,
|
||||||
|
num_mels=100,
|
||||||
|
sampling_rate=32000,
|
||||||
|
hop_size=320,
|
||||||
|
win_size=1280,
|
||||||
|
fmin=0,
|
||||||
|
fmax=None,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
# self.dtype = dtype
|
||||||
|
self.filter_length: int = hps.data.filter_length
|
||||||
|
self.sampling_rate: int = hps.data.sampling_rate
|
||||||
|
self.hop_length: int = hps.data.hop_length
|
||||||
|
self.win_length: int = hps.data.win_length
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
ssl_content,
|
||||||
|
ref_audio_32k: torch.FloatTensor,
|
||||||
|
phoneme_ids0,
|
||||||
|
phoneme_ids1,
|
||||||
|
bert1,
|
||||||
|
bert2,
|
||||||
|
top_k,
|
||||||
|
):
|
||||||
|
refer = spectrogram_torch(
|
||||||
|
ref_audio_32k,
|
||||||
|
self.filter_length,
|
||||||
|
self.sampling_rate,
|
||||||
|
self.hop_length,
|
||||||
|
self.win_length,
|
||||||
|
center=False,
|
||||||
|
).to(ssl_content.dtype)
|
||||||
|
|
||||||
|
codes = self.vq_model.extract_latent(ssl_content)
|
||||||
|
prompt_semantic = codes[0, 0]
|
||||||
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
|
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
|
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||||
|
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
|
ge = self.vq_model.create_ge(refer)
|
||||||
|
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
|
prompt_ = prompt.unsqueeze(0)
|
||||||
|
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
|
||||||
|
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
# print(prompt_.shape, phoneme_ids0.shape, ge.shape)
|
||||||
|
# print(fea_ref.shape)
|
||||||
|
|
||||||
|
ref_32k = ref_audio_32k
|
||||||
|
mel2 = norm_spec(self.mel2(ref_32k)).to(ssl_content.dtype)
|
||||||
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||||
|
mel2 = mel2[:, :, :T_min]
|
||||||
|
fea_ref = fea_ref[:, :, :T_min]
|
||||||
|
if T_min > 500:
|
||||||
|
mel2 = mel2[:, :, -500:]
|
||||||
|
fea_ref = fea_ref[:, :, -500:]
|
||||||
|
T_min = 500
|
||||||
|
|
||||||
|
fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge)
|
||||||
|
# print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
# print(pred_semantic.shape, phoneme_ids1.shape, ge.shape)
|
||||||
|
# print(fea_todo.shape)
|
||||||
|
|
||||||
|
return fea_ref, fea_todo, mel2
|
||||||
|
|
||||||
|
|
||||||
class GPTSoVITSV3(torch.nn.Module):
|
class GPTSoVITSV3(torch.nn.Module):
|
||||||
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -311,6 +402,7 @@ class GPTSoVITSV3(torch.nn.Module):
|
|||||||
chunk_len = 934 - fea_ref.shape[2]
|
chunk_len = 934 - fea_ref.shape[2]
|
||||||
wav_gen_list = []
|
wav_gen_list = []
|
||||||
idx = 0
|
idx = 0
|
||||||
|
fea_todo = fea_todo[:,:,:-5]
|
||||||
wav_gen_length = fea_todo.shape[2] * 256
|
wav_gen_length = fea_todo.shape[2] * 256
|
||||||
while 1:
|
while 1:
|
||||||
# current_time = datetime.now()
|
# current_time = datetime.now()
|
||||||
@ -342,6 +434,65 @@ class GPTSoVITSV3(torch.nn.Module):
|
|||||||
|
|
||||||
wav_gen = torch.cat(wav_gen_list, 2)
|
wav_gen = torch.cat(wav_gen_list, 2)
|
||||||
return wav_gen[0][0][:wav_gen_length]
|
return wav_gen[0][0][:wav_gen_length]
|
||||||
|
|
||||||
|
class GPTSoVITSV4(torch.nn.Module):
|
||||||
|
def __init__(self, gpt_sovits_half, cfm, hifigan):
|
||||||
|
super().__init__()
|
||||||
|
self.gpt_sovits_half = gpt_sovits_half
|
||||||
|
self.cfm = cfm
|
||||||
|
self.hifigan = hifigan
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
ssl_content,
|
||||||
|
ref_audio_32k: torch.FloatTensor,
|
||||||
|
phoneme_ids0: torch.LongTensor,
|
||||||
|
phoneme_ids1: torch.LongTensor,
|
||||||
|
bert1,
|
||||||
|
bert2,
|
||||||
|
top_k: torch.LongTensor,
|
||||||
|
sample_steps: torch.LongTensor,
|
||||||
|
):
|
||||||
|
# current_time = datetime.now()
|
||||||
|
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
|
||||||
|
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||||
|
)
|
||||||
|
chunk_len = 1000 - fea_ref.shape[2]
|
||||||
|
wav_gen_list = []
|
||||||
|
idx = 0
|
||||||
|
fea_todo = fea_todo[:,:,:-10]
|
||||||
|
wav_gen_length = fea_todo.shape[2] * 480
|
||||||
|
while 1:
|
||||||
|
# current_time = datetime.now()
|
||||||
|
# print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||||
|
if fea_todo_chunk.shape[-1] == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样,
|
||||||
|
# 所以在这里补0让他shape维持不变
|
||||||
|
# 但是这样会导致生成的音频长度不对,所以在最后截取一下。
|
||||||
|
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
|
||||||
|
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
||||||
|
if complete_len != 0:
|
||||||
|
fea_todo_chunk = torch.cat(
|
||||||
|
[
|
||||||
|
fea_todo_chunk,
|
||||||
|
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
|
||||||
|
],
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||||
|
idx += chunk_len
|
||||||
|
|
||||||
|
cfm_res = denorm_spec(cfm_res)
|
||||||
|
hifigan_res = self.hifigan(cfm_res)
|
||||||
|
wav_gen_list.append(hifigan_res)
|
||||||
|
|
||||||
|
wav_gen = torch.cat(wav_gen_list, 2)
|
||||||
|
return wav_gen[0][0][:wav_gen_length]
|
||||||
|
|
||||||
|
|
||||||
def init_bigvgan():
|
def init_bigvgan():
|
||||||
@ -361,6 +512,31 @@ def init_bigvgan():
|
|||||||
bigvgan_model = bigvgan_model.to(device)
|
bigvgan_model = bigvgan_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def init_hifigan():
|
||||||
|
global hifigan_model, bigvgan_model
|
||||||
|
hifigan_model = Generator(
|
||||||
|
initial_channel=100,
|
||||||
|
resblock="1",
|
||||||
|
resblock_kernel_sizes=[3, 7, 11],
|
||||||
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
upsample_rates=[10, 6, 2, 2, 2],
|
||||||
|
upsample_initial_channel=512,
|
||||||
|
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||||||
|
gin_channels=0,
|
||||||
|
is_bias=True,
|
||||||
|
)
|
||||||
|
hifigan_model.eval()
|
||||||
|
hifigan_model.remove_weight_norm()
|
||||||
|
state_dict_g = torch.load(
|
||||||
|
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
|
||||||
|
)
|
||||||
|
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
||||||
|
if is_half == True:
|
||||||
|
hifigan_model = hifigan_model.half().to(device)
|
||||||
|
else:
|
||||||
|
hifigan_model = hifigan_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
class Sovits:
|
class Sovits:
|
||||||
def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
|
def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
|
||||||
self.vq_model = vq_model
|
self.vq_model = vq_model
|
||||||
@ -399,6 +575,7 @@ class DictToAttrRecursive(dict):
|
|||||||
|
|
||||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||||
|
|
||||||
|
v3v4set = {"v3", "v4"}
|
||||||
|
|
||||||
def get_sovits_weights(sovits_path):
|
def get_sovits_weights(sovits_path):
|
||||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||||
@ -419,8 +596,8 @@ def get_sovits_weights(sovits_path):
|
|||||||
else:
|
else:
|
||||||
hps.model.version = "v2"
|
hps.model.version = "v2"
|
||||||
|
|
||||||
if model_version == "v3":
|
if model_version in v3v4set:
|
||||||
hps.model.version = "v3"
|
hps.model.version = model_version
|
||||||
|
|
||||||
logger.info(f"hps: {hps}")
|
logger.info(f"hps: {hps}")
|
||||||
|
|
||||||
@ -522,10 +699,14 @@ def export_cfm(
|
|||||||
return export_cfm
|
return export_cfm
|
||||||
|
|
||||||
|
|
||||||
def export():
|
def export(version="v3"):
|
||||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
if version == "v3":
|
||||||
|
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||||
init_bigvgan()
|
init_bigvgan()
|
||||||
|
else:
|
||||||
|
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
||||||
|
init_hifigan()
|
||||||
|
|
||||||
|
|
||||||
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
||||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||||
@ -542,7 +723,7 @@ def export():
|
|||||||
hps = sovits.hps
|
hps = sovits.hps
|
||||||
ref_wav_path = "onnx/ad/ref.wav"
|
ref_wav_path = "onnx/ad/ref.wav"
|
||||||
speed = 1.0
|
speed = 1.0
|
||||||
sample_steps = 32
|
sample_steps = 8
|
||||||
dtype = torch.float16 if is_half == True else torch.float32
|
dtype = torch.float16 if is_half == True else torch.float32
|
||||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||||
zero_wav = np.zeros(
|
zero_wav = np.zeros(
|
||||||
@ -634,25 +815,33 @@ def export():
|
|||||||
# vq_model = sovits.vq_model
|
# vq_model = sovits.vq_model
|
||||||
vq_model = trace_vq_model
|
vq_model = trace_vq_model
|
||||||
|
|
||||||
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
|
if version == "v3":
|
||||||
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
|
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
|
||||||
|
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
|
||||||
|
else:
|
||||||
|
gpt_sovits_half = ExportGPTSovitsV4Half(sovits.hps, script_t2s, trace_vq_model)
|
||||||
|
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v4_half.pt")
|
||||||
|
|
||||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||||
ref_audio = ref_audio.to(device).float()
|
ref_audio = ref_audio.to(device).float()
|
||||||
if ref_audio.shape[0] == 2:
|
if ref_audio.shape[0] == 2:
|
||||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||||
if sr != 24000:
|
tgt_sr = 24000 if version == "v3" else 32000
|
||||||
ref_audio = resample(ref_audio, sr)
|
if sr != tgt_sr:
|
||||||
|
ref_audio = resample(ref_audio, sr, tgt_sr)
|
||||||
# mel2 = mel_fn(ref_audio)
|
# mel2 = mel_fn(ref_audio)
|
||||||
mel2 = norm_spec(mel_fn(ref_audio))
|
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
|
||||||
|
mel2 = norm_spec(mel2)
|
||||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||||
fea_ref = fea_ref[:, :, :T_min]
|
fea_ref = fea_ref[:, :, :T_min]
|
||||||
print("fea_ref:", fea_ref.shape, T_min)
|
print("fea_ref:", fea_ref.shape, T_min)
|
||||||
if T_min > 468:
|
Tref = 468 if version == "v3" else 500
|
||||||
mel2 = mel2[:, :, -468:]
|
Tchunk = 934 if version == "v3" else 1000
|
||||||
fea_ref = fea_ref[:, :, -468:]
|
if T_min > Tref:
|
||||||
T_min = 468
|
mel2 = mel2[:, :, -Tref:]
|
||||||
chunk_len = 934 - T_min
|
fea_ref = fea_ref[:, :, -Tref:]
|
||||||
|
T_min = Tref
|
||||||
|
chunk_len = Tchunk - T_min
|
||||||
mel2 = mel2.to(dtype)
|
mel2 = mel2.to(dtype)
|
||||||
|
|
||||||
# fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
|
# fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
|
||||||
@ -714,13 +903,19 @@ def export():
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
||||||
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
||||||
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
if version == "v3":
|
||||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||||
wav_gen = bigvgan_model(cmf_res)
|
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||||
|
wav_gen = bigvgan_model(cmf_res)
|
||||||
|
else:
|
||||||
|
hifigan_model_ = torch.jit.trace(hifigan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||||
|
hifigan_model_.save("onnx/ad/hifigan_model.pt")
|
||||||
|
wav_gen = hifigan_model(cmf_res)
|
||||||
|
|
||||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||||
|
|
||||||
sr = 24000
|
sr = 24000 if version == "v3" else 48000
|
||||||
soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
|
soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
|
||||||
|
|
||||||
|
|
||||||
@ -848,8 +1043,9 @@ def test_export(
|
|||||||
|
|
||||||
def test_export1(
|
def test_export1(
|
||||||
todo_text,
|
todo_text,
|
||||||
gpt_sovits_v3,
|
gpt_sovits_v3v4,
|
||||||
output,
|
output,
|
||||||
|
out_sr=24000,
|
||||||
):
|
):
|
||||||
# hps = sovits.hps
|
# hps = sovits.hps
|
||||||
ref_wav_path = "onnx/ad/ref.wav"
|
ref_wav_path = "onnx/ad/ref.wav"
|
||||||
@ -859,7 +1055,7 @@ def test_export1(
|
|||||||
dtype = torch.float16 if is_half == True else torch.float32
|
dtype = torch.float16 if is_half == True else torch.float32
|
||||||
|
|
||||||
zero_wav = np.zeros(
|
zero_wav = np.zeros(
|
||||||
int(24000 * 0.3),
|
int(out_sr * 0.3),
|
||||||
dtype=np.float16 if is_half == True else np.float32,
|
dtype=np.float16 if is_half == True else np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -907,22 +1103,26 @@ def test_export1(
|
|||||||
bert2.shape,
|
bert2.shape,
|
||||||
top_k.shape,
|
top_k.shape,
|
||||||
)
|
)
|
||||||
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
wav_gen = gpt_sovits_v3v4(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||||
|
|
||||||
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
|
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
|
||||||
|
|
||||||
audio = wav_gen.cpu().detach().numpy()
|
audio = wav_gen.cpu().detach().numpy()
|
||||||
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
sr = 24000
|
soundfile.write(output, (audio * 32768).astype(np.int16), out_sr)
|
||||||
soundfile.write(output, (audio * 32768).astype(np.int16), sr)
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def test_():
|
def test_(version="v3"):
|
||||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
if version == "v3":
|
||||||
|
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||||
|
# init_bigvgan()
|
||||||
|
else:
|
||||||
|
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
||||||
|
# init_hifigan()
|
||||||
|
|
||||||
# cfm = ExportCFM(sovits.cfm)
|
# cfm = ExportCFM(sovits.cfm)
|
||||||
# cfm.cfm.estimator = dit
|
# cfm.cfm.estimator = dit
|
||||||
@ -963,25 +1163,41 @@ def test_():
|
|||||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.half()
|
# gpt_sovits_v3_half = gpt_sovits_v3_half.half()
|
||||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
||||||
# gpt_sovits_v3_half.eval()
|
# gpt_sovits_v3_half.eval()
|
||||||
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
if version == "v3":
|
||||||
logger.info("gpt_sovits_v3_half ok")
|
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
||||||
|
logger.info("gpt_sovits_v3_half ok")
|
||||||
|
# init_bigvgan()
|
||||||
|
# global bigvgan_model
|
||||||
|
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
|
||||||
|
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
|
||||||
|
bigvgan_model = bigvgan_model.half()
|
||||||
|
bigvgan_model = bigvgan_model.cuda()
|
||||||
|
bigvgan_model.eval()
|
||||||
|
|
||||||
# init_bigvgan()
|
logger.info("bigvgan ok")
|
||||||
# global bigvgan_model
|
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
||||||
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
|
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
||||||
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
|
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
|
||||||
bigvgan_model = bigvgan_model.half()
|
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
|
||||||
bigvgan_model = bigvgan_model.cuda()
|
gpt_sovits_v3.eval()
|
||||||
bigvgan_model.eval()
|
print("save gpt_sovits_v3 ok")
|
||||||
|
else:
|
||||||
|
gpt_sovits_v4_half = ExportGPTSovitsV4Half(sovits.hps, t2s_m, vq_model)
|
||||||
|
logger.info("gpt_sovits_v4 ok")
|
||||||
|
|
||||||
logger.info("bigvgan ok")
|
hifigan_model = torch.jit.load("onnx/ad/hifigan_model.pt")
|
||||||
|
hifigan_model = hifigan_model.half()
|
||||||
|
hifigan_model = hifigan_model.cuda()
|
||||||
|
hifigan_model.eval()
|
||||||
|
logger.info("hifigan ok")
|
||||||
|
gpt_sovits_v4 = GPTSoVITSV4(gpt_sovits_v4_half, cfm, hifigan_model)
|
||||||
|
gpt_sovits_v4 = torch.jit.script(gpt_sovits_v4)
|
||||||
|
gpt_sovits_v4.save("onnx/ad/gpt_sovits_v4.pt")
|
||||||
|
print("save gpt_sovits_v4 ok")
|
||||||
|
|
||||||
|
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
|
||||||
|
sr = 24000 if version == "v3" else 48000
|
||||||
|
|
||||||
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
|
||||||
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
|
||||||
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
|
|
||||||
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
|
|
||||||
gpt_sovits_v3.eval()
|
|
||||||
print("save gpt_sovits_v3 ok")
|
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
# print("thread:", torch.get_num_threads())
|
# print("thread:", torch.get_num_threads())
|
||||||
@ -991,14 +1207,16 @@ def test_():
|
|||||||
|
|
||||||
test_export1(
|
test_export1(
|
||||||
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
||||||
gpt_sovits_v3,
|
gpt_sovits_v3v4,
|
||||||
"out.wav",
|
"out.wav",
|
||||||
|
sr
|
||||||
)
|
)
|
||||||
|
|
||||||
test_export1(
|
test_export1(
|
||||||
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
||||||
gpt_sovits_v3,
|
gpt_sovits_v3v4,
|
||||||
"out2.wav",
|
"out2.wav",
|
||||||
|
sr
|
||||||
)
|
)
|
||||||
|
|
||||||
# test_export(
|
# test_export(
|
||||||
@ -1030,6 +1248,6 @@ def test_export_gpt_sovits_v3():
|
|||||||
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# export()
|
export("v4")
|
||||||
test_()
|
# test_("v4")
|
||||||
# test_export_gpt_sovits_v3()
|
# test_export_gpt_sovits_v3()
|
||||||
|
@ -391,6 +391,7 @@ class Generator(torch.nn.Module):
|
|||||||
upsample_initial_channel,
|
upsample_initial_channel,
|
||||||
upsample_kernel_sizes,
|
upsample_kernel_sizes,
|
||||||
gin_channels=0,
|
gin_channels=0,
|
||||||
|
is_bias=False,
|
||||||
):
|
):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
@ -418,7 +419,7 @@ class Generator(torch.nn.Module):
|
|||||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
self.resblocks.append(resblock(ch, k, d))
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
|
|
||||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
|
||||||
self.ups.apply(init_weights)
|
self.ups.apply(init_weights)
|
||||||
|
|
||||||
if gin_channels != 0:
|
if gin_channels != 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user