diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index e40c8aac..eb359f40 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -2,18 +2,19 @@ # reference: https://github.com/lifeiteng/vall-e from typing import Optional from my_utils import load_audio -from onnx_export import VitsModel from text import cleaned_text_to_sequence import torch import torchaudio -from torch import IntTensor, LongTensor, nn +from torch import IntTensor, LongTensor, Tensor, nn from torch.nn import functional as F from transformers import AutoModelForMaskedLM, AutoTokenizer from feature_extractor import cnhubert from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from module.models_onnx import SynthesizerTrn + import os @@ -31,6 +32,14 @@ default_config = { "EOS": 1024, } +def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: + config = dict_s1["config"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + t2s_model = t2s_model.eval() + return t2s_model + +@torch.jit.script def logits_to_probs( logits, previous_tokens: Optional[torch.Tensor] = None, @@ -73,20 +82,13 @@ def logits_to_probs( probs = torch.nn.functional.softmax(logits, dim=-1) return probs -def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: - config = dict_s1["config"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) - t2s_model.load_state_dict(dict_s1["weight"]) - t2s_model = t2s_model.eval() - return t2s_model - -def multinomial_sample_one_no_sync( - probs_sort -): # Does multinomial sampling without a cuda synchronization +@torch.jit.script +def multinomial_sample_one_no_sync(probs_sort): + # Does multinomial sampling without a cuda synchronization q = torch.randn_like(probs_sort) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) - +@torch.jit.script def sample( logits, previous_tokens, @@ -102,6 +104,90 @@ def sample( return idx_next, probs +@torch.jit.script +def spectrogram_torch(y, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False): + hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype) + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + +class VitsModel(nn.Module): + def __init__(self, vits_path): + super().__init__() + dict_s2 = torch.load(vits_path,map_location="cpu") + self.hps = dict_s2["config"] + if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + self.hps["model"]["version"] = "v1" + else: + self.hps["model"]["version"] = "v2" + + self.hps = DictToAttrRecursive(self.hps) + self.hps.model.semantic_frame_rate = "25hz" + self.vq_model = SynthesizerTrn( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + n_speakers=self.hps.data.n_speakers, + **self.hps.model + ) + self.vq_model.eval() + self.vq_model.load_state_dict(dict_s2["weight"], strict=False) + + def forward(self, text_seq, pred_semantic, ref_audio): + refer = spectrogram_torch( + ref_audio, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False + ) + return self.vq_model(pred_semantic, text_seq, refer)[0, 0] + class T2SModel(nn.Module): def __init__(self, config,raw_t2s:Text2SemanticLightningModule, norm_first=False, top_k=3): super(T2SModel, self).__init__() @@ -148,7 +234,6 @@ class T2SModel(nn.Module): self.top_k = int(self.config["inference"]["top_k"]) self.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) - # def forward(self, x:LongTensor, prompts:LongTensor): def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor): bert = torch.cat([ref_bert.T, text_bert.T], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) @@ -228,85 +313,6 @@ class T2SModel(nn.Module): y[0, -1] = 0 return y[:, -idx:].unsqueeze(0) - - # def first_stage_decoder(self, x, prompt): - # y = prompt - # x_example = x[:,:,0] * 0.0 - # #N, 1, 512 - # cache = { - # "all_stage": self.num_layers, - # "k": None, - # "v": None, - # "y_emb": None, - # "first_infer": 1, - # "stage": 0, - # } - - # y_emb = self.ar_audio_embedding(y) - - # cache["y_emb"] = y_emb - # y_pos = self.ar_audio_position(y_emb) - - # xy_pos = torch.concat([x, y_pos], dim=1) - - # y_example = y_pos[:,:,0] * 0.0 - # x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool() - # y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64) - # y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( - # torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0 - # ) - # y_attn_mask = y_attn_mask > 0 - - # x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool() - # y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool() - # x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1) - # y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) - # xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) - # cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\ - # .unsqueeze(1).repeat(self.num_layers, 1, 1, 1) - # cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\ - # .unsqueeze(1).repeat(self.num_layers, 1, 1, 1) - - # xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) - # logits = self.ar_predict_layer(xy_dec[:, -1]) - # samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) - - # y = torch.concat([y, samples], dim=1) - - # return y, cache["k"], cache["v"], cache["y_emb"], x_example - - # def stage_decoder(self, y, k, v, y_emb, x_example): - # cache = { - # "all_stage": self.num_layers, - # "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), - # "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)), - # "y_emb": y_emb, - # "first_infer": 0, - # "stage": 0, - # } - - # y_emb = torch.cat( - # [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 - # ) - # cache["y_emb"] = y_emb - # y_pos = self.ar_audio_position(y_emb) - - # xy_pos = y_pos[:, -1:] - - # y_example = y_pos[:,:,0] * 0.0 - - # xy_attn_mask = torch.cat([x_example, y_example], dim=1) - # xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool) - - # xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) - # logits = self.ar_predict_layer(xy_dec[:, -1]) - # samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) - - # y = torch.concat([y, samples], dim=1) - - # return y, cache["k"], cache["v"], cache["y_emb"], logits, samples - - bert_path = os.environ.get( "bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" @@ -314,114 +320,118 @@ bert_path = os.environ.get( cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" cnhubert.cnhubert_base_path = cnhubert_base_path -class BertModel(torch.nn.Module): - def __init__(self, bert_model): - super(BertModel, self).__init__() - self.bert = bert_model - - def forward(self, input_ids, attention_mask, token_type_ids, word2ph:IntTensor): - res = self.bert(input_ids, attention_mask, token_type_ids) - phone_level_feature = [] - for i in range(word2ph.shape[0]): - repeat_feature = res[i].repeat(word2ph[i].item(), 1) - phone_level_feature.append(repeat_feature) - phone_level_feature = torch.cat(phone_level_feature, dim=0) - # [sum(word2ph), 1024] - return phone_level_feature +@torch.jit.script +def build_phone_level_feature(res:Tensor, word2ph:IntTensor): + phone_level_feature = [] + for i in range(word2ph.shape[0]): + repeat_feature = res[i].repeat(word2ph[i].item(), 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + # [sum(word2ph), 1024] + return phone_level_feature class MyBertModel(torch.nn.Module): def __init__(self, bert_model): super(MyBertModel, self).__init__() self.bert = bert_model - def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor): + def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] - return res + return build_phone_level_feature(res, word2ph) class SSLModel(torch.nn.Module): - def __init__(self,vits:VitsModel): + def __init__(self): super().__init__() self.ssl = cnhubert.get_model().model - self.ssl_proj = vits.vq_model.ssl_proj - self.quantizer = vits.vq_model.quantizer - def forward(self, ref_audio)->LongTensor: - ref_audio_16k,ref_audio_sr = parse_audio(ref_audio) + def forward(self, ref_audio_16k)-> torch.Tensor: ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2) - codes = self.extract_latent(ssl_content.float()) - prompt_semantic = codes[0, 0] - prompts = prompt_semantic.unsqueeze(0) - return prompts,ref_audio_sr + return ssl_content - def extract_latent(self, x): - ssl = self.ssl_proj(x) - quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) - return codes.transpose(0, 1) +class ExportSSLModel(torch.nn.Module): + def __init__(self,ssl:SSLModel): + super().__init__() + self.ssl = ssl -def export_bert(tokenizer): - ref_bert_inputs = tokenizer("在参加挼死特春晚的时候有人问了这样一个问题", return_tensors="pt") + def forward(self, ref_audio:torch.Tensor): + return self.ssl(ref_audio) + + @torch.jit.export + def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor: + audio = resamplex(ref_audio,src_sr,dst_sr).float() + return audio + +def export_bert(tokenizer,ref_text,word2ph): + ref_bert_inputs = tokenizer(ref_text, return_tensors="pt") ref_bert_inputs = { 'input_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['input_ids']), 'attention_mask': torch.jit.annotate(torch.Tensor,ref_bert_inputs['attention_mask']), 'token_type_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['token_type_ids']), + 'word2ph': torch.jit.annotate(torch.Tensor,word2ph) } bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) my_bert_model = MyBertModel(bert_model) my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs) - print('trace my_bert_model') - bert = BertModel(my_bert_model) - torch.jit.script(bert).save("onnx/bert_model.pt") + my_bert_model.save("onnx/bert_model.pt") print('exported bert') def export(gpt_path, vits_path): + tokenizer = AutoTokenizer.from_pretrained(bert_path) + + ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt") + ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')]) + ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int() + + text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt") + text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')]) + text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int() + + bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) + + bert = MyBertModel(bert_model) + + # export_bert(tokenizer,"声音,是有温度的.夜晚的声音,会发光",ref_bert_inputs['word2ph']) + + ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float() + ssl = SSLModel() + s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(torch.jit.annotate(torch.Tensor,ref_audio)))) + torch.jit.script(s).save("onnx/xw/ssl_model.pt") + + print('exported ssl') + + # ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')]) + # ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')]) + # text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')]) + ref_bert = bert(**ref_bert_inputs) + text_bert = bert(**text_berf_inputs) + ssl_content = ssl(ref_audio) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path) + vits.eval() + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" dict_s1 = torch.load(gpt_path, map_location="cpu") raw_t2s = get_raw_t2s_model(dict_s1) t2s_m = T2SModel(dict_s1['config'],raw_t2s,top_k=3) t2s_m.eval() - torch.jit.script(t2s_m).save("onnx/xw/t2s_model.pt") + t2s = torch.jit.script(t2s_m) print('exported t2s_m') - tokenizer = AutoTokenizer.from_pretrained(bert_path) + gpt_sovits = GPT_SoVITS(t2s,vits) + ref_audio_sr = ssl.resample(ref_audio,16000,32000) + + # audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert) - ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt") - ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int() - - text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt") - text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int() - - bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) - - my_bert_model = MyBertModel(bert_model) - - bert = BertModel(my_bert_model) - - # export_bert(tokenizer) - - - # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" - vits = VitsModel(vits_path) - vits.eval() - - ref_audio = torch.tensor([load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 48000)]).float() - ssl = SSLModel(vits) - torch.jit.trace(ssl,example_inputs=(torch.jit.annotate(torch.Tensor,ref_audio))).save("onnx/xw/ssl_model.pt") - print('exported ssl') - - # ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')]) - ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')]) - text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')]) - ref_bert = bert(**ref_bert_inputs) - text_bert = bert(**text_berf_inputs) - prompts,ref_audio_sr = ssl(ref_audio) - pred_semantic = t2s_m(prompts, ref_seq, text_seq, ref_bert, text_bert) - - torch.jit.trace(vits,example_inputs=( + torch.jit.trace(gpt_sovits,example_inputs=( + torch.jit.annotate(torch.Tensor,ssl_content), + torch.jit.annotate(torch.Tensor,ref_audio_sr), + torch.jit.annotate(torch.Tensor,ref_seq), torch.jit.annotate(torch.Tensor,text_seq), - torch.jit.annotate(torch.Tensor,pred_semantic), - torch.jit.annotate(torch.Tensor,ref_audio_sr))).save("onnx/xw/vits_model.pt") + torch.jit.annotate(torch.Tensor,ref_bert), + torch.jit.annotate(torch.Tensor,text_bert))).save("onnx/xw/gpt_sovits_model.pt") print('exported vits') @torch.jit.script @@ -430,69 +440,84 @@ def parse_audio(ref_audio): ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device) return ref_audio_16k,ref_audio_sr +@torch.jit.script +def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor: + return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float() + +class GPT_SoVITS(nn.Module): + def __init__(self, t2s:T2SModel,vits:VitsModel): + super().__init__() + self.t2s = t2s + self.vits = vits + + def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor): + codes = self.vits.vq_model.extract_latent(ssl_content.float()) + prompt_semantic = codes[0, 0] + prompts = prompt_semantic.unsqueeze(0) + + pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert) + audio = self.vits(text_seq, pred_semantic, ref_audio_sr) + return audio + def test(): - tokenizer = AutoTokenizer.from_pretrained(bert_path) - # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) - # bert_model.bert.embeddings = MyBertEmbeddings(bert_model.bert.config) - # bert_model.bert.encoder = MyBertEncoder(bert_model.bert.config) - # my_bert_model = MyBertModel(bert_model) - # bert = BertModel(my_bert_model) - bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') + bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) + bert = MyBertModel(bert_model) + # bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') - # gpt_path = "GPT_weights_v2/xw-e15.ckpt" - # dict_s1 = torch.load(gpt_path, map_location="cpu") - # raw_t2s = get_raw_t2s_model(dict_s1) - # t2s = T2SModel(dict_s1['config'],raw_t2s,top_k=3) - # t2s.eval() - t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') + gpt_path = "GPT_weights_v2/xw-e15.ckpt" + dict_s1 = torch.load(gpt_path, map_location="cpu") + raw_t2s = get_raw_t2s_model(dict_s1) + t2s = T2SModel(dict_s1['config'],raw_t2s,top_k=3) + t2s.eval() + # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') - # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" - # vits = VitsModel(vits_path).to('cuda') - # vits.eval() + vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path) + vits.eval() - # ssl = SSLModel(vits).to('cuda') - # ssl.eval() + ssl = ExportSSLModel(SSLModel()) + ssl.eval() - vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda') + gpt_sovits = GPT_SoVITS(t2s,vits) - ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda') + # vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda') + # ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda') - ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')]) - ref_seq=ref_seq.to('cuda') - text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","d","e1","w","en4","t","i2","."],version='v2')]) - text_seq=text_seq.to('cuda') + + ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt") + ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')]) + ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int() + + text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt") + text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')]) + text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int() - ref_bert_inputs = tokenizer("在参加挼死特春晚的时候有人问了这样一个问题", return_tensors="pt") - ref_bert_inputs['word2ph'] = torch.Tensor([2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2]).int().to('cuda') ref_bert = bert( - ref_bert_inputs['input_ids'].to('cuda'), - ref_bert_inputs['attention_mask'].to('cuda'), - ref_bert_inputs['token_type_ids'].to('cuda'), - ref_bert_inputs['word2ph'].to('cuda')) + ref_bert_inputs['input_ids'], + ref_bert_inputs['attention_mask'], + ref_bert_inputs['token_type_ids'], + ref_bert_inputs['word2ph'] + ) - print('ref_bert:',ref_bert.device) - - text_berf_inputs = tokenizer("大家好,我有一个奇怪的问题.", return_tensors="pt") - text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,2,1]).int().to('cuda') - text_bert = bert(text_berf_inputs['input_ids'].to('cuda'), - text_berf_inputs['attention_mask'].to('cuda'), - text_berf_inputs['token_type_ids'].to('cuda'), + text_bert = bert(text_berf_inputs['input_ids'], + text_berf_inputs['attention_mask'], + text_berf_inputs['token_type_ids'], text_berf_inputs['word2ph']) - ref_audio = torch.tensor([load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 48000)]).float().to('cuda') + #[1,N] + ref_audio = torch.tensor(load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 16000)).float().unsqueeze(0) + print('ref_audio:',ref_audio.shape) + ref_audio_sr = ssl.resample(ref_audio,16000,32000) print('start ssl') - prompts,ref_audio_sr = ssl(ref_audio) + ssl_content = ssl(ref_audio) - pred_semantic = t2s(prompts, ref_seq, text_seq, ref_bert, text_bert) - - print('start vits:',pred_semantic.shape) - print('ref_audio_sr:',ref_audio_sr.device) - audio = vits(text_seq, pred_semantic, ref_audio_sr) + print('start gpt_sovits:') + with torch.no_grad(): + audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert) print('start write wav') soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000) - torch.load("onnx/symbols_v2.json") # audio = vits(text_seq, pred_semantic1, ref_audio) # soundfile.write("out.wav", audio, 32000) @@ -513,10 +538,4 @@ def export_symbel(version='v2'): if __name__ == "__main__": export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth") # test() - # export_symbel() - # tokenizer = AutoTokenizer.from_pretrained(bert_path) - # text_berf_inputs = tokenizer("大家好,我有一个奇怪的问题.", return_tensors="pt") - # print(text_berf_inputs) - # ref_audio = load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 48000) - # print(ref_audio.shape) - # soundfile.write("chen1_ref.wav", ref_audio, 48000) \ No newline at end of file + # export_symbel() \ No newline at end of file