mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-03 05:39:55 +08:00
添加导出 v4 的部分 (#2417)
* feat: 添加导出v4的script * 改名 export_torch_script_v3.py 为 export_torch_script_v3v4.py * export_torch_script_v3v4 中优化函数名称和参数
This commit is contained in:
parent
e909c93c63
commit
6d12a6a6cb
@ -10,7 +10,7 @@ from inference_webui import get_phones_and_bert
|
||||
import librosa
|
||||
from module import commons
|
||||
from module.mel_processing import mel_spectrogram_torch
|
||||
from module.models_onnx import CFM, SynthesizerTrnV3
|
||||
from module.models_onnx import CFM, Generator, SynthesizerTrnV3
|
||||
import numpy as np
|
||||
import torch._dynamo.config
|
||||
import torchaudio
|
||||
@ -46,7 +46,7 @@ class MelSpectrgram(torch.nn.Module):
|
||||
center=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
||||
self.hann_window = torch.hann_window(win_size).to(device=device, dtype=dtype)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
||||
self.n_fft: int = n_fft
|
||||
@ -189,6 +189,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1280,
|
||||
"win_size": 1280,
|
||||
"hop_size": 320,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 32000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
@ -285,6 +298,84 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
return fea_ref, fea_todo, mel2
|
||||
|
||||
|
||||
class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
||||
super().__init__()
|
||||
self.hps = hps
|
||||
self.t2s_m = t2s_m
|
||||
self.vq_model = vq_model
|
||||
self.mel2 = MelSpectrgram(
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
n_fft=1280,
|
||||
num_mels=100,
|
||||
sampling_rate=32000,
|
||||
hop_size=320,
|
||||
win_size=1280,
|
||||
fmin=0,
|
||||
fmax=None,
|
||||
center=False,
|
||||
)
|
||||
# self.dtype = dtype
|
||||
self.filter_length: int = hps.data.filter_length
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0,
|
||||
phoneme_ids1,
|
||||
bert1,
|
||||
bert2,
|
||||
top_k,
|
||||
):
|
||||
refer = spectrogram_torch(
|
||||
ref_audio_32k,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
).to(ssl_content.dtype)
|
||||
|
||||
codes = self.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0)
|
||||
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
ge = self.vq_model.create_ge(refer)
|
||||
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
prompt_ = prompt.unsqueeze(0)
|
||||
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
|
||||
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
# print(prompt_.shape, phoneme_ids0.shape, ge.shape)
|
||||
# print(fea_ref.shape)
|
||||
|
||||
ref_32k = ref_audio_32k
|
||||
mel2 = norm_spec(self.mel2(ref_32k)).to(ssl_content.dtype)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if T_min > 500:
|
||||
mel2 = mel2[:, :, -500:]
|
||||
fea_ref = fea_ref[:, :, -500:]
|
||||
T_min = 500
|
||||
|
||||
fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge)
|
||||
# print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
# print(pred_semantic.shape, phoneme_ids1.shape, ge.shape)
|
||||
# print(fea_todo.shape)
|
||||
|
||||
return fea_ref, fea_todo, mel2
|
||||
|
||||
|
||||
class GPTSoVITSV3(torch.nn.Module):
|
||||
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
||||
super().__init__()
|
||||
@ -311,6 +402,7 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
fea_todo = fea_todo[:,:,:-5]
|
||||
wav_gen_length = fea_todo.shape[2] * 256
|
||||
while 1:
|
||||
# current_time = datetime.now()
|
||||
@ -342,6 +434,65 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
|
||||
wav_gen = torch.cat(wav_gen_list, 2)
|
||||
return wav_gen[0][0][:wav_gen_length]
|
||||
|
||||
class GPTSoVITSV4(torch.nn.Module):
|
||||
def __init__(self, gpt_sovits_half, cfm, hifigan):
|
||||
super().__init__()
|
||||
self.gpt_sovits_half = gpt_sovits_half
|
||||
self.cfm = cfm
|
||||
self.hifigan = hifigan
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0: torch.LongTensor,
|
||||
phoneme_ids1: torch.LongTensor,
|
||||
bert1,
|
||||
bert2,
|
||||
top_k: torch.LongTensor,
|
||||
sample_steps: torch.LongTensor,
|
||||
):
|
||||
# current_time = datetime.now()
|
||||
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
|
||||
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
chunk_len = 1000 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
fea_todo = fea_todo[:,:,:-10]
|
||||
wav_gen_length = fea_todo.shape[2] * 480
|
||||
while 1:
|
||||
# current_time = datetime.now()
|
||||
# print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||
if fea_todo_chunk.shape[-1] == 0:
|
||||
break
|
||||
|
||||
# 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样,
|
||||
# 所以在这里补0让他shape维持不变
|
||||
# 但是这样会导致生成的音频长度不对,所以在最后截取一下。
|
||||
# 经过 hifigan 之后音频长度就是 fea_todo.shape[2] * 480
|
||||
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
||||
if complete_len != 0:
|
||||
fea_todo_chunk = torch.cat(
|
||||
[
|
||||
fea_todo_chunk,
|
||||
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||
idx += chunk_len
|
||||
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
hifigan_res = self.hifigan(cfm_res)
|
||||
wav_gen_list.append(hifigan_res)
|
||||
|
||||
wav_gen = torch.cat(wav_gen_list, 2)
|
||||
return wav_gen[0][0][:wav_gen_length]
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
@ -361,6 +512,31 @@ def init_bigvgan():
|
||||
bigvgan_model = bigvgan_model.to(device)
|
||||
|
||||
|
||||
def init_hifigan():
|
||||
global hifigan_model, bigvgan_model
|
||||
hifigan_model = Generator(
|
||||
initial_channel=100,
|
||||
resblock="1",
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_rates=[10, 6, 2, 2, 2],
|
||||
upsample_initial_channel=512,
|
||||
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||||
gin_channels=0,
|
||||
is_bias=True,
|
||||
)
|
||||
hifigan_model.eval()
|
||||
hifigan_model.remove_weight_norm()
|
||||
state_dict_g = torch.load(
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
|
||||
)
|
||||
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
||||
if is_half == True:
|
||||
hifigan_model = hifigan_model.half().to(device)
|
||||
else:
|
||||
hifigan_model = hifigan_model.to(device)
|
||||
|
||||
|
||||
class Sovits:
|
||||
def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
|
||||
self.vq_model = vq_model
|
||||
@ -399,6 +575,7 @@ class DictToAttrRecursive(dict):
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
@ -419,8 +596,8 @@ def get_sovits_weights(sovits_path):
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
|
||||
if model_version == "v3":
|
||||
hps.model.version = "v3"
|
||||
if model_version in v3v4set:
|
||||
hps.model.version = model_version
|
||||
|
||||
logger.info(f"hps: {hps}")
|
||||
|
||||
@ -522,10 +699,14 @@ def export_cfm(
|
||||
return export_cfm
|
||||
|
||||
|
||||
def export():
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
|
||||
init_bigvgan()
|
||||
def export_1(ref_wav_path,ref_wav_text,version="v3"):
|
||||
if version == "v3":
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
init_bigvgan()
|
||||
else:
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
||||
init_hifigan()
|
||||
|
||||
|
||||
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
@ -540,9 +721,9 @@ def export():
|
||||
script_t2s = torch.jit.script(t2s_m).to(device)
|
||||
|
||||
hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
# ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
sample_steps = 32
|
||||
sample_steps = 8
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
zero_wav = np.zeros(
|
||||
@ -567,8 +748,11 @@ def export():
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
||||
# phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
# "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
|
||||
# )
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
"你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
|
||||
ref_wav_text, "auto", "v3"
|
||||
)
|
||||
phones2, bert2, norm_text2 = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
|
||||
@ -634,25 +818,33 @@ def export():
|
||||
# vq_model = sovits.vq_model
|
||||
vq_model = trace_vq_model
|
||||
|
||||
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
|
||||
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
|
||||
if version == "v3":
|
||||
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
|
||||
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
|
||||
else:
|
||||
gpt_sovits_half = ExportGPTSovitsV4Half(sovits.hps, script_t2s, trace_vq_model)
|
||||
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v4_half.pt")
|
||||
|
||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if sr != 24000:
|
||||
ref_audio = resample(ref_audio, sr)
|
||||
tgt_sr = 24000 if version == "v3" else 32000
|
||||
if sr != tgt_sr:
|
||||
ref_audio = resample(ref_audio, sr, tgt_sr)
|
||||
# mel2 = mel_fn(ref_audio)
|
||||
mel2 = norm_spec(mel_fn(ref_audio))
|
||||
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
print("fea_ref:", fea_ref.shape, T_min)
|
||||
if T_min > 468:
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
Tref = 468 if version == "v3" else 500
|
||||
Tchunk = 934 if version == "v3" else 1000
|
||||
if T_min > Tref:
|
||||
mel2 = mel2[:, :, -Tref:]
|
||||
fea_ref = fea_ref[:, :, -Tref:]
|
||||
T_min = Tref
|
||||
chunk_len = Tchunk - T_min
|
||||
mel2 = mel2.to(dtype)
|
||||
|
||||
# fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
|
||||
@ -714,13 +906,19 @@ def export():
|
||||
with torch.inference_mode():
|
||||
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
||||
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
||||
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
if version == "v3":
|
||||
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
else:
|
||||
hifigan_model_ = torch.jit.trace(hifigan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
hifigan_model_.save("onnx/ad/hifigan_model.pt")
|
||||
wav_gen = hifigan_model(cmf_res)
|
||||
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||
|
||||
sr = 24000
|
||||
sr = 24000 if version == "v3" else 48000
|
||||
soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
|
||||
|
||||
|
||||
@ -846,10 +1044,11 @@ def test_export(
|
||||
soundfile.write(output, (audio * 32768).astype(np.int16), sr)
|
||||
|
||||
|
||||
def test_export1(
|
||||
def test_export(
|
||||
todo_text,
|
||||
gpt_sovits_v3,
|
||||
gpt_sovits_v3v4,
|
||||
output,
|
||||
out_sr=24000,
|
||||
):
|
||||
# hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
@ -859,7 +1058,7 @@ def test_export1(
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
|
||||
zero_wav = np.zeros(
|
||||
int(24000 * 0.3),
|
||||
int(out_sr * 0.3),
|
||||
dtype=np.float16 if is_half == True else np.float32,
|
||||
)
|
||||
|
||||
@ -894,7 +1093,7 @@ def test_export1(
|
||||
|
||||
bert1 = bert1.T.to(device)
|
||||
bert2 = bert2.T.to(device)
|
||||
top_k = torch.LongTensor([15]).to(device)
|
||||
top_k = torch.LongTensor([20]).to(device)
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info("start inference %s", current_time)
|
||||
@ -907,22 +1106,26 @@ def test_export1(
|
||||
bert2.shape,
|
||||
top_k.shape,
|
||||
)
|
||||
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||
wav_gen = gpt_sovits_v3v4(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
|
||||
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
|
||||
|
||||
audio = wav_gen.cpu().detach().numpy()
|
||||
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
sr = 24000
|
||||
soundfile.write(output, (audio * 32768).astype(np.int16), sr)
|
||||
soundfile.write(output, (audio * 32768).astype(np.int16), out_sr)
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def test_():
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
def export_2(version="v3"):
|
||||
if version == "v3":
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
# init_bigvgan()
|
||||
else:
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
||||
# init_hifigan()
|
||||
|
||||
# cfm = ExportCFM(sovits.cfm)
|
||||
# cfm.cfm.estimator = dit
|
||||
@ -963,25 +1166,41 @@ def test_():
|
||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.half()
|
||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
||||
# gpt_sovits_v3_half.eval()
|
||||
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
||||
logger.info("gpt_sovits_v3_half ok")
|
||||
if version == "v3":
|
||||
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
||||
logger.info("gpt_sovits_v3_half ok")
|
||||
# init_bigvgan()
|
||||
# global bigvgan_model
|
||||
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
|
||||
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
|
||||
bigvgan_model = bigvgan_model.half()
|
||||
bigvgan_model = bigvgan_model.cuda()
|
||||
bigvgan_model.eval()
|
||||
|
||||
# init_bigvgan()
|
||||
# global bigvgan_model
|
||||
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
|
||||
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
|
||||
bigvgan_model = bigvgan_model.half()
|
||||
bigvgan_model = bigvgan_model.cuda()
|
||||
bigvgan_model.eval()
|
||||
logger.info("bigvgan ok")
|
||||
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
||||
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
||||
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
|
||||
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
|
||||
gpt_sovits_v3.eval()
|
||||
print("save gpt_sovits_v3 ok")
|
||||
else:
|
||||
gpt_sovits_v4_half = ExportGPTSovitsV4Half(sovits.hps, t2s_m, vq_model)
|
||||
logger.info("gpt_sovits_v4 ok")
|
||||
|
||||
logger.info("bigvgan ok")
|
||||
hifigan_model = torch.jit.load("onnx/ad/hifigan_model.pt")
|
||||
hifigan_model = hifigan_model.half()
|
||||
hifigan_model = hifigan_model.cuda()
|
||||
hifigan_model.eval()
|
||||
logger.info("hifigan ok")
|
||||
gpt_sovits_v4 = GPTSoVITSV4(gpt_sovits_v4_half, cfm, hifigan_model)
|
||||
gpt_sovits_v4 = torch.jit.script(gpt_sovits_v4)
|
||||
gpt_sovits_v4.save("onnx/ad/gpt_sovits_v4.pt")
|
||||
print("save gpt_sovits_v4 ok")
|
||||
|
||||
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
|
||||
sr = 24000 if version == "v3" else 48000
|
||||
|
||||
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
||||
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
||||
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
|
||||
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
|
||||
gpt_sovits_v3.eval()
|
||||
print("save gpt_sovits_v3 ok")
|
||||
|
||||
time.sleep(5)
|
||||
# print("thread:", torch.get_num_threads())
|
||||
@ -989,16 +1208,18 @@ def test_():
|
||||
# torch.set_num_interop_threads(1)
|
||||
# torch.set_num_threads(1)
|
||||
|
||||
test_export1(
|
||||
test_export(
|
||||
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
||||
gpt_sovits_v3,
|
||||
gpt_sovits_v3v4,
|
||||
"out.wav",
|
||||
sr
|
||||
)
|
||||
|
||||
test_export1(
|
||||
test_export(
|
||||
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
||||
gpt_sovits_v3,
|
||||
gpt_sovits_v3v4,
|
||||
"out2.wav",
|
||||
sr
|
||||
)
|
||||
|
||||
# test_export(
|
||||
@ -1022,7 +1243,7 @@ def test_export_gpt_sovits_v3():
|
||||
# gpt_sovits_v3,
|
||||
# "out4.wav",
|
||||
# )
|
||||
test_export1(
|
||||
test_export(
|
||||
"风萧萧兮易水寒,壮士一去兮不复还.",
|
||||
gpt_sovits_v3,
|
||||
"out5.wav",
|
||||
@ -1030,6 +1251,6 @@ def test_export_gpt_sovits_v3():
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
# export()
|
||||
test_()
|
||||
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
||||
# export_2("v4")
|
||||
# test_export_gpt_sovits_v3()
|
@ -391,6 +391,7 @@ class Generator(torch.nn.Module):
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=0,
|
||||
is_bias=False,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
@ -418,7 +419,7 @@ class Generator(torch.nn.Module):
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user