添加导出 v4 的部分 (#2417)

* feat: 添加导出v4的script

* 改名 export_torch_script_v3.py 为 export_torch_script_v3v4.py

* export_torch_script_v3v4 中优化函数名称和参数
This commit is contained in:
zzz 2025-06-04 15:50:16 +08:00 committed by GitHub
parent e909c93c63
commit 6d12a6a6cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 280 additions and 58 deletions

View File

@ -10,7 +10,7 @@ from inference_webui import get_phones_and_bert
import librosa import librosa
from module import commons from module import commons
from module.mel_processing import mel_spectrogram_torch from module.mel_processing import mel_spectrogram_torch
from module.models_onnx import CFM, SynthesizerTrnV3 from module.models_onnx import CFM, Generator, SynthesizerTrnV3
import numpy as np import numpy as np
import torch._dynamo.config import torch._dynamo.config
import torchaudio import torchaudio
@ -46,7 +46,7 @@ class MelSpectrgram(torch.nn.Module):
center=False, center=False,
): ):
super().__init__() super().__init__()
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype) self.hann_window = torch.hann_window(win_size).to(device=device, dtype=dtype)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device) self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
self.n_fft: int = n_fft self.n_fft: int = n_fft
@ -189,6 +189,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
"center": False, "center": False,
}, },
) )
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
spec_min = -12 spec_min = -12
spec_max = 2 spec_max = 2
@ -285,6 +298,84 @@ class ExportGPTSovitsHalf(torch.nn.Module):
return fea_ref, fea_todo, mel2 return fea_ref, fea_todo, mel2
class ExportGPTSovitsV4Half(torch.nn.Module):
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
super().__init__()
self.hps = hps
self.t2s_m = t2s_m
self.vq_model = vq_model
self.mel2 = MelSpectrgram(
dtype=torch.float32,
device=device,
n_fft=1280,
num_mels=100,
sampling_rate=32000,
hop_size=320,
win_size=1280,
fmin=0,
fmax=None,
center=False,
)
# self.dtype = dtype
self.filter_length: int = hps.data.filter_length
self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length
def forward(
self,
ssl_content,
ref_audio_32k: torch.FloatTensor,
phoneme_ids0,
phoneme_ids1,
bert1,
bert2,
top_k,
):
refer = spectrogram_torch(
ref_audio_32k,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
).to(ssl_content.dtype)
codes = self.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0)
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
ge = self.vq_model.create_ge(refer)
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
prompt_ = prompt.unsqueeze(0)
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
# print(prompt_.shape, phoneme_ids0.shape, ge.shape)
# print(fea_ref.shape)
ref_32k = ref_audio_32k
mel2 = norm_spec(self.mel2(ref_32k)).to(ssl_content.dtype)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if T_min > 500:
mel2 = mel2[:, :, -500:]
fea_ref = fea_ref[:, :, -500:]
T_min = 500
fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge)
# print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
# print(pred_semantic.shape, phoneme_ids1.shape, ge.shape)
# print(fea_todo.shape)
return fea_ref, fea_todo, mel2
class GPTSoVITSV3(torch.nn.Module): class GPTSoVITSV3(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, bigvgan): def __init__(self, gpt_sovits_half, cfm, bigvgan):
super().__init__() super().__init__()
@ -311,6 +402,7 @@ class GPTSoVITSV3(torch.nn.Module):
chunk_len = 934 - fea_ref.shape[2] chunk_len = 934 - fea_ref.shape[2]
wav_gen_list = [] wav_gen_list = []
idx = 0 idx = 0
fea_todo = fea_todo[:,:,:-5]
wav_gen_length = fea_todo.shape[2] * 256 wav_gen_length = fea_todo.shape[2] * 256
while 1: while 1:
# current_time = datetime.now() # current_time = datetime.now()
@ -343,6 +435,65 @@ class GPTSoVITSV3(torch.nn.Module):
wav_gen = torch.cat(wav_gen_list, 2) wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length] return wav_gen[0][0][:wav_gen_length]
class GPTSoVITSV4(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, hifigan):
super().__init__()
self.gpt_sovits_half = gpt_sovits_half
self.cfm = cfm
self.hifigan = hifigan
def forward(
self,
ssl_content,
ref_audio_32k: torch.FloatTensor,
phoneme_ids0: torch.LongTensor,
phoneme_ids1: torch.LongTensor,
bert1,
bert2,
top_k: torch.LongTensor,
sample_steps: torch.LongTensor,
):
# current_time = datetime.now()
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
chunk_len = 1000 - fea_ref.shape[2]
wav_gen_list = []
idx = 0
fea_todo = fea_todo[:,:,:-10]
wav_gen_length = fea_todo.shape[2] * 480
while 1:
# current_time = datetime.now()
# print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
if fea_todo_chunk.shape[-1] == 0:
break
# 因为导出的模型在不同shape时会重新编译还是怎么的会卡顿10s这样
# 所以在这里补0让他shape维持不变
# 但是这样会导致生成的音频长度不对,所以在最后截取一下。
# 经过 hifigan 之后音频长度就是 fea_todo.shape[2] * 480
complete_len = chunk_len - fea_todo_chunk.shape[-1]
if complete_len != 0:
fea_todo_chunk = torch.cat(
[
fea_todo_chunk,
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
],
2,
)
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
idx += chunk_len
cfm_res = denorm_spec(cfm_res)
hifigan_res = self.hifigan(cfm_res)
wav_gen_list.append(hifigan_res)
wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length]
def init_bigvgan(): def init_bigvgan():
global bigvgan_model global bigvgan_model
@ -361,6 +512,31 @@ def init_bigvgan():
bigvgan_model = bigvgan_model.to(device) bigvgan_model = bigvgan_model.to(device)
def init_hifigan():
global hifigan_model, bigvgan_model
hifigan_model = Generator(
initial_channel=100,
resblock="1",
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0,
is_bias=True,
)
hifigan_model.eval()
hifigan_model.remove_weight_norm()
state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
)
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
if is_half == True:
hifigan_model = hifigan_model.half().to(device)
else:
hifigan_model = hifigan_model.to(device)
class Sovits: class Sovits:
def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps): def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
self.vq_model = vq_model self.vq_model = vq_model
@ -399,6 +575,7 @@ class DictToAttrRecursive(dict):
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
v3v4set = {"v3", "v4"}
def get_sovits_weights(sovits_path): def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
@ -419,8 +596,8 @@ def get_sovits_weights(sovits_path):
else: else:
hps.model.version = "v2" hps.model.version = "v2"
if model_version == "v3": if model_version in v3v4set:
hps.model.version = "v3" hps.model.version = model_version
logger.info(f"hps: {hps}") logger.info(f"hps: {hps}")
@ -522,10 +699,14 @@ def export_cfm(
return export_cfm return export_cfm
def export(): def export_1(ref_wav_path,ref_wav_text,version="v3"):
if version == "v3":
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
init_bigvgan() init_bigvgan()
else:
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
init_hifigan()
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt") dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
raw_t2s = get_raw_t2s_model(dict_s1).to(device) raw_t2s = get_raw_t2s_model(dict_s1).to(device)
@ -540,9 +721,9 @@ def export():
script_t2s = torch.jit.script(t2s_m).to(device) script_t2s = torch.jit.script(t2s_m).to(device)
hps = sovits.hps hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav" # ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0 speed = 1.0
sample_steps = 32 sample_steps = 8
dtype = torch.float16 if is_half == True else torch.float32 dtype = torch.float16 if is_half == True else torch.float32
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
zero_wav = np.zeros( zero_wav = np.zeros(
@ -567,8 +748,11 @@ def export():
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
# phones1, bert1, norm_text1 = get_phones_and_bert(
# "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
# )
phones1, bert1, norm_text1 = get_phones_and_bert( phones1, bert1, norm_text1 = get_phones_and_bert(
"你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3" ref_wav_text, "auto", "v3"
) )
phones2, bert2, norm_text2 = get_phones_and_bert( phones2, bert2, norm_text2 = get_phones_and_bert(
"这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
@ -634,25 +818,33 @@ def export():
# vq_model = sovits.vq_model # vq_model = sovits.vq_model
vq_model = trace_vq_model vq_model = trace_vq_model
if version == "v3":
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model) gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt") torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
else:
gpt_sovits_half = ExportGPTSovitsV4Half(sovits.hps, script_t2s, trace_vq_model)
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v4_half.pt")
ref_audio, sr = torchaudio.load(ref_wav_path) ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio = ref_audio.to(device).float() ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2: if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr != 24000: tgt_sr = 24000 if version == "v3" else 32000
ref_audio = resample(ref_audio, sr) if sr != tgt_sr:
ref_audio = resample(ref_audio, sr, tgt_sr)
# mel2 = mel_fn(ref_audio) # mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel_fn(ref_audio)) mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
fea_ref = fea_ref[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min]
print("fea_ref:", fea_ref.shape, T_min) print("fea_ref:", fea_ref.shape, T_min)
if T_min > 468: Tref = 468 if version == "v3" else 500
mel2 = mel2[:, :, -468:] Tchunk = 934 if version == "v3" else 1000
fea_ref = fea_ref[:, :, -468:] if T_min > Tref:
T_min = 468 mel2 = mel2[:, :, -Tref:]
chunk_len = 934 - T_min fea_ref = fea_ref[:, :, -Tref:]
T_min = Tref
chunk_len = Tchunk - T_min
mel2 = mel2.to(dtype) mel2 = mel2.to(dtype)
# fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge) # fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
@ -714,13 +906,19 @@ def export():
with torch.inference_mode(): with torch.inference_mode():
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype) cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
torch._dynamo.mark_dynamic(cmf_res_rand, 2) torch._dynamo.mark_dynamic(cmf_res_rand, 2)
if version == "v3":
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)) bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
bigvgan_model_.save("onnx/ad/bigvgan_model.pt") bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
wav_gen = bigvgan_model(cmf_res) wav_gen = bigvgan_model(cmf_res)
else:
hifigan_model_ = torch.jit.trace(hifigan_model, optimize=True, example_inputs=(cmf_res_rand,))
hifigan_model_.save("onnx/ad/hifigan_model.pt")
wav_gen = hifigan_model(cmf_res)
print("wav_gen:", wav_gen.shape, wav_gen.dtype) print("wav_gen:", wav_gen.shape, wav_gen.dtype)
audio = wav_gen[0][0].cpu().detach().numpy() audio = wav_gen[0][0].cpu().detach().numpy()
sr = 24000 sr = 24000 if version == "v3" else 48000
soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr) soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
@ -846,10 +1044,11 @@ def test_export(
soundfile.write(output, (audio * 32768).astype(np.int16), sr) soundfile.write(output, (audio * 32768).astype(np.int16), sr)
def test_export1( def test_export(
todo_text, todo_text,
gpt_sovits_v3, gpt_sovits_v3v4,
output, output,
out_sr=24000,
): ):
# hps = sovits.hps # hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav" ref_wav_path = "onnx/ad/ref.wav"
@ -859,7 +1058,7 @@ def test_export1(
dtype = torch.float16 if is_half == True else torch.float32 dtype = torch.float16 if is_half == True else torch.float32
zero_wav = np.zeros( zero_wav = np.zeros(
int(24000 * 0.3), int(out_sr * 0.3),
dtype=np.float16 if is_half == True else np.float32, dtype=np.float16 if is_half == True else np.float32,
) )
@ -894,7 +1093,7 @@ def test_export1(
bert1 = bert1.T.to(device) bert1 = bert1.T.to(device)
bert2 = bert2.T.to(device) bert2 = bert2.T.to(device)
top_k = torch.LongTensor([15]).to(device) top_k = torch.LongTensor([20]).to(device)
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("start inference %s", current_time) logger.info("start inference %s", current_time)
@ -907,22 +1106,26 @@ def test_export1(
bert2.shape, bert2.shape,
top_k.shape, top_k.shape,
) )
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps) wav_gen = gpt_sovits_v3v4(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
print("wav_gen:", wav_gen.shape, wav_gen.dtype) print("wav_gen:", wav_gen.shape, wav_gen.dtype)
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0) wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
audio = wav_gen.cpu().detach().numpy() audio = wav_gen.cpu().detach().numpy()
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
sr = 24000 soundfile.write(output, (audio * 32768).astype(np.int16), out_sr)
soundfile.write(output, (audio * 32768).astype(np.int16), sr)
import time import time
def test_(): def export_2(version="v3"):
if version == "v3":
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
# init_bigvgan()
else:
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
# init_hifigan()
# cfm = ExportCFM(sovits.cfm) # cfm = ExportCFM(sovits.cfm)
# cfm.cfm.estimator = dit # cfm.cfm.estimator = dit
@ -963,9 +1166,9 @@ def test_():
# gpt_sovits_v3_half = gpt_sovits_v3_half.half() # gpt_sovits_v3_half = gpt_sovits_v3_half.half()
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda() # gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
# gpt_sovits_v3_half.eval() # gpt_sovits_v3_half.eval()
if version == "v3":
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model) gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
logger.info("gpt_sovits_v3_half ok") logger.info("gpt_sovits_v3_half ok")
# init_bigvgan() # init_bigvgan()
# global bigvgan_model # global bigvgan_model
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt") bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
@ -975,13 +1178,29 @@ def test_():
bigvgan_model.eval() bigvgan_model.eval()
logger.info("bigvgan ok") logger.info("bigvgan ok")
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model) gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3) gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt") gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
gpt_sovits_v3 = gpt_sovits_v3.half().to(device) gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
gpt_sovits_v3.eval() gpt_sovits_v3.eval()
print("save gpt_sovits_v3 ok") print("save gpt_sovits_v3 ok")
else:
gpt_sovits_v4_half = ExportGPTSovitsV4Half(sovits.hps, t2s_m, vq_model)
logger.info("gpt_sovits_v4 ok")
hifigan_model = torch.jit.load("onnx/ad/hifigan_model.pt")
hifigan_model = hifigan_model.half()
hifigan_model = hifigan_model.cuda()
hifigan_model.eval()
logger.info("hifigan ok")
gpt_sovits_v4 = GPTSoVITSV4(gpt_sovits_v4_half, cfm, hifigan_model)
gpt_sovits_v4 = torch.jit.script(gpt_sovits_v4)
gpt_sovits_v4.save("onnx/ad/gpt_sovits_v4.pt")
print("save gpt_sovits_v4 ok")
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
sr = 24000 if version == "v3" else 48000
time.sleep(5) time.sleep(5)
# print("thread:", torch.get_num_threads()) # print("thread:", torch.get_num_threads())
@ -989,16 +1208,18 @@ def test_():
# torch.set_num_interop_threads(1) # torch.set_num_interop_threads(1)
# torch.set_num_threads(1) # torch.set_num_threads(1)
test_export1( test_export(
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....", "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
gpt_sovits_v3, gpt_sovits_v3v4,
"out.wav", "out.wav",
sr
) )
test_export1( test_export(
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!", "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
gpt_sovits_v3, gpt_sovits_v3v4,
"out2.wav", "out2.wav",
sr
) )
# test_export( # test_export(
@ -1022,7 +1243,7 @@ def test_export_gpt_sovits_v3():
# gpt_sovits_v3, # gpt_sovits_v3,
# "out4.wav", # "out4.wav",
# ) # )
test_export1( test_export(
"风萧萧兮易水寒,壮士一去兮不复还.", "风萧萧兮易水寒,壮士一去兮不复还.",
gpt_sovits_v3, gpt_sovits_v3,
"out5.wav", "out5.wav",
@ -1030,6 +1251,6 @@ def test_export_gpt_sovits_v3():
with torch.no_grad(): with torch.no_grad():
# export() export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
test_() # export_2("v4")
# test_export_gpt_sovits_v3() # test_export_gpt_sovits_v3()

View File

@ -391,6 +391,7 @@ class Generator(torch.nn.Module):
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
gin_channels=0, gin_channels=0,
is_bias=False,
): ):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
@ -418,7 +419,7 @@ class Generator(torch.nn.Module):
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
self.ups.apply(init_weights) self.ups.apply(init_weights)
if gin_channels != 0: if gin_channels != 0: