mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Fix: 由于 export_torch_script_v3 的改动,v2 现在需要传入 top_k
This commit is contained in:
parent
b12ac35b04
commit
1ceab938bb
@ -654,6 +654,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
|||||||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||||
torch._dynamo.mark_dynamic(text_bert, 0)
|
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||||
|
|
||||||
|
top_k = torch.LongTensor([5]).to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gpt_sovits_export = torch.jit.trace(
|
gpt_sovits_export = torch.jit.trace(
|
||||||
gpt_sovits,
|
gpt_sovits,
|
||||||
@ -663,7 +665,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
|||||||
ref_seq,
|
ref_seq,
|
||||||
text_seq,
|
text_seq,
|
||||||
ref_bert,
|
ref_bert,
|
||||||
text_bert))
|
text_bert,
|
||||||
|
top_k))
|
||||||
|
|
||||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||||
gpt_sovits_export.save(gpt_sovits_path)
|
gpt_sovits_export.save(gpt_sovits_path)
|
||||||
@ -685,15 +688,26 @@ class GPT_SoVITS(nn.Module):
|
|||||||
self.t2s = t2s
|
self.t2s = t2s
|
||||||
self.vits = vits
|
self.vits = vits
|
||||||
|
|
||||||
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0):
|
def forward(
|
||||||
|
self,
|
||||||
|
ssl_content: torch.Tensor,
|
||||||
|
ref_audio_sr: torch.Tensor,
|
||||||
|
ref_seq: Tensor,
|
||||||
|
text_seq: Tensor,
|
||||||
|
ref_bert: Tensor,
|
||||||
|
text_bert: Tensor,
|
||||||
|
top_k: LongTensor,
|
||||||
|
speed=1.0,
|
||||||
|
):
|
||||||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
prompts = prompt_semantic.unsqueeze(0)
|
prompts = prompt_semantic.unsqueeze(0)
|
||||||
|
|
||||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
|
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||||
@ -785,8 +799,10 @@ def test():
|
|||||||
print('text_bert:',text_bert.shape)
|
print('text_bert:',text_bert.shape)
|
||||||
text_bert=text_bert.to('cuda')
|
text_bert=text_bert.to('cuda')
|
||||||
|
|
||||||
|
top_k = torch.LongTensor([5]).to('cuda')
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert)
|
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||||||
print('start write wav')
|
print('start write wav')
|
||||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user