mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat:enable speed control for v1v2
This commit is contained in:
parent
337da7454e
commit
8858492f56
@ -205,6 +205,8 @@ class TextEncoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, text, ge, speed=1):
|
||||
if type(speed) == float:
|
||||
speed = torch.FloatTensor([speed])
|
||||
y_mask = torch.ones_like(y[:1, :1, :])
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
@ -217,8 +219,7 @@ class TextEncoder(nn.Module):
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y = F.interpolate(y, size=(y.shape[-1] / speed).to(torch.int) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
|
||||
stats = self.proj(y) * y_mask
|
||||
|
@ -233,7 +233,7 @@ class VitsModel(nn.Module):
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
self.vq_model = SynthesizerTrn(
|
||||
self.vq_model:SynthesizerTrn = SynthesizerTrn(
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
@ -243,32 +243,27 @@ class VitsModel(nn.Module):
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
# print(f"filter_length:{self.hps.data.filter_length} sampling_rate:{self.hps.data.sampling_rate} hop_length:{self.hps.data.hop_length} win_length:{self.hps.data.win_length}")
|
||||
#v2 filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
|
||||
def forward(self, text_seq, pred_semantic, spectrum, sv_emb):
|
||||
def forward(self, text_seq, pred_semantic, spectrum, sv_emb, speed):
|
||||
if self.is_v2p:
|
||||
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0]
|
||||
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb, speed=speed)[0, 0]
|
||||
else:
|
||||
return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0]
|
||||
return self.vq_model(pred_semantic, text_seq, spectrum, speed=speed)[0, 0]
|
||||
|
||||
|
||||
class GptSoVits(nn.Module):
|
||||
class GptSoVits():
|
||||
def __init__(self, vits, t2s):
|
||||
super().__init__()
|
||||
self.vits = vits
|
||||
self.t2s = t2s
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb)
|
||||
return audio
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, speed, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
torch.onnx.export(
|
||||
self.vits,
|
||||
(text_seq, pred_semantic, spectrum, sv_emb),
|
||||
(text_seq, pred_semantic, spectrum, sv_emb, speed),
|
||||
f"onnx/{project_name}/{project_name}_vits.onnx",
|
||||
input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb"],
|
||||
input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb", "speed"],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"input_text_phones": {1: "text_length"},
|
||||
@ -379,13 +374,12 @@ def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_
|
||||
top_p = torch.FloatTensor([1.0])
|
||||
repetition_penalty = torch.FloatTensor([1.0])
|
||||
temperature = torch.FloatTensor([1.0])
|
||||
speed = torch.FloatTensor([1.0])
|
||||
|
||||
os.makedirs(f"onnx/{project_name}", exist_ok=True)
|
||||
|
||||
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
||||
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
# exit()
|
||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), speed, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
|
||||
if export_audio_preprocessor:
|
||||
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
||||
|
Loading…
x
Reference in New Issue
Block a user