(export_torch_script.py) 整合 vits 和 t2s 成一个 model 导出

This commit is contained in:
csh 2024-09-24 04:00:03 +08:00
parent dbaeb42e7f
commit 41dbc179c3

View File

@ -2,18 +2,19 @@
# reference: https://github.com/lifeiteng/vall-e
from typing import Optional
from my_utils import load_audio
from onnx_export import VitsModel
from text import cleaned_text_to_sequence
import torch
import torchaudio
from torch import IntTensor, LongTensor, nn
from torch import IntTensor, LongTensor, Tensor, nn
from torch.nn import functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer
from feature_extractor import cnhubert
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from module.models_onnx import SynthesizerTrn
import os
@ -31,6 +32,14 @@ default_config = {
"EOS": 1024,
}
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
t2s_model = t2s_model.eval()
return t2s_model
@torch.jit.script
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
@ -73,20 +82,13 @@ def logits_to_probs(
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
t2s_model = t2s_model.eval()
return t2s_model
def multinomial_sample_one_no_sync(
probs_sort
): # Does multinomial sampling without a cuda synchronization
@torch.jit.script
def multinomial_sample_one_no_sync(probs_sort):
# Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@torch.jit.script
def sample(
logits,
previous_tokens,
@ -102,6 +104,90 @@ def sample(
return idx_next, probs
@torch.jit.script
def spectrogram_torch(y, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
)
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
class T2SModel(nn.Module):
def __init__(self, config,raw_t2s:Text2SemanticLightningModule, norm_first=False, top_k=3):
super(T2SModel, self).__init__()
@ -148,7 +234,6 @@ class T2SModel(nn.Module):
self.top_k = int(self.config["inference"]["top_k"])
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
# def forward(self, x:LongTensor, prompts:LongTensor):
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor):
bert = torch.cat([ref_bert.T, text_bert.T], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
@ -229,199 +314,124 @@ class T2SModel(nn.Module):
return y[:, -idx:].unsqueeze(0)
# def first_stage_decoder(self, x, prompt):
# y = prompt
# x_example = x[:,:,0] * 0.0
# #N, 1, 512
# cache = {
# "all_stage": self.num_layers,
# "k": None,
# "v": None,
# "y_emb": None,
# "first_infer": 1,
# "stage": 0,
# }
# y_emb = self.ar_audio_embedding(y)
# cache["y_emb"] = y_emb
# y_pos = self.ar_audio_position(y_emb)
# xy_pos = torch.concat([x, y_pos], dim=1)
# y_example = y_pos[:,:,0] * 0.0
# x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
# y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
# y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
# torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
# )
# y_attn_mask = y_attn_mask > 0
# x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
# y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
# x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
# y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
# xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
# cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
# .unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
# cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
# .unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
# xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
# logits = self.ar_predict_layer(xy_dec[:, -1])
# samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
# y = torch.concat([y, samples], dim=1)
# return y, cache["k"], cache["v"], cache["y_emb"], x_example
# def stage_decoder(self, y, k, v, y_emb, x_example):
# cache = {
# "all_stage": self.num_layers,
# "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
# "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
# "y_emb": y_emb,
# "first_infer": 0,
# "stage": 0,
# }
# y_emb = torch.cat(
# [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
# )
# cache["y_emb"] = y_emb
# y_pos = self.ar_audio_position(y_emb)
# xy_pos = y_pos[:, -1:]
# y_example = y_pos[:,:,0] * 0.0
# xy_attn_mask = torch.cat([x_example, y_example], dim=1)
# xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
# xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
# logits = self.ar_predict_layer(xy_dec[:, -1])
# samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
# y = torch.concat([y, samples], dim=1)
# return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
class BertModel(torch.nn.Module):
def __init__(self, bert_model):
super(BertModel, self).__init__()
self.bert = bert_model
def forward(self, input_ids, attention_mask, token_type_ids, word2ph:IntTensor):
res = self.bert(input_ids, attention_mask, token_type_ids)
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# [sum(word2ph), 1024]
return phone_level_feature
@torch.jit.script
def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# [sum(word2ph), 1024]
return phone_level_feature
class MyBertModel(torch.nn.Module):
def __init__(self, bert_model):
super(MyBertModel, self).__init__()
self.bert = bert_model
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor):
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
return res
return build_phone_level_feature(res, word2ph)
class SSLModel(torch.nn.Module):
def __init__(self,vits:VitsModel):
def __init__(self):
super().__init__()
self.ssl = cnhubert.get_model().model
self.ssl_proj = vits.vq_model.ssl_proj
self.quantizer = vits.vq_model.quantizer
def forward(self, ref_audio)->LongTensor:
ref_audio_16k,ref_audio_sr = parse_audio(ref_audio)
def forward(self, ref_audio_16k)-> torch.Tensor:
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
codes = self.extract_latent(ssl_content.float())
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
return prompts,ref_audio_sr
return ssl_content
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)
class ExportSSLModel(torch.nn.Module):
def __init__(self,ssl:SSLModel):
super().__init__()
self.ssl = ssl
def export_bert(tokenizer):
ref_bert_inputs = tokenizer("在参加挼死特春晚的时候有人问了这样一个问题", return_tensors="pt")
def forward(self, ref_audio:torch.Tensor):
return self.ssl(ref_audio)
@torch.jit.export
def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
audio = resamplex(ref_audio,src_sr,dst_sr).float()
return audio
def export_bert(tokenizer,ref_text,word2ph):
ref_bert_inputs = tokenizer(ref_text, return_tensors="pt")
ref_bert_inputs = {
'input_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['input_ids']),
'attention_mask': torch.jit.annotate(torch.Tensor,ref_bert_inputs['attention_mask']),
'token_type_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['token_type_ids']),
'word2ph': torch.jit.annotate(torch.Tensor,word2ph)
}
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
my_bert_model = MyBertModel(bert_model)
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
print('trace my_bert_model')
bert = BertModel(my_bert_model)
torch.jit.script(bert).save("onnx/bert_model.pt")
my_bert_model.save("onnx/bert_model.pt")
print('exported bert')
def export(gpt_path, vits_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path)
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert = MyBertModel(bert_model)
# export_bert(tokenizer,"声音,是有温度的.夜晚的声音,会发光",ref_bert_inputs['word2ph'])
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
ssl = SSLModel()
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(torch.jit.annotate(torch.Tensor,ref_audio))))
torch.jit.script(s).save("onnx/xw/ssl_model.pt")
print('exported ssl')
# ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')])
# ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
# text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
ref_bert = bert(**ref_bert_inputs)
text_bert = bert(**text_berf_inputs)
ssl_content = ssl(ref_audio)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path)
vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1)
t2s_m = T2SModel(dict_s1['config'],raw_t2s,top_k=3)
t2s_m.eval()
torch.jit.script(t2s_m).save("onnx/xw/t2s_model.pt")
t2s = torch.jit.script(t2s_m)
print('exported t2s_m')
tokenizer = AutoTokenizer.from_pretrained(bert_path)
gpt_sovits = GPT_SoVITS(t2s,vits)
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
# audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert)
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
my_bert_model = MyBertModel(bert_model)
bert = BertModel(my_bert_model)
# export_bert(tokenizer)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path)
vits.eval()
ref_audio = torch.tensor([load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 48000)]).float()
ssl = SSLModel(vits)
torch.jit.trace(ssl,example_inputs=(torch.jit.annotate(torch.Tensor,ref_audio))).save("onnx/xw/ssl_model.pt")
print('exported ssl')
# ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')])
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
ref_bert = bert(**ref_bert_inputs)
text_bert = bert(**text_berf_inputs)
prompts,ref_audio_sr = ssl(ref_audio)
pred_semantic = t2s_m(prompts, ref_seq, text_seq, ref_bert, text_bert)
torch.jit.trace(vits,example_inputs=(
torch.jit.trace(gpt_sovits,example_inputs=(
torch.jit.annotate(torch.Tensor,ssl_content),
torch.jit.annotate(torch.Tensor,ref_audio_sr),
torch.jit.annotate(torch.Tensor,ref_seq),
torch.jit.annotate(torch.Tensor,text_seq),
torch.jit.annotate(torch.Tensor,pred_semantic),
torch.jit.annotate(torch.Tensor,ref_audio_sr))).save("onnx/xw/vits_model.pt")
torch.jit.annotate(torch.Tensor,ref_bert),
torch.jit.annotate(torch.Tensor,text_bert))).save("onnx/xw/gpt_sovits_model.pt")
print('exported vits')
@torch.jit.script
@ -430,69 +440,84 @@ def parse_audio(ref_audio):
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device)
return ref_audio_16k,ref_audio_sr
@torch.jit.script
def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float()
class GPT_SoVITS(nn.Module):
def __init__(self, t2s:T2SModel,vits:VitsModel):
super().__init__()
self.t2s = t2s
self.vits = vits
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor):
codes = self.vits.vq_model.extract_latent(ssl_content.float())
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
audio = self.vits(text_seq, pred_semantic, ref_audio_sr)
return audio
def test():
tokenizer = AutoTokenizer.from_pretrained(bert_path)
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
# bert_model.bert.embeddings = MyBertEmbeddings(bert_model.bert.config)
# bert_model.bert.encoder = MyBertEncoder(bert_model.bert.config)
# my_bert_model = MyBertModel(bert_model)
# bert = BertModel(my_bert_model)
bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert = MyBertModel(bert_model)
# bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
# dict_s1 = torch.load(gpt_path, map_location="cpu")
# raw_t2s = get_raw_t2s_model(dict_s1)
# t2s = T2SModel(dict_s1['config'],raw_t2s,top_k=3)
# t2s.eval()
t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1)
t2s = T2SModel(dict_s1['config'],raw_t2s,top_k=3)
t2s.eval()
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
# vits = VitsModel(vits_path).to('cuda')
# vits.eval()
vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path)
vits.eval()
# ssl = SSLModel(vits).to('cuda')
# ssl.eval()
ssl = ExportSSLModel(SSLModel())
ssl.eval()
vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda')
gpt_sovits = GPT_SoVITS(t2s,vits)
ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda')
# vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda')
# ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda')
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')])
ref_seq=ref_seq.to('cuda')
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","d","e1","w","en4","t","i2","."],version='v2')])
text_seq=text_seq.to('cuda')
ref_bert_inputs = tokenizer("在参加挼死特春晚的时候有人问了这样一个问题", return_tensors="pt")
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2]).int().to('cuda')
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
ref_bert = bert(
ref_bert_inputs['input_ids'].to('cuda'),
ref_bert_inputs['attention_mask'].to('cuda'),
ref_bert_inputs['token_type_ids'].to('cuda'),
ref_bert_inputs['word2ph'].to('cuda'))
ref_bert_inputs['input_ids'],
ref_bert_inputs['attention_mask'],
ref_bert_inputs['token_type_ids'],
ref_bert_inputs['word2ph']
)
print('ref_bert:',ref_bert.device)
text_berf_inputs = tokenizer("大家好,我有一个奇怪的问题.", return_tensors="pt")
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,2,1]).int().to('cuda')
text_bert = bert(text_berf_inputs['input_ids'].to('cuda'),
text_berf_inputs['attention_mask'].to('cuda'),
text_berf_inputs['token_type_ids'].to('cuda'),
text_bert = bert(text_berf_inputs['input_ids'],
text_berf_inputs['attention_mask'],
text_berf_inputs['token_type_ids'],
text_berf_inputs['word2ph'])
ref_audio = torch.tensor([load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 48000)]).float().to('cuda')
#[1,N]
ref_audio = torch.tensor(load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 16000)).float().unsqueeze(0)
print('ref_audio:',ref_audio.shape)
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
print('start ssl')
prompts,ref_audio_sr = ssl(ref_audio)
ssl_content = ssl(ref_audio)
pred_semantic = t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
print('start vits:',pred_semantic.shape)
print('ref_audio_sr:',ref_audio_sr.device)
audio = vits(text_seq, pred_semantic, ref_audio_sr)
print('start gpt_sovits:')
with torch.no_grad():
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert)
print('start write wav')
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
torch.load("onnx/symbols_v2.json")
# audio = vits(text_seq, pred_semantic1, ref_audio)
# soundfile.write("out.wav", audio, 32000)
@ -514,9 +539,3 @@ if __name__ == "__main__":
export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
# test()
# export_symbel()
# tokenizer = AutoTokenizer.from_pretrained(bert_path)
# text_berf_inputs = tokenizer("大家好,我有一个奇怪的问题.", return_tensors="pt")
# print(text_berf_inputs)
# ref_audio = load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 48000)
# print(ref_audio.shape)
# soundfile.write("chen1_ref.wav", ref_audio, 48000)