feat:enable speed control for v1v2

This commit is contained in:
zpeng11 2025-08-31 00:14:07 -04:00
parent 337da7454e
commit 8858492f56
2 changed files with 14 additions and 19 deletions

View File

@ -205,6 +205,8 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge, speed=1):
if type(speed) == float:
speed = torch.FloatTensor([speed])
y_mask = torch.ones_like(y[:1, :1, :])
y = self.ssl_proj(y * y_mask) * y_mask
@ -217,9 +219,8 @@ class TextEncoder(nn.Module):
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
y = F.interpolate(y, size=(y.shape[-1] / speed).to(torch.int) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)

View File

@ -233,7 +233,7 @@ class VitsModel(nn.Module):
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn(
self.vq_model:SynthesizerTrn = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
@ -243,32 +243,27 @@ class VitsModel(nn.Module):
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
# print(f"filter_length:{self.hps.data.filter_length} sampling_rate:{self.hps.data.sampling_rate} hop_length:{self.hps.data.hop_length} win_length:{self.hps.data.win_length}")
#v2 filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
def forward(self, text_seq, pred_semantic, spectrum, sv_emb):
def forward(self, text_seq, pred_semantic, spectrum, sv_emb, speed):
if self.is_v2p:
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0]
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb, speed=speed)[0, 0]
else:
return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0]
return self.vq_model(pred_semantic, text_seq, spectrum, speed=speed)[0, 0]
class GptSoVits(nn.Module):
class GptSoVits():
def __init__(self, vits, t2s):
super().__init__()
self.vits = vits
self.t2s = t2s
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb)
return audio
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, speed, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
torch.onnx.export(
self.vits,
(text_seq, pred_semantic, spectrum, sv_emb),
(text_seq, pred_semantic, spectrum, sv_emb, speed),
f"onnx/{project_name}/{project_name}_vits.onnx",
input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb"],
input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb", "speed"],
output_names=["audio"],
dynamic_axes={
"input_text_phones": {1: "text_length"},
@ -379,13 +374,12 @@ def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_
top_p = torch.FloatTensor([1.0])
repetition_penalty = torch.FloatTensor([1.0])
temperature = torch.FloatTensor([1.0])
speed = torch.FloatTensor([1.0])
os.makedirs(f"onnx/{project_name}", exist_ok=True)
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
# exit()
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), speed, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
if export_audio_preprocessor:
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",