mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-24 13:33:33 +08:00
Merge pull request #2460 from L-jasmine/export_v2pro
优化 torch_script 导出模型
This commit is contained in:
parent
1a9b8854ee
commit
7dec5f5bb0
@ -103,7 +103,7 @@ def logits_to_probs(
|
|||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def multinomial_sample_one_no_sync(probs_sort):
|
def multinomial_sample_one_no_sync(probs_sort):
|
||||||
# Does multinomial sampling without a cuda synchronization
|
# Does multinomial sampling without a cuda synchronization
|
||||||
q = torch.randn_like(probs_sort)
|
q = torch.empty_like(probs_sort).exponential_(1.0)
|
||||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ def sample(
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.35,
|
||||||
):
|
):
|
||||||
probs = logits_to_probs(
|
probs = logits_to_probs(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
@ -129,8 +129,8 @@ def sample(
|
|||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
def spectrogram_torch(hann_window:Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||||||
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
# hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||||
y = torch.nn.functional.pad(
|
y = torch.nn.functional.pad(
|
||||||
y.unsqueeze(1),
|
y.unsqueeze(1),
|
||||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||||
@ -309,8 +309,9 @@ class T2SBlock:
|
|||||||
|
|
||||||
attn = F.scaled_dot_product_attention(q, k, v)
|
attn = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||||
|
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||||
attn = F.linear(attn, self.out_w, self.out_b)
|
attn = F.linear(attn, self.out_w, self.out_b)
|
||||||
|
|
||||||
x = x + attn
|
x = x + attn
|
||||||
@ -348,7 +349,7 @@ class T2STransformer:
|
|||||||
|
|
||||||
|
|
||||||
class VitsModel(nn.Module):
|
class VitsModel(nn.Module):
|
||||||
def __init__(self, vits_path, version=None):
|
def __init__(self, vits_path, version=None, is_half=True, device="cpu"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||||
dict_s2 = load_sovits_new(vits_path)
|
dict_s2 = load_sovits_new(vits_path)
|
||||||
@ -373,11 +374,18 @@ class VitsModel(nn.Module):
|
|||||||
n_speakers=self.hps.data.n_speakers,
|
n_speakers=self.hps.data.n_speakers,
|
||||||
**self.hps.model,
|
**self.hps.model,
|
||||||
)
|
)
|
||||||
self.vq_model.eval()
|
|
||||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
|
self.vq_model.dec.remove_weight_norm()
|
||||||
|
if is_half:
|
||||||
|
self.vq_model = self.vq_model.half()
|
||||||
|
self.vq_model = self.vq_model.to(device)
|
||||||
|
self.vq_model.eval()
|
||||||
|
self.hann_window = torch.hann_window(self.hps.data.win_length, device=device, dtype= torch.float16 if is_half else torch.float32)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
|
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
|
||||||
refer = spectrogram_torch(
|
refer = spectrogram_torch(
|
||||||
|
self.hann_window,
|
||||||
ref_audio,
|
ref_audio,
|
||||||
self.hps.data.filter_length,
|
self.hps.data.filter_length,
|
||||||
self.hps.data.sampling_rate,
|
self.hps.data.sampling_rate,
|
||||||
@ -667,7 +675,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
|||||||
ssl_content = ssl(ref_audio).to(device)
|
ssl_content = ssl(ref_audio).to(device)
|
||||||
|
|
||||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
vits = VitsModel(vits_path).to(device)
|
vits = VitsModel(vits_path,device=device,is_half=False)
|
||||||
vits.eval()
|
vits.eval()
|
||||||
|
|
||||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||||
@ -765,10 +773,7 @@ def export_prov2(
|
|||||||
sv_model = ExportERes2NetV2(sv_cn_model)
|
sv_model = ExportERes2NetV2(sv_cn_model)
|
||||||
|
|
||||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
vits = VitsModel(vits_path, version)
|
vits = VitsModel(vits_path, version,is_half=is_half,device=device)
|
||||||
if is_half:
|
|
||||||
vits.vq_model = vits.vq_model.half()
|
|
||||||
vits.to(device)
|
|
||||||
vits.eval()
|
vits.eval()
|
||||||
|
|
||||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||||
|
@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
|||||||
self.sampling_rate: int = hps.data.sampling_rate
|
self.sampling_rate: int = hps.data.sampling_rate
|
||||||
self.hop_length: int = hps.data.hop_length
|
self.hop_length: int = hps.data.hop_length
|
||||||
self.win_length: int = hps.data.win_length
|
self.win_length: int = hps.data.win_length
|
||||||
|
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
|||||||
top_k,
|
top_k,
|
||||||
):
|
):
|
||||||
refer = spectrogram_torch(
|
refer = spectrogram_torch(
|
||||||
|
self.hann_window,
|
||||||
ref_audio_32k,
|
ref_audio_32k,
|
||||||
self.filter_length,
|
self.filter_length,
|
||||||
self.sampling_rate,
|
self.sampling_rate,
|
||||||
@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
|
|||||||
self.sampling_rate: int = hps.data.sampling_rate
|
self.sampling_rate: int = hps.data.sampling_rate
|
||||||
self.hop_length: int = hps.data.hop_length
|
self.hop_length: int = hps.data.hop_length
|
||||||
self.win_length: int = hps.data.win_length
|
self.win_length: int = hps.data.win_length
|
||||||
|
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
|
|||||||
top_k,
|
top_k,
|
||||||
):
|
):
|
||||||
refer = spectrogram_torch(
|
refer = spectrogram_torch(
|
||||||
|
self.hann_window,
|
||||||
ref_audio_32k,
|
ref_audio_32k,
|
||||||
self.filter_length,
|
self.filter_length,
|
||||||
self.sampling_rate,
|
self.sampling_rate,
|
||||||
@ -1149,7 +1153,7 @@ def export_2(version="v3"):
|
|||||||
raw_t2s = raw_t2s.half().to(device)
|
raw_t2s = raw_t2s.half().to(device)
|
||||||
t2s_m = T2SModel(raw_t2s).half().to(device)
|
t2s_m = T2SModel(raw_t2s).half().to(device)
|
||||||
t2s_m.eval()
|
t2s_m.eval()
|
||||||
t2s_m = torch.jit.script(t2s_m)
|
t2s_m = torch.jit.script(t2s_m).to(device)
|
||||||
t2s_m.eval()
|
t2s_m.eval()
|
||||||
# t2s_m.top_k = 15
|
# t2s_m.top_k = 15
|
||||||
logger.info("t2s_m ok")
|
logger.info("t2s_m ok")
|
||||||
@ -1251,6 +1255,6 @@ def test_export_gpt_sovits_v3():
|
|||||||
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
# export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
||||||
# export_2("v4")
|
export_2("v4")
|
||||||
# test_export_gpt_sovits_v3()
|
# test_export_gpt_sovits_v3()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user