mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 14:40:00 +08:00
Merge branch 'RVC-Boss:main' into main
This commit is contained in:
commit
787881a6ce
@ -1,5 +1,6 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
|
import argparse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from my_utils import load_audio
|
from my_utils import load_audio
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
@ -15,7 +16,7 @@ from feature_extractor import cnhubert
|
|||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from module.models_onnx import SynthesizerTrn
|
from module.models_onnx import SynthesizerTrn
|
||||||
|
|
||||||
|
from inference_webui import get_phones_and_bert
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import soundfile
|
import soundfile
|
||||||
@ -329,11 +330,12 @@ class T2STransformer:
|
|||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
class VitsModel(nn.Module):
|
class VitsModel(nn.Module):
|
||||||
def __init__(self, vits_path):
|
def __init__(self, vits_path):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dict_s2 = torch.load(vits_path,map_location="cpu")
|
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||||
|
dict_s2 = torch.load(vits_path)
|
||||||
self.hps = dict_s2["config"]
|
self.hps = dict_s2["config"]
|
||||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||||
self.hps["model"]["version"] = "v1"
|
self.hps["model"]["version"] = "v1"
|
||||||
@ -351,7 +353,7 @@ class VitsModel(nn.Module):
|
|||||||
self.vq_model.eval()
|
self.vq_model.eval()
|
||||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
|
|
||||||
def forward(self, text_seq, pred_semantic, ref_audio):
|
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
|
||||||
refer = spectrogram_torch(
|
refer = spectrogram_torch(
|
||||||
ref_audio,
|
ref_audio,
|
||||||
self.hps.data.filter_length,
|
self.hps.data.filter_length,
|
||||||
@ -360,7 +362,7 @@ class VitsModel(nn.Module):
|
|||||||
self.hps.data.win_length,
|
self.hps.data.win_length,
|
||||||
center=False
|
center=False
|
||||||
)
|
)
|
||||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||||
|
|
||||||
class T2SModel(nn.Module):
|
class T2SModel(nn.Module):
|
||||||
def __init__(self,raw_t2s:Text2SemanticLightningModule):
|
def __init__(self,raw_t2s:Text2SemanticLightningModule):
|
||||||
@ -507,6 +509,8 @@ class T2SModel(nn.Module):
|
|||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||||
|
|
||||||
|
y[0,-1] = 0
|
||||||
|
|
||||||
return y[:, -idx:].unsqueeze(0)
|
return y[:, -idx:].unsqueeze(0)
|
||||||
|
|
||||||
bert_path = os.environ.get(
|
bert_path = os.environ.get(
|
||||||
@ -524,7 +528,7 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
|||||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||||
# [sum(word2ph), 1024]
|
# [sum(word2ph), 1024]
|
||||||
return phone_level_feature
|
return phone_level_feature
|
||||||
|
|
||||||
class MyBertModel(torch.nn.Module):
|
class MyBertModel(torch.nn.Module):
|
||||||
def __init__(self, bert_model):
|
def __init__(self, bert_model):
|
||||||
super(MyBertModel, self).__init__()
|
super(MyBertModel, self).__init__()
|
||||||
@ -532,7 +536,8 @@ class MyBertModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
|
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
|
||||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||||
res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
||||||
|
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
||||||
return build_phone_level_feature(res, word2ph)
|
return build_phone_level_feature(res, word2ph)
|
||||||
|
|
||||||
class SSLModel(torch.nn.Module):
|
class SSLModel(torch.nn.Module):
|
||||||
@ -557,66 +562,99 @@ class ExportSSLModel(torch.nn.Module):
|
|||||||
audio = resamplex(ref_audio,src_sr,dst_sr).float()
|
audio = resamplex(ref_audio,src_sr,dst_sr).float()
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def export_bert(ref_bert_inputs):
|
def export_bert(output_path):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
|
|
||||||
|
text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
|
||||||
|
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
||||||
|
word2ph = []
|
||||||
|
for c in text:
|
||||||
|
if c in [',','。',':','?',",",".","?"]:
|
||||||
|
word2ph.append(1)
|
||||||
|
else:
|
||||||
|
word2ph.append(2)
|
||||||
|
ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int()
|
||||||
|
|
||||||
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||||
|
my_bert_model = MyBertModel(bert_model)
|
||||||
|
|
||||||
ref_bert_inputs = {
|
ref_bert_inputs = {
|
||||||
'input_ids': ref_bert_inputs['input_ids'],
|
'input_ids': ref_bert_inputs['input_ids'],
|
||||||
'attention_mask': ref_bert_inputs['attention_mask'],
|
'attention_mask': ref_bert_inputs['attention_mask'],
|
||||||
'token_type_ids': ref_bert_inputs['token_type_ids'],
|
'token_type_ids': ref_bert_inputs['token_type_ids'],
|
||||||
'word2ph': ref_bert_inputs['word2ph']
|
'word2ph': ref_bert_inputs['word2ph']
|
||||||
}
|
}
|
||||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
|
|
||||||
my_bert_model = MyBertModel(bert_model)
|
torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0)
|
||||||
|
|
||||||
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
|
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
|
||||||
my_bert_model.save("onnx/bert_model.pt")
|
output_path = os.path.join(output_path, "bert_model.pt")
|
||||||
|
my_bert_model.save(output_path)
|
||||||
print('#### exported bert ####')
|
print('#### exported bert ####')
|
||||||
|
|
||||||
|
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'):
|
||||||
|
if not os.path.exists(output_path):
|
||||||
|
os.makedirs(output_path)
|
||||||
|
print(f"目录已创建: {output_path}")
|
||||||
|
else:
|
||||||
|
print(f"目录已存在: {output_path}")
|
||||||
|
|
||||||
def export(gpt_path, vits_path):
|
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
|
||||||
|
|
||||||
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
|
|
||||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
|
|
||||||
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
|
|
||||||
|
|
||||||
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
|
|
||||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
|
|
||||||
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
|
|
||||||
|
|
||||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
|
|
||||||
|
|
||||||
bert = MyBertModel(bert_model)
|
|
||||||
|
|
||||||
# export_bert(ref_bert_inputs)
|
|
||||||
|
|
||||||
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
|
|
||||||
ssl = SSLModel()
|
ssl = SSLModel()
|
||||||
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
|
if export_bert_and_ssl:
|
||||||
torch.jit.script(s).save("onnx/xw/ssl_model.pt")
|
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
|
||||||
print('#### exported ssl ####')
|
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||||
|
torch.jit.script(s).save(ssl_path)
|
||||||
|
print('#### exported ssl ####')
|
||||||
|
export_bert(output_path)
|
||||||
|
else:
|
||||||
|
s = ExportSSLModel(ssl)
|
||||||
|
|
||||||
ref_bert = bert(**ref_bert_inputs)
|
print(f"device: {device}")
|
||||||
text_bert = bert(**text_berf_inputs)
|
|
||||||
ssl_content = ssl(ref_audio)
|
|
||||||
|
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
||||||
|
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||||
|
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||||
|
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2')
|
||||||
|
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||||
|
text_bert = text_bert_T.T.to(text_seq.device)
|
||||||
|
|
||||||
|
ssl_content = ssl(ref_audio).to(device)
|
||||||
|
|
||||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
vits = VitsModel(vits_path)
|
vits = VitsModel(vits_path).to(device)
|
||||||
vits.eval()
|
vits.eval()
|
||||||
|
|
||||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||||
raw_t2s = get_raw_t2s_model(dict_s1)
|
dict_s1 = torch.load(gpt_path)
|
||||||
|
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||||
|
print('#### get_raw_t2s_model ####')
|
||||||
|
print(raw_t2s.config)
|
||||||
t2s_m = T2SModel(raw_t2s)
|
t2s_m = T2SModel(raw_t2s)
|
||||||
t2s_m.eval()
|
t2s_m.eval()
|
||||||
t2s = torch.jit.script(t2s_m)
|
t2s = torch.jit.script(t2s_m).to(device)
|
||||||
print('#### script t2s_m ####')
|
print('#### script t2s_m ####')
|
||||||
|
|
||||||
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
|
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
|
||||||
gpt_sovits = GPT_SoVITS(t2s,vits)
|
gpt_sovits = GPT_SoVITS(t2s,vits).to(device)
|
||||||
gpt_sovits.eval()
|
gpt_sovits.eval()
|
||||||
ref_audio_sr = s.resample(ref_audio,16000,32000)
|
|
||||||
print('ref_audio_sr:',ref_audio_sr.shape)
|
|
||||||
|
|
||||||
gpt_sovits_export = torch.jit.trace(
|
ref_audio_sr = s.resample(ref_audio,16000,32000).to(device)
|
||||||
|
|
||||||
|
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||||
|
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||||||
|
torch._dynamo.mark_dynamic(text_seq, 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||||
|
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
gpt_sovits_export = torch.jit.trace(
|
||||||
gpt_sovits,
|
gpt_sovits,
|
||||||
example_inputs=(
|
example_inputs=(
|
||||||
ssl_content,
|
ssl_content,
|
||||||
@ -624,11 +662,11 @@ def export(gpt_path, vits_path):
|
|||||||
ref_seq,
|
ref_seq,
|
||||||
text_seq,
|
text_seq,
|
||||||
ref_bert,
|
ref_bert,
|
||||||
text_bert),
|
text_bert))
|
||||||
check_trace=False) # 默认是True 但是 check 的时候可能是随机生成的一个奇怪维度的值,导致报错
|
|
||||||
|
|
||||||
gpt_sovits_export.save("onnx/xw/gpt_sovits_model.pt")
|
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||||
print('#### exported gpt_sovits ####')
|
gpt_sovits_export.save(gpt_sovits_path)
|
||||||
|
print('#### exported gpt_sovits ####')
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def parse_audio(ref_audio):
|
def parse_audio(ref_audio):
|
||||||
@ -646,63 +684,88 @@ class GPT_SoVITS(nn.Module):
|
|||||||
self.t2s = t2s
|
self.t2s = t2s
|
||||||
self.vits = vits
|
self.vits = vits
|
||||||
|
|
||||||
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor):
|
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0):
|
||||||
codes = self.vits.vq_model.extract_latent(ssl_content.float())
|
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
prompts = prompt_semantic.unsqueeze(0)
|
prompts = prompt_semantic.unsqueeze(0)
|
||||||
|
|
||||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
|
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
|
||||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr)
|
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def test(gpt_path, vits_path):
|
def test():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
|
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||||
bert = MyBertModel(bert_model)
|
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||||
# bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
|
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||||
|
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||||
|
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||||
|
|
||||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
|
||||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
args = parser.parse_args()
|
||||||
raw_t2s = get_raw_t2s_model(dict_s1)
|
gpt_path = args.gpt_model
|
||||||
t2s = T2SModel(raw_t2s)
|
vits_path = args.sovits_model
|
||||||
t2s.eval()
|
ref_audio_path = args.ref_audio
|
||||||
|
ref_text = args.ref_text
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
|
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||||
|
# bert = MyBertModel(bert_model)
|
||||||
|
my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
|
||||||
|
|
||||||
|
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
||||||
|
# raw_t2s = get_raw_t2s_model(dict_s1)
|
||||||
|
# t2s = T2SModel(raw_t2s)
|
||||||
|
# t2s.eval()
|
||||||
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
|
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
|
||||||
|
|
||||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
vits = VitsModel(vits_path)
|
# vits = VitsModel(vits_path)
|
||||||
vits.eval()
|
# vits.eval()
|
||||||
|
|
||||||
ssl = ExportSSLModel(SSLModel())
|
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
||||||
ssl.eval()
|
# ssl.eval()
|
||||||
|
ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda')
|
||||||
|
|
||||||
gpt_sovits = GPT_SoVITS(t2s,vits)
|
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||||
|
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda')
|
||||||
|
|
||||||
# vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda')
|
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
||||||
# ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda')
|
ref_seq = torch.LongTensor([ref_seq_id])
|
||||||
|
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||||
|
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
||||||
|
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
||||||
|
|
||||||
|
text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2')
|
||||||
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
|
|
||||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
|
test_bert = tokenizer(text, return_tensors="pt")
|
||||||
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
|
word2ph = []
|
||||||
|
for c in text:
|
||||||
|
if c in [',','。',':','?',"?",",","."]:
|
||||||
|
word2ph.append(1)
|
||||||
|
else:
|
||||||
|
word2ph.append(2)
|
||||||
|
test_bert['word2ph'] = torch.Tensor(word2ph).int()
|
||||||
|
|
||||||
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
|
test_bert = my_bert(
|
||||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
|
test_bert['input_ids'].to('cuda'),
|
||||||
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
|
test_bert['attention_mask'].to('cuda'),
|
||||||
|
test_bert['token_type_ids'].to('cuda'),
|
||||||
ref_bert = bert(
|
test_bert['word2ph'].to('cuda')
|
||||||
ref_bert_inputs['input_ids'],
|
|
||||||
ref_bert_inputs['attention_mask'],
|
|
||||||
ref_bert_inputs['token_type_ids'],
|
|
||||||
ref_bert_inputs['word2ph']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
text_bert = bert(text_berf_inputs['input_ids'],
|
text_seq = torch.LongTensor([text_seq_id])
|
||||||
text_berf_inputs['attention_mask'],
|
text_bert = text_bert_T.T.to(text_seq.device)
|
||||||
text_berf_inputs['token_type_ids'],
|
|
||||||
text_berf_inputs['word2ph'])
|
print('text_bert:',text_bert.shape,text_bert)
|
||||||
|
print('test_bert:',test_bert.shape,test_bert)
|
||||||
|
print(torch.allclose(text_bert.to('cuda'),test_bert))
|
||||||
|
|
||||||
|
print('text_seq:',text_seq.shape)
|
||||||
|
print('text_bert:',text_bert.shape,text_bert.type())
|
||||||
|
|
||||||
#[1,N]
|
#[1,N]
|
||||||
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
|
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda')
|
||||||
print('ref_audio:',ref_audio.shape)
|
print('ref_audio:',ref_audio.shape)
|
||||||
|
|
||||||
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
|
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
|
||||||
@ -710,13 +773,22 @@ def test(gpt_path, vits_path):
|
|||||||
ssl_content = ssl(ref_audio)
|
ssl_content = ssl(ref_audio)
|
||||||
|
|
||||||
print('start gpt_sovits:')
|
print('start gpt_sovits:')
|
||||||
|
print('ssl_content:',ssl_content.shape)
|
||||||
|
print('ref_audio_sr:',ref_audio_sr.shape)
|
||||||
|
print('ref_seq:',ref_seq.shape)
|
||||||
|
ref_seq=ref_seq.to('cuda')
|
||||||
|
print('text_seq:',text_seq.shape)
|
||||||
|
text_seq=text_seq.to('cuda')
|
||||||
|
print('ref_bert:',ref_bert.shape)
|
||||||
|
ref_bert=ref_bert.to('cuda')
|
||||||
|
print('text_bert:',text_bert.shape)
|
||||||
|
text_bert=text_bert.to('cuda')
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert)
|
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert)
|
||||||
print('start write wav')
|
print('start write wav')
|
||||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
# audio = vits(text_seq, pred_semantic1, ref_audio)
|
|
||||||
# soundfile.write("out.wav", audio, 32000)
|
|
||||||
|
|
||||||
import text
|
import text
|
||||||
import json
|
import json
|
||||||
@ -731,7 +803,30 @@ def export_symbel(version='v2'):
|
|||||||
with open(f"onnx/symbols_v2.json", "w") as file:
|
with open(f"onnx/symbols_v2.json", "w") as file:
|
||||||
json.dump(symbols, file, indent=4)
|
json.dump(symbols, file, indent=4)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
|
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||||
|
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||||
|
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||||
|
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||||
|
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||||
|
parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model")
|
||||||
|
parser.add_argument('--device', help="Device to use")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
export(
|
||||||
|
gpt_path=args.gpt_model,
|
||||||
|
vits_path=args.sovits_model,
|
||||||
|
ref_audio_path=args.ref_audio,
|
||||||
|
ref_text=args.ref_text,
|
||||||
|
output_path=args.output_path,
|
||||||
|
device=args.device,
|
||||||
|
export_bert_and_ssl=args.export_common_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
import inference_webui
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
|
inference_webui.is_half=False
|
||||||
# test(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
|
inference_webui.dtype=torch.float32
|
||||||
# export_symbel()
|
main()
|
||||||
|
# test()
|
||||||
|
@ -231,7 +231,7 @@ class TextEncoder(nn.Module):
|
|||||||
|
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, y, text, ge):
|
def forward(self, y, text, ge, speed=1):
|
||||||
y_mask = torch.ones_like(y[:1,:1,:])
|
y_mask = torch.ones_like(y[:1,:1,:])
|
||||||
|
|
||||||
y = self.ssl_proj(y * y_mask) * y_mask
|
y = self.ssl_proj(y * y_mask) * y_mask
|
||||||
@ -244,6 +244,9 @@ class TextEncoder(nn.Module):
|
|||||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||||
|
|
||||||
y = self.encoder2(y * y_mask, y_mask)
|
y = self.encoder2(y * y_mask, y_mask)
|
||||||
|
if(speed!=1):
|
||||||
|
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
|
||||||
|
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||||
|
|
||||||
stats = self.proj(y) * y_mask
|
stats = self.proj(y) * y_mask
|
||||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
@ -887,7 +890,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
# self.enc_p.encoder_text.requires_grad_(False)
|
# self.enc_p.encoder_text.requires_grad_(False)
|
||||||
# self.enc_p.mrte.requires_grad_(False)
|
# self.enc_p.mrte.requires_grad_(False)
|
||||||
|
|
||||||
def forward(self, codes, text, refer):
|
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
|
||||||
refer_mask = torch.ones_like(refer[:1,:1,:])
|
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||||
if (self.version == "v1"):
|
if (self.version == "v1"):
|
||||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
@ -900,10 +903,10 @@ class SynthesizerTrn(nn.Module):
|
|||||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||||
|
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(
|
x, m_p, logs_p, y_mask = self.enc_p(
|
||||||
quantized, text, ge
|
quantized, text, ge, speed
|
||||||
)
|
)
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
|
||||||
|
@ -127,14 +127,14 @@ def get_dict():
|
|||||||
|
|
||||||
def read_dict():
|
def read_dict():
|
||||||
polyphonic_dict = {}
|
polyphonic_dict = {}
|
||||||
with open(PP_DICT_PATH) as f:
|
with open(PP_DICT_PATH,encoding="utf-8") as f:
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
while line:
|
while line:
|
||||||
key, value_str = line.split(':')
|
key, value_str = line.split(':')
|
||||||
value = eval(value_str.strip())
|
value = eval(value_str.strip())
|
||||||
polyphonic_dict[key.strip()] = value
|
polyphonic_dict[key.strip()] = value
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
with open(PP_FIX_DICT_PATH) as f:
|
with open(PP_FIX_DICT_PATH,encoding="utf-8") as f:
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
while line:
|
while line:
|
||||||
key, value_str = line.split(':')
|
key, value_str = line.split(':')
|
||||||
@ -151,4 +151,4 @@ def correct_pronunciation(word,word_pinyins):
|
|||||||
return word_pinyins
|
return word_pinyins
|
||||||
|
|
||||||
|
|
||||||
pp_dict = get_dict()
|
pp_dict = get_dict()
|
||||||
|
@ -451,6 +451,8 @@ async def set_sovits_weights(weights_path: str = None):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
|
if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈
|
||||||
|
host = None
|
||||||
uvicorn.run(app=APP, host=host, port=port, workers=1)
|
uvicorn.run(app=APP, host=host, port=port, workers=1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user