mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat:enable speed control for v1v2
This commit is contained in:
parent
337da7454e
commit
8858492f56
@ -205,6 +205,8 @@ class TextEncoder(nn.Module):
|
|||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, y, text, ge, speed=1):
|
def forward(self, y, text, ge, speed=1):
|
||||||
|
if type(speed) == float:
|
||||||
|
speed = torch.FloatTensor([speed])
|
||||||
y_mask = torch.ones_like(y[:1, :1, :])
|
y_mask = torch.ones_like(y[:1, :1, :])
|
||||||
|
|
||||||
y = self.ssl_proj(y * y_mask) * y_mask
|
y = self.ssl_proj(y * y_mask) * y_mask
|
||||||
@ -217,9 +219,8 @@ class TextEncoder(nn.Module):
|
|||||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||||
|
|
||||||
y = self.encoder2(y * y_mask, y_mask)
|
y = self.encoder2(y * y_mask, y_mask)
|
||||||
if speed != 1:
|
y = F.interpolate(y, size=(y.shape[-1] / speed).to(torch.int) + 1, mode="linear")
|
||||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
|
||||||
|
|
||||||
stats = self.proj(y) * y_mask
|
stats = self.proj(y) * y_mask
|
||||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
@ -233,7 +233,7 @@ class VitsModel(nn.Module):
|
|||||||
|
|
||||||
self.hps = DictToAttrRecursive(self.hps)
|
self.hps = DictToAttrRecursive(self.hps)
|
||||||
self.hps.model.semantic_frame_rate = "25hz"
|
self.hps.model.semantic_frame_rate = "25hz"
|
||||||
self.vq_model = SynthesizerTrn(
|
self.vq_model:SynthesizerTrn = SynthesizerTrn(
|
||||||
self.hps.data.filter_length // 2 + 1,
|
self.hps.data.filter_length // 2 + 1,
|
||||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||||
n_speakers=self.hps.data.n_speakers,
|
n_speakers=self.hps.data.n_speakers,
|
||||||
@ -243,32 +243,27 @@ class VitsModel(nn.Module):
|
|||||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
# print(f"filter_length:{self.hps.data.filter_length} sampling_rate:{self.hps.data.sampling_rate} hop_length:{self.hps.data.hop_length} win_length:{self.hps.data.win_length}")
|
# print(f"filter_length:{self.hps.data.filter_length} sampling_rate:{self.hps.data.sampling_rate} hop_length:{self.hps.data.hop_length} win_length:{self.hps.data.win_length}")
|
||||||
#v2 filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
|
#v2 filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
|
||||||
def forward(self, text_seq, pred_semantic, spectrum, sv_emb):
|
def forward(self, text_seq, pred_semantic, spectrum, sv_emb, speed):
|
||||||
if self.is_v2p:
|
if self.is_v2p:
|
||||||
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0]
|
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb, speed=speed)[0, 0]
|
||||||
else:
|
else:
|
||||||
return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0]
|
return self.vq_model(pred_semantic, text_seq, spectrum, speed=speed)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
class GptSoVits(nn.Module):
|
class GptSoVits():
|
||||||
def __init__(self, vits, t2s):
|
def __init__(self, vits, t2s):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vits = vits
|
self.vits = vits
|
||||||
self.t2s = t2s
|
self.t2s = t2s
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, speed, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
|
||||||
audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb)
|
|
||||||
return audio
|
|
||||||
|
|
||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
|
||||||
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.vits,
|
self.vits,
|
||||||
(text_seq, pred_semantic, spectrum, sv_emb),
|
(text_seq, pred_semantic, spectrum, sv_emb, speed),
|
||||||
f"onnx/{project_name}/{project_name}_vits.onnx",
|
f"onnx/{project_name}/{project_name}_vits.onnx",
|
||||||
input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb"],
|
input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb", "speed"],
|
||||||
output_names=["audio"],
|
output_names=["audio"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"input_text_phones": {1: "text_length"},
|
"input_text_phones": {1: "text_length"},
|
||||||
@ -379,13 +374,12 @@ def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_
|
|||||||
top_p = torch.FloatTensor([1.0])
|
top_p = torch.FloatTensor([1.0])
|
||||||
repetition_penalty = torch.FloatTensor([1.0])
|
repetition_penalty = torch.FloatTensor([1.0])
|
||||||
temperature = torch.FloatTensor([1.0])
|
temperature = torch.FloatTensor([1.0])
|
||||||
|
speed = torch.FloatTensor([1.0])
|
||||||
|
|
||||||
os.makedirs(f"onnx/{project_name}", exist_ok=True)
|
os.makedirs(f"onnx/{project_name}", exist_ok=True)
|
||||||
|
|
||||||
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
||||||
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), speed, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
# exit()
|
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
|
||||||
|
|
||||||
if export_audio_preprocessor:
|
if export_audio_preprocessor:
|
||||||
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user