mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 14:40:00 +08:00
恢复 t2s_model.py
把改动移到 export_torch_script.py
This commit is contained in:
parent
41dbc179c3
commit
4ed0b8bdcc
@ -83,7 +83,7 @@ class T2SMLP:
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_heads,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
@ -92,12 +92,12 @@ class T2SBlock:
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1: float,
|
||||
norm_eps1,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2: float,
|
||||
norm_eps2,
|
||||
):
|
||||
self.num_heads:int = num_heads
|
||||
self.num_heads = num_heads
|
||||
self.mlp = mlp
|
||||
self.hidden_dim: int = hidden_dim
|
||||
self.qkv_w = qkv_w
|
||||
@ -266,7 +266,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
self.norm_first = norm_first
|
||||
self.vocab_size = config["model"]["vocab_size"]
|
||||
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
|
||||
self.p_dropout = float(config["model"]["dropout"])
|
||||
self.p_dropout = config["model"]["dropout"]
|
||||
self.EOS = config["model"]["EOS"]
|
||||
self.norm_first = norm_first
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
|
@ -34,6 +34,7 @@ default_config = {
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
config = dict_s1["config"]
|
||||
config["model"]["dropout"] = float(config["model"]["dropout"])
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
t2s_model = t2s_model.eval()
|
||||
@ -105,7 +106,7 @@ def sample(
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def spectrogram_torch(y, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
|
||||
def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
|
||||
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
@ -155,7 +156,180 @@ class DictToAttrRecursive(dict):
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
self.w1 = w1
|
||||
self.b1 = b1
|
||||
self.w2 = w2
|
||||
self.b2 = b2
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.linear(x, self.w1, self.b1))
|
||||
x = F.linear(x, self.w2, self.b2)
|
||||
return x
|
||||
|
||||
@torch.jit.script
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1: float,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2: float,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.mlp = mlp
|
||||
self.hidden_dim: int = hidden_dim
|
||||
self.qkv_w = qkv_w
|
||||
self.qkv_b = qkv_b
|
||||
self.out_w = out_w
|
||||
self.out_b = out_b
|
||||
self.norm_w1 = norm_w1
|
||||
self.norm_b1 = norm_b1
|
||||
self.norm_eps1 = norm_eps1
|
||||
self.norm_w2 = norm_w2
|
||||
self.norm_b2 = norm_b2
|
||||
self.norm_eps2 = norm_eps2
|
||||
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
if padding_mask.dtype == torch.bool:
|
||||
return x.masked_fill(padding_mask, 0)
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k.shape[1]
|
||||
|
||||
q = self.to_mask(q, padding_mask)
|
||||
k_cache = self.to_mask(k, padding_mask)
|
||||
v_cache = self.to_mask(v, padding_mask)
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
if self.false.device!= padding_mask.device:
|
||||
self.false = self.false.to(padding_mask.device)
|
||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
||||
x_item = x[i,idx,:].unsqueeze(0)
|
||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(
|
||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i,idx,:] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
v_cache = torch.cat([v_cache, v], dim=1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k_cache.shape[1]
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: list[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
|
||||
k_cache : list[torch.Tensor] = []
|
||||
v_cache : list[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
|
||||
k_cache.append(k_cache_)
|
||||
v_cache.append(v_cache_)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
k_cache: list[torch.Tensor],
|
||||
v_cache: list[torch.Tensor]):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
return x, k_cache, v_cache
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
@ -189,34 +363,19 @@ class VitsModel(nn.Module):
|
||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
def __init__(self, config,raw_t2s:Text2SemanticLightningModule, norm_first=False, top_k=3):
|
||||
def __init__(self,raw_t2s:Text2SemanticLightningModule):
|
||||
super(T2SModel, self).__init__()
|
||||
self.model_dim = config["model"]["hidden_dim"]
|
||||
self.embedding_dim = config["model"]["embedding_dim"]
|
||||
self.num_head = config["model"]["head"]
|
||||
self.num_layers = config["model"]["n_layer"]
|
||||
self.vocab_size = config["model"]["vocab_size"]
|
||||
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
|
||||
self.p_dropout = float(config["model"]["dropout"])
|
||||
self.EOS:int = config["model"]["EOS"]
|
||||
self.norm_first = norm_first
|
||||
self.model_dim = raw_t2s.model.model_dim
|
||||
self.embedding_dim = raw_t2s.model.embedding_dim
|
||||
self.num_head = raw_t2s.model.num_head
|
||||
self.num_layers = raw_t2s.model.num_layers
|
||||
self.vocab_size = raw_t2s.model.vocab_size
|
||||
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
|
||||
# self.p_dropout = float(raw_t2s.model.p_dropout)
|
||||
self.EOS:int = int(raw_t2s.model.EOS)
|
||||
self.norm_first = raw_t2s.model.norm_first
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
self.hz = 50
|
||||
self.config = config
|
||||
|
||||
# self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
# self.ar_text_embedding = TokenEmbedding(
|
||||
# self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||
# )
|
||||
# self.ar_text_position = SinePositionalEmbedding(
|
||||
# self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
# )
|
||||
# self.ar_audio_embedding = TokenEmbedding(
|
||||
# self.embedding_dim, self.vocab_size, self.p_dropout
|
||||
# )
|
||||
# self.ar_audio_position = SinePositionalEmbedding(
|
||||
# self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
# )
|
||||
|
||||
self.bert_proj = raw_t2s.model.bert_proj
|
||||
self.ar_text_embedding = raw_t2s.model.ar_text_embedding
|
||||
@ -225,13 +384,45 @@ class T2SModel(nn.Module):
|
||||
self.ar_audio_position = raw_t2s.model.ar_audio_position
|
||||
|
||||
# self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
self.t2s_transformer = raw_t2s.model.t2s_transformer
|
||||
# self.t2s_transformer = raw_t2s.model.t2s_transformer
|
||||
|
||||
blocks = []
|
||||
h = raw_t2s.model.h
|
||||
|
||||
for i in range(self.num_layers):
|
||||
layer = h.layers[i]
|
||||
t2smlp = T2SMLP(
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
)
|
||||
|
||||
block = T2SBlock(
|
||||
self.num_head,
|
||||
self.model_dim,
|
||||
t2smlp,
|
||||
layer.self_attn.in_proj_weight,
|
||||
layer.self_attn.in_proj_bias,
|
||||
layer.self_attn.out_proj.weight,
|
||||
layer.self_attn.out_proj.bias,
|
||||
layer.norm1.weight,
|
||||
layer.norm1.bias,
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
|
||||
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
||||
self.ar_predict_layer = raw_t2s.model.ar_predict_layer
|
||||
# self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
|
||||
self.max_sec = self.config["data"]["max_sec"]
|
||||
self.top_k = int(self.config["inference"]["top_k"])
|
||||
self.max_sec = raw_t2s.config["data"]["max_sec"]
|
||||
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||||
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||
|
||||
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor):
|
||||
@ -296,6 +487,10 @@ class T2SModel(nn.Module):
|
||||
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
@ -305,13 +500,13 @@ class T2SModel(nn.Module):
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
stop = True
|
||||
if stop:
|
||||
if y.shape[1] == 0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
break
|
||||
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
|
||||
y[0, -1] = 0
|
||||
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
bert_path = os.environ.get(
|
||||
@ -362,20 +557,19 @@ class ExportSSLModel(torch.nn.Module):
|
||||
audio = resamplex(ref_audio,src_sr,dst_sr).float()
|
||||
return audio
|
||||
|
||||
def export_bert(tokenizer,ref_text,word2ph):
|
||||
ref_bert_inputs = tokenizer(ref_text, return_tensors="pt")
|
||||
def export_bert(ref_bert_inputs):
|
||||
ref_bert_inputs = {
|
||||
'input_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['input_ids']),
|
||||
'attention_mask': torch.jit.annotate(torch.Tensor,ref_bert_inputs['attention_mask']),
|
||||
'token_type_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['token_type_ids']),
|
||||
'word2ph': torch.jit.annotate(torch.Tensor,word2ph)
|
||||
'input_ids': ref_bert_inputs['input_ids'],
|
||||
'attention_mask': ref_bert_inputs['attention_mask'],
|
||||
'token_type_ids': ref_bert_inputs['token_type_ids'],
|
||||
'word2ph': ref_bert_inputs['word2ph']
|
||||
}
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
|
||||
my_bert_model = MyBertModel(bert_model)
|
||||
|
||||
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
|
||||
my_bert_model.save("onnx/bert_model.pt")
|
||||
print('exported bert')
|
||||
print('#### exported bert ####')
|
||||
|
||||
def export(gpt_path, vits_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
@ -392,18 +586,14 @@ def export(gpt_path, vits_path):
|
||||
|
||||
bert = MyBertModel(bert_model)
|
||||
|
||||
# export_bert(tokenizer,"声音,是有温度的.夜晚的声音,会发光",ref_bert_inputs['word2ph'])
|
||||
# export_bert(ref_bert_inputs)
|
||||
|
||||
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(torch.jit.annotate(torch.Tensor,ref_audio))))
|
||||
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
|
||||
torch.jit.script(s).save("onnx/xw/ssl_model.pt")
|
||||
print('#### exported ssl ####')
|
||||
|
||||
print('exported ssl')
|
||||
|
||||
# ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')])
|
||||
# ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
|
||||
# text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
|
||||
ref_bert = bert(**ref_bert_inputs)
|
||||
text_bert = bert(**text_berf_inputs)
|
||||
ssl_content = ssl(ref_audio)
|
||||
@ -415,24 +605,30 @@ def export(gpt_path, vits_path):
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||
raw_t2s = get_raw_t2s_model(dict_s1)
|
||||
t2s_m = T2SModel(dict_s1['config'],raw_t2s,top_k=3)
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
t2s = torch.jit.script(t2s_m)
|
||||
print('exported t2s_m')
|
||||
|
||||
gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
|
||||
|
||||
# audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert)
|
||||
print('#### script t2s_m ####')
|
||||
|
||||
torch.jit.trace(gpt_sovits,example_inputs=(
|
||||
torch.jit.annotate(torch.Tensor,ssl_content),
|
||||
torch.jit.annotate(torch.Tensor,ref_audio_sr),
|
||||
torch.jit.annotate(torch.Tensor,ref_seq),
|
||||
torch.jit.annotate(torch.Tensor,text_seq),
|
||||
torch.jit.annotate(torch.Tensor,ref_bert),
|
||||
torch.jit.annotate(torch.Tensor,text_bert))).save("onnx/xw/gpt_sovits_model.pt")
|
||||
print('exported vits')
|
||||
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||
gpt_sovits.eval()
|
||||
ref_audio_sr = s.resample(ref_audio,16000,32000)
|
||||
print('ref_audio_sr:',ref_audio_sr.shape)
|
||||
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
example_inputs=(
|
||||
ssl_content,
|
||||
ref_audio_sr,
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert),
|
||||
check_trace=False) # 默认是True 但是 check 的时候可能是随机生成的一个奇怪维度的值,导致报错
|
||||
|
||||
gpt_sovits_export.save("onnx/xw/gpt_sovits_model.pt")
|
||||
print('#### exported gpt_sovits ####')
|
||||
|
||||
@torch.jit.script
|
||||
def parse_audio(ref_audio):
|
||||
@ -459,20 +655,20 @@ class GPT_SoVITS(nn.Module):
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr)
|
||||
return audio
|
||||
|
||||
def test():
|
||||
def test(gpt_path, vits_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
|
||||
bert = MyBertModel(bert_model)
|
||||
# bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
|
||||
|
||||
gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||
raw_t2s = get_raw_t2s_model(dict_s1)
|
||||
t2s = T2SModel(dict_s1['config'],raw_t2s,top_k=3)
|
||||
t2s = T2SModel(raw_t2s)
|
||||
t2s.eval()
|
||||
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
|
||||
|
||||
vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path)
|
||||
vits.eval()
|
||||
|
||||
@ -506,7 +702,7 @@ def test():
|
||||
text_berf_inputs['word2ph'])
|
||||
|
||||
#[1,N]
|
||||
ref_audio = torch.tensor(load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 16000)).float().unsqueeze(0)
|
||||
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
|
||||
print('ref_audio:',ref_audio.shape)
|
||||
|
||||
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
|
||||
@ -537,5 +733,5 @@ def export_symbel(version='v2'):
|
||||
|
||||
if __name__ == "__main__":
|
||||
export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
|
||||
# test()
|
||||
# test(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
|
||||
# export_symbel()
|
Loading…
x
Reference in New Issue
Block a user