mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-30 18:49:48 +08:00
feat: 添加导出v4的script
This commit is contained in:
parent
acd68355c9
commit
150cd842df
@ -10,7 +10,7 @@ from inference_webui import get_phones_and_bert
|
||||
import librosa
|
||||
from module import commons
|
||||
from module.mel_processing import mel_spectrogram_torch
|
||||
from module.models_onnx import CFM, SynthesizerTrnV3
|
||||
from module.models_onnx import CFM, Generator, SynthesizerTrnV3
|
||||
import numpy as np
|
||||
import torch._dynamo.config
|
||||
import torchaudio
|
||||
@ -46,7 +46,7 @@ class MelSpectrgram(torch.nn.Module):
|
||||
center=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
||||
self.hann_window = torch.hann_window(win_size).to(device=device, dtype=dtype)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
||||
self.n_fft: int = n_fft
|
||||
@ -189,6 +189,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1280,
|
||||
"win_size": 1280,
|
||||
"hop_size": 320,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 32000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
@ -285,6 +298,84 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
return fea_ref, fea_todo, mel2
|
||||
|
||||
|
||||
class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
||||
super().__init__()
|
||||
self.hps = hps
|
||||
self.t2s_m = t2s_m
|
||||
self.vq_model = vq_model
|
||||
self.mel2 = MelSpectrgram(
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
n_fft=1280,
|
||||
num_mels=100,
|
||||
sampling_rate=32000,
|
||||
hop_size=320,
|
||||
win_size=1280,
|
||||
fmin=0,
|
||||
fmax=None,
|
||||
center=False,
|
||||
)
|
||||
# self.dtype = dtype
|
||||
self.filter_length: int = hps.data.filter_length
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0,
|
||||
phoneme_ids1,
|
||||
bert1,
|
||||
bert2,
|
||||
top_k,
|
||||
):
|
||||
refer = spectrogram_torch(
|
||||
ref_audio_32k,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
).to(ssl_content.dtype)
|
||||
|
||||
codes = self.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0)
|
||||
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
ge = self.vq_model.create_ge(refer)
|
||||
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
prompt_ = prompt.unsqueeze(0)
|
||||
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
|
||||
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
# print(prompt_.shape, phoneme_ids0.shape, ge.shape)
|
||||
# print(fea_ref.shape)
|
||||
|
||||
ref_32k = ref_audio_32k
|
||||
mel2 = norm_spec(self.mel2(ref_32k)).to(ssl_content.dtype)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if T_min > 500:
|
||||
mel2 = mel2[:, :, -500:]
|
||||
fea_ref = fea_ref[:, :, -500:]
|
||||
T_min = 500
|
||||
|
||||
fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge)
|
||||
# print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
# print(pred_semantic.shape, phoneme_ids1.shape, ge.shape)
|
||||
# print(fea_todo.shape)
|
||||
|
||||
return fea_ref, fea_todo, mel2
|
||||
|
||||
|
||||
class GPTSoVITSV3(torch.nn.Module):
|
||||
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
||||
super().__init__()
|
||||
@ -311,6 +402,7 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
fea_todo = fea_todo[:,:,:-5]
|
||||
wav_gen_length = fea_todo.shape[2] * 256
|
||||
while 1:
|
||||
# current_time = datetime.now()
|
||||
@ -342,6 +434,65 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
|
||||
wav_gen = torch.cat(wav_gen_list, 2)
|
||||
return wav_gen[0][0][:wav_gen_length]
|
||||
|
||||
class GPTSoVITSV4(torch.nn.Module):
|
||||
def __init__(self, gpt_sovits_half, cfm, hifigan):
|
||||
super().__init__()
|
||||
self.gpt_sovits_half = gpt_sovits_half
|
||||
self.cfm = cfm
|
||||
self.hifigan = hifigan
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0: torch.LongTensor,
|
||||
phoneme_ids1: torch.LongTensor,
|
||||
bert1,
|
||||
bert2,
|
||||
top_k: torch.LongTensor,
|
||||
sample_steps: torch.LongTensor,
|
||||
):
|
||||
# current_time = datetime.now()
|
||||
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
|
||||
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
chunk_len = 1000 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
fea_todo = fea_todo[:,:,:-10]
|
||||
wav_gen_length = fea_todo.shape[2] * 480
|
||||
while 1:
|
||||
# current_time = datetime.now()
|
||||
# print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||
if fea_todo_chunk.shape[-1] == 0:
|
||||
break
|
||||
|
||||
# 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样,
|
||||
# 所以在这里补0让他shape维持不变
|
||||
# 但是这样会导致生成的音频长度不对,所以在最后截取一下。
|
||||
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
|
||||
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
||||
if complete_len != 0:
|
||||
fea_todo_chunk = torch.cat(
|
||||
[
|
||||
fea_todo_chunk,
|
||||
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||
idx += chunk_len
|
||||
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
hifigan_res = self.hifigan(cfm_res)
|
||||
wav_gen_list.append(hifigan_res)
|
||||
|
||||
wav_gen = torch.cat(wav_gen_list, 2)
|
||||
return wav_gen[0][0][:wav_gen_length]
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
@ -361,6 +512,31 @@ def init_bigvgan():
|
||||
bigvgan_model = bigvgan_model.to(device)
|
||||
|
||||
|
||||
def init_hifigan():
|
||||
global hifigan_model, bigvgan_model
|
||||
hifigan_model = Generator(
|
||||
initial_channel=100,
|
||||
resblock="1",
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_rates=[10, 6, 2, 2, 2],
|
||||
upsample_initial_channel=512,
|
||||
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||||
gin_channels=0,
|
||||
is_bias=True,
|
||||
)
|
||||
hifigan_model.eval()
|
||||
hifigan_model.remove_weight_norm()
|
||||
state_dict_g = torch.load(
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
|
||||
)
|
||||
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
||||
if is_half == True:
|
||||
hifigan_model = hifigan_model.half().to(device)
|
||||
else:
|
||||
hifigan_model = hifigan_model.to(device)
|
||||
|
||||
|
||||
class Sovits:
|
||||
def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
|
||||
self.vq_model = vq_model
|
||||
@ -399,6 +575,7 @@ class DictToAttrRecursive(dict):
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
@ -419,8 +596,8 @@ def get_sovits_weights(sovits_path):
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
|
||||
if model_version == "v3":
|
||||
hps.model.version = "v3"
|
||||
if model_version in v3v4set:
|
||||
hps.model.version = model_version
|
||||
|
||||
logger.info(f"hps: {hps}")
|
||||
|
||||
@ -522,10 +699,14 @@ def export_cfm(
|
||||
return export_cfm
|
||||
|
||||
|
||||
def export():
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
|
||||
init_bigvgan()
|
||||
def export(version="v3"):
|
||||
if version == "v3":
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
init_bigvgan()
|
||||
else:
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
||||
init_hifigan()
|
||||
|
||||
|
||||
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
@ -542,7 +723,7 @@ def export():
|
||||
hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
sample_steps = 32
|
||||
sample_steps = 8
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
zero_wav = np.zeros(
|
||||
@ -634,25 +815,33 @@ def export():
|
||||
# vq_model = sovits.vq_model
|
||||
vq_model = trace_vq_model
|
||||
|
||||
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
|
||||
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
|
||||
if version == "v3":
|
||||
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
|
||||
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
|
||||
else:
|
||||
gpt_sovits_half = ExportGPTSovitsV4Half(sovits.hps, script_t2s, trace_vq_model)
|
||||
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v4_half.pt")
|
||||
|
||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if sr != 24000:
|
||||
ref_audio = resample(ref_audio, sr)
|
||||
tgt_sr = 24000 if version == "v3" else 32000
|
||||
if sr != tgt_sr:
|
||||
ref_audio = resample(ref_audio, sr, tgt_sr)
|
||||
# mel2 = mel_fn(ref_audio)
|
||||
mel2 = norm_spec(mel_fn(ref_audio))
|
||||
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
print("fea_ref:", fea_ref.shape, T_min)
|
||||
if T_min > 468:
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
Tref = 468 if version == "v3" else 500
|
||||
Tchunk = 934 if version == "v3" else 1000
|
||||
if T_min > Tref:
|
||||
mel2 = mel2[:, :, -Tref:]
|
||||
fea_ref = fea_ref[:, :, -Tref:]
|
||||
T_min = Tref
|
||||
chunk_len = Tchunk - T_min
|
||||
mel2 = mel2.to(dtype)
|
||||
|
||||
# fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
|
||||
@ -714,13 +903,19 @@ def export():
|
||||
with torch.inference_mode():
|
||||
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
||||
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
||||
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
if version == "v3":
|
||||
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
else:
|
||||
hifigan_model_ = torch.jit.trace(hifigan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
hifigan_model_.save("onnx/ad/hifigan_model.pt")
|
||||
wav_gen = hifigan_model(cmf_res)
|
||||
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||
|
||||
sr = 24000
|
||||
sr = 24000 if version == "v3" else 48000
|
||||
soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
|
||||
|
||||
|
||||
@ -848,8 +1043,9 @@ def test_export(
|
||||
|
||||
def test_export1(
|
||||
todo_text,
|
||||
gpt_sovits_v3,
|
||||
gpt_sovits_v3v4,
|
||||
output,
|
||||
out_sr=24000,
|
||||
):
|
||||
# hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
@ -859,7 +1055,7 @@ def test_export1(
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
|
||||
zero_wav = np.zeros(
|
||||
int(24000 * 0.3),
|
||||
int(out_sr * 0.3),
|
||||
dtype=np.float16 if is_half == True else np.float32,
|
||||
)
|
||||
|
||||
@ -907,22 +1103,26 @@ def test_export1(
|
||||
bert2.shape,
|
||||
top_k.shape,
|
||||
)
|
||||
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||
wav_gen = gpt_sovits_v3v4(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
|
||||
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
|
||||
|
||||
audio = wav_gen.cpu().detach().numpy()
|
||||
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
sr = 24000
|
||||
soundfile.write(output, (audio * 32768).astype(np.int16), sr)
|
||||
soundfile.write(output, (audio * 32768).astype(np.int16), out_sr)
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def test_():
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
def test_(version="v3"):
|
||||
if version == "v3":
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
# init_bigvgan()
|
||||
else:
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
||||
# init_hifigan()
|
||||
|
||||
# cfm = ExportCFM(sovits.cfm)
|
||||
# cfm.cfm.estimator = dit
|
||||
@ -963,25 +1163,41 @@ def test_():
|
||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.half()
|
||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
||||
# gpt_sovits_v3_half.eval()
|
||||
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
||||
logger.info("gpt_sovits_v3_half ok")
|
||||
if version == "v3":
|
||||
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
||||
logger.info("gpt_sovits_v3_half ok")
|
||||
# init_bigvgan()
|
||||
# global bigvgan_model
|
||||
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
|
||||
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
|
||||
bigvgan_model = bigvgan_model.half()
|
||||
bigvgan_model = bigvgan_model.cuda()
|
||||
bigvgan_model.eval()
|
||||
|
||||
# init_bigvgan()
|
||||
# global bigvgan_model
|
||||
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
|
||||
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
|
||||
bigvgan_model = bigvgan_model.half()
|
||||
bigvgan_model = bigvgan_model.cuda()
|
||||
bigvgan_model.eval()
|
||||
logger.info("bigvgan ok")
|
||||
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
||||
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
||||
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
|
||||
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
|
||||
gpt_sovits_v3.eval()
|
||||
print("save gpt_sovits_v3 ok")
|
||||
else:
|
||||
gpt_sovits_v4_half = ExportGPTSovitsV4Half(sovits.hps, t2s_m, vq_model)
|
||||
logger.info("gpt_sovits_v4 ok")
|
||||
|
||||
logger.info("bigvgan ok")
|
||||
hifigan_model = torch.jit.load("onnx/ad/hifigan_model.pt")
|
||||
hifigan_model = hifigan_model.half()
|
||||
hifigan_model = hifigan_model.cuda()
|
||||
hifigan_model.eval()
|
||||
logger.info("hifigan ok")
|
||||
gpt_sovits_v4 = GPTSoVITSV4(gpt_sovits_v4_half, cfm, hifigan_model)
|
||||
gpt_sovits_v4 = torch.jit.script(gpt_sovits_v4)
|
||||
gpt_sovits_v4.save("onnx/ad/gpt_sovits_v4.pt")
|
||||
print("save gpt_sovits_v4 ok")
|
||||
|
||||
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
|
||||
sr = 24000 if version == "v3" else 48000
|
||||
|
||||
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
||||
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
||||
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
|
||||
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
|
||||
gpt_sovits_v3.eval()
|
||||
print("save gpt_sovits_v3 ok")
|
||||
|
||||
time.sleep(5)
|
||||
# print("thread:", torch.get_num_threads())
|
||||
@ -991,14 +1207,16 @@ def test_():
|
||||
|
||||
test_export1(
|
||||
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
||||
gpt_sovits_v3,
|
||||
gpt_sovits_v3v4,
|
||||
"out.wav",
|
||||
sr
|
||||
)
|
||||
|
||||
test_export1(
|
||||
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
||||
gpt_sovits_v3,
|
||||
gpt_sovits_v3v4,
|
||||
"out2.wav",
|
||||
sr
|
||||
)
|
||||
|
||||
# test_export(
|
||||
@ -1030,6 +1248,6 @@ def test_export_gpt_sovits_v3():
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
# export()
|
||||
test_()
|
||||
export("v4")
|
||||
# test_("v4")
|
||||
# test_export_gpt_sovits_v3()
|
||||
|
@ -391,6 +391,7 @@ class Generator(torch.nn.Module):
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=0,
|
||||
is_bias=False,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
@ -418,7 +419,7 @@ class Generator(torch.nn.Module):
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user