mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 04:22:46 +08:00
Fix: 由于 export_torch_script_v3 的改动,v2 现在需要传入 top_k
This commit is contained in:
parent
b12ac35b04
commit
1ceab938bb
@ -654,6 +654,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||
|
||||
top_k = torch.LongTensor([5]).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
@ -663,7 +665,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert))
|
||||
text_bert,
|
||||
top_k))
|
||||
|
||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||
gpt_sovits_export.save(gpt_sovits_path)
|
||||
@ -685,15 +688,26 @@ class GPT_SoVITS(nn.Module):
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
|
||||
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0):
|
||||
def forward(
|
||||
self,
|
||||
ssl_content: torch.Tensor,
|
||||
ref_audio_sr: torch.Tensor,
|
||||
ref_seq: Tensor,
|
||||
text_seq: Tensor,
|
||||
ref_bert: Tensor,
|
||||
text_bert: Tensor,
|
||||
top_k: LongTensor,
|
||||
speed=1.0,
|
||||
):
|
||||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompts = prompt_semantic.unsqueeze(0)
|
||||
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
||||
return audio
|
||||
|
||||
|
||||
def test():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
@ -785,8 +799,10 @@ def test():
|
||||
print('text_bert:',text_bert.shape)
|
||||
text_bert=text_bert.to('cuda')
|
||||
|
||||
top_k = torch.LongTensor([5]).to('cuda')
|
||||
|
||||
with torch.no_grad():
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert)
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||||
print('start write wav')
|
||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user