mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
添加导出成 TorchScript 的脚本用于支持python以外的语言 (#1640)
* Fix onnx_export to support v2 * delete some useless code & add some args type for export torch-script * Add export_torch_script.py * (export_torch_script.py) 整合 vits 和 t2s 成一个 model 导出 * 恢复 `t2s_model.py` 把改动移到 `export_torch_script.py`
This commit is contained in:
parent
78c68d46cb
commit
5efb960898
737
GPT_SoVITS/export_torch_script.py
Normal file
737
GPT_SoVITS/export_torch_script.py
Normal file
@ -0,0 +1,737 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
from typing import Optional
|
||||
from my_utils import load_audio
|
||||
from text import cleaned_text_to_sequence
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from torch import IntTensor, LongTensor, Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from feature_extractor import cnhubert
|
||||
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from module.models_onnx import SynthesizerTrn
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import soundfile
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
"hidden_dim": 512,
|
||||
"num_head": 8,
|
||||
"num_layers": 12,
|
||||
"num_codebook": 8,
|
||||
"p_dropout": 0.0,
|
||||
"vocab_size": 1024 + 1,
|
||||
"phoneme_vocab_size": 512,
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
config = dict_s1["config"]
|
||||
config["model"]["dropout"] = float(config["model"]["dropout"])
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
t2s_model = t2s_model.eval()
|
||||
return t2s_model
|
||||
|
||||
@torch.jit.script
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
previous_tokens: Optional[torch.Tensor] = None,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
# if previous_tokens is not None:
|
||||
# previous_tokens = previous_tokens.squeeze()
|
||||
# print(logits.shape,previous_tokens.shape)
|
||||
# pdb.set_trace()
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
@torch.jit.script
|
||||
def multinomial_sample_one_no_sync(probs_sort):
|
||||
# Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
@torch.jit.script
|
||||
def sample(
|
||||
logits,
|
||||
previous_tokens,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
|
||||
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
self[key] = value
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
self.w1 = w1
|
||||
self.b1 = b1
|
||||
self.w2 = w2
|
||||
self.b2 = b2
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.linear(x, self.w1, self.b1))
|
||||
x = F.linear(x, self.w2, self.b2)
|
||||
return x
|
||||
|
||||
@torch.jit.script
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1: float,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2: float,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.mlp = mlp
|
||||
self.hidden_dim: int = hidden_dim
|
||||
self.qkv_w = qkv_w
|
||||
self.qkv_b = qkv_b
|
||||
self.out_w = out_w
|
||||
self.out_b = out_b
|
||||
self.norm_w1 = norm_w1
|
||||
self.norm_b1 = norm_b1
|
||||
self.norm_eps1 = norm_eps1
|
||||
self.norm_w2 = norm_w2
|
||||
self.norm_b2 = norm_b2
|
||||
self.norm_eps2 = norm_eps2
|
||||
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
if padding_mask.dtype == torch.bool:
|
||||
return x.masked_fill(padding_mask, 0)
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k.shape[1]
|
||||
|
||||
q = self.to_mask(q, padding_mask)
|
||||
k_cache = self.to_mask(k, padding_mask)
|
||||
v_cache = self.to_mask(v, padding_mask)
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
if self.false.device!= padding_mask.device:
|
||||
self.false = self.false.to(padding_mask.device)
|
||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
||||
x_item = x[i,idx,:].unsqueeze(0)
|
||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(
|
||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i,idx,:] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
v_cache = torch.cat([v_cache, v], dim=1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k_cache.shape[1]
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: list[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
|
||||
k_cache : list[torch.Tensor] = []
|
||||
v_cache : list[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
|
||||
k_cache.append(k_cache_)
|
||||
v_cache.append(v_cache_)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
k_cache: list[torch.Tensor],
|
||||
v_cache: list[torch.Tensor]):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
return x, k_cache, v_cache
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
self.vq_model = SynthesizerTrn(
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
|
||||
def forward(self, text_seq, pred_semantic, ref_audio):
|
||||
refer = spectrogram_torch(
|
||||
ref_audio,
|
||||
self.hps.data.filter_length,
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
def __init__(self,raw_t2s:Text2SemanticLightningModule):
|
||||
super(T2SModel, self).__init__()
|
||||
self.model_dim = raw_t2s.model.model_dim
|
||||
self.embedding_dim = raw_t2s.model.embedding_dim
|
||||
self.num_head = raw_t2s.model.num_head
|
||||
self.num_layers = raw_t2s.model.num_layers
|
||||
self.vocab_size = raw_t2s.model.vocab_size
|
||||
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
|
||||
# self.p_dropout = float(raw_t2s.model.p_dropout)
|
||||
self.EOS:int = int(raw_t2s.model.EOS)
|
||||
self.norm_first = raw_t2s.model.norm_first
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
self.hz = 50
|
||||
|
||||
self.bert_proj = raw_t2s.model.bert_proj
|
||||
self.ar_text_embedding = raw_t2s.model.ar_text_embedding
|
||||
self.ar_text_position = raw_t2s.model.ar_text_position
|
||||
self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
|
||||
self.ar_audio_position = raw_t2s.model.ar_audio_position
|
||||
|
||||
# self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
# self.t2s_transformer = raw_t2s.model.t2s_transformer
|
||||
|
||||
blocks = []
|
||||
h = raw_t2s.model.h
|
||||
|
||||
for i in range(self.num_layers):
|
||||
layer = h.layers[i]
|
||||
t2smlp = T2SMLP(
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
)
|
||||
|
||||
block = T2SBlock(
|
||||
self.num_head,
|
||||
self.model_dim,
|
||||
t2smlp,
|
||||
layer.self_attn.in_proj_weight,
|
||||
layer.self_attn.in_proj_bias,
|
||||
layer.self_attn.out_proj.weight,
|
||||
layer.self_attn.out_proj.bias,
|
||||
layer.norm1.weight,
|
||||
layer.norm1.bias,
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
|
||||
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
||||
self.ar_predict_layer = raw_t2s.model.ar_predict_layer
|
||||
# self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
|
||||
self.max_sec = raw_t2s.config["data"]["max_sec"]
|
||||
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||||
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||
|
||||
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor):
|
||||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||
bert = bert.unsqueeze(0)
|
||||
|
||||
x = self.ar_text_embedding(all_phoneme_ids)
|
||||
x = x + self.bert_proj(bert.transpose(1, 2))
|
||||
x:torch.Tensor = self.ar_text_position(x)
|
||||
|
||||
early_stop_num = self.early_stop_num
|
||||
|
||||
|
||||
#[1,N,512] [1,N]
|
||||
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
y = prompts
|
||||
# x_example = x[:,:,0] * 0.0
|
||||
|
||||
x_len = x.shape[1]
|
||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
x_attn_mask_pad = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
|
||||
idx = 0
|
||||
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
logits = logits[:, :-1]
|
||||
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
|
||||
stop = False
|
||||
# for idx in range(1, 50):
|
||||
for idx in range(1, 1500):
|
||||
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
stop = True
|
||||
if stop:
|
||||
if y.shape[1] == 0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
break
|
||||
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
bert_path = os.environ.get(
|
||||
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
)
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
|
||||
@torch.jit.script
|
||||
def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
||||
phone_level_feature = []
|
||||
for i in range(word2ph.shape[0]):
|
||||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
# [sum(word2ph), 1024]
|
||||
return phone_level_feature
|
||||
|
||||
class MyBertModel(torch.nn.Module):
|
||||
def __init__(self, bert_model):
|
||||
super(MyBertModel, self).__init__()
|
||||
self.bert = bert_model
|
||||
|
||||
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||
res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
||||
return build_phone_level_feature(res, word2ph)
|
||||
|
||||
class SSLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ssl = cnhubert.get_model().model
|
||||
|
||||
def forward(self, ref_audio_16k)-> torch.Tensor:
|
||||
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
return ssl_content
|
||||
|
||||
class ExportSSLModel(torch.nn.Module):
|
||||
def __init__(self,ssl:SSLModel):
|
||||
super().__init__()
|
||||
self.ssl = ssl
|
||||
|
||||
def forward(self, ref_audio:torch.Tensor):
|
||||
return self.ssl(ref_audio)
|
||||
|
||||
@torch.jit.export
|
||||
def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
|
||||
audio = resamplex(ref_audio,src_sr,dst_sr).float()
|
||||
return audio
|
||||
|
||||
def export_bert(ref_bert_inputs):
|
||||
ref_bert_inputs = {
|
||||
'input_ids': ref_bert_inputs['input_ids'],
|
||||
'attention_mask': ref_bert_inputs['attention_mask'],
|
||||
'token_type_ids': ref_bert_inputs['token_type_ids'],
|
||||
'word2ph': ref_bert_inputs['word2ph']
|
||||
}
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
|
||||
my_bert_model = MyBertModel(bert_model)
|
||||
|
||||
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
|
||||
my_bert_model.save("onnx/bert_model.pt")
|
||||
print('#### exported bert ####')
|
||||
|
||||
def export(gpt_path, vits_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
|
||||
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
|
||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
|
||||
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
|
||||
|
||||
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
|
||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
|
||||
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
|
||||
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
|
||||
|
||||
bert = MyBertModel(bert_model)
|
||||
|
||||
# export_bert(ref_bert_inputs)
|
||||
|
||||
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
|
||||
torch.jit.script(s).save("onnx/xw/ssl_model.pt")
|
||||
print('#### exported ssl ####')
|
||||
|
||||
ref_bert = bert(**ref_bert_inputs)
|
||||
text_bert = bert(**text_berf_inputs)
|
||||
ssl_content = ssl(ref_audio)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path)
|
||||
vits.eval()
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||
raw_t2s = get_raw_t2s_model(dict_s1)
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
t2s = torch.jit.script(t2s_m)
|
||||
print('#### script t2s_m ####')
|
||||
|
||||
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||
gpt_sovits.eval()
|
||||
ref_audio_sr = s.resample(ref_audio,16000,32000)
|
||||
print('ref_audio_sr:',ref_audio_sr.shape)
|
||||
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
example_inputs=(
|
||||
ssl_content,
|
||||
ref_audio_sr,
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert),
|
||||
check_trace=False) # 默认是True 但是 check 的时候可能是随机生成的一个奇怪维度的值,导致报错
|
||||
|
||||
gpt_sovits_export.save("onnx/xw/gpt_sovits_model.pt")
|
||||
print('#### exported gpt_sovits ####')
|
||||
|
||||
@torch.jit.script
|
||||
def parse_audio(ref_audio):
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()#.to(ref_audio.device)
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device)
|
||||
return ref_audio_16k,ref_audio_sr
|
||||
|
||||
@torch.jit.script
|
||||
def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
|
||||
return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float()
|
||||
|
||||
class GPT_SoVITS(nn.Module):
|
||||
def __init__(self, t2s:T2SModel,vits:VitsModel):
|
||||
super().__init__()
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
|
||||
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor):
|
||||
codes = self.vits.vq_model.extract_latent(ssl_content.float())
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompts = prompt_semantic.unsqueeze(0)
|
||||
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr)
|
||||
return audio
|
||||
|
||||
def test(gpt_path, vits_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
|
||||
bert = MyBertModel(bert_model)
|
||||
# bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||
raw_t2s = get_raw_t2s_model(dict_s1)
|
||||
t2s = T2SModel(raw_t2s)
|
||||
t2s.eval()
|
||||
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path)
|
||||
vits.eval()
|
||||
|
||||
ssl = ExportSSLModel(SSLModel())
|
||||
ssl.eval()
|
||||
|
||||
gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||
|
||||
# vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda')
|
||||
# ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda')
|
||||
|
||||
|
||||
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
|
||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
|
||||
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
|
||||
|
||||
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
|
||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
|
||||
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
|
||||
|
||||
ref_bert = bert(
|
||||
ref_bert_inputs['input_ids'],
|
||||
ref_bert_inputs['attention_mask'],
|
||||
ref_bert_inputs['token_type_ids'],
|
||||
ref_bert_inputs['word2ph']
|
||||
)
|
||||
|
||||
text_bert = bert(text_berf_inputs['input_ids'],
|
||||
text_berf_inputs['attention_mask'],
|
||||
text_berf_inputs['token_type_ids'],
|
||||
text_berf_inputs['word2ph'])
|
||||
|
||||
#[1,N]
|
||||
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
|
||||
print('ref_audio:',ref_audio.shape)
|
||||
|
||||
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
|
||||
print('start ssl')
|
||||
ssl_content = ssl(ref_audio)
|
||||
|
||||
print('start gpt_sovits:')
|
||||
with torch.no_grad():
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert)
|
||||
print('start write wav')
|
||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||
|
||||
# audio = vits(text_seq, pred_semantic1, ref_audio)
|
||||
# soundfile.write("out.wav", audio, 32000)
|
||||
|
||||
import text
|
||||
import json
|
||||
|
||||
def export_symbel(version='v2'):
|
||||
if version=='v1':
|
||||
symbols = text._symbol_to_id_v1
|
||||
with open(f"onnx/symbols_v1.json", "w") as file:
|
||||
json.dump(symbols, file, indent=4)
|
||||
else:
|
||||
symbols = text._symbol_to_id_v2
|
||||
with open(f"onnx/symbols_v2.json", "w") as file:
|
||||
json.dump(symbols, file, indent=4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
|
||||
# test(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
|
||||
# export_symbel()
|
@ -4,8 +4,8 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from module import commons
|
||||
from module.modules import LayerNorm
|
||||
|
||||
from typing import Optional
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
@ -59,6 +59,7 @@ class Encoder(nn.Module):
|
||||
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
||||
# self.gin_channels = 256
|
||||
self.cond_layer_idx = self.n_layers
|
||||
self.spk_emb_linear = nn.Linear(256, self.hidden_channels)
|
||||
if "gin_channels" in kwargs:
|
||||
self.gin_channels = kwargs["gin_channels"]
|
||||
if self.gin_channels != 0:
|
||||
@ -98,22 +99,36 @@ class Encoder(nn.Module):
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
# def forward(self, x, x_mask, g=None):
|
||||
# attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
# x = x * x_mask
|
||||
# for i in range(self.n_layers):
|
||||
# if i == self.cond_layer_idx and g is not None:
|
||||
# g = self.spk_emb_linear(g.transpose(1, 2))
|
||||
# g = g.transpose(1, 2)
|
||||
# x = x + g
|
||||
# x = x * x_mask
|
||||
# y = self.attn_layers[i](x, x, attn_mask)
|
||||
# y = self.drop(y)
|
||||
# x = self.norm_layers_1[i](x + y)
|
||||
|
||||
# y = self.ffn_layers[i](x, x_mask)
|
||||
# y = self.drop(y)
|
||||
# x = self.norm_layers_2[i](x + y)
|
||||
# x = x * x_mask
|
||||
# return x
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
if i == self.cond_layer_idx and g is not None:
|
||||
g = self.spk_emb_linear(g.transpose(1, 2))
|
||||
g = g.transpose(1, 2)
|
||||
x = x + g
|
||||
x = x * x_mask
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
|
||||
y = attn_layers(x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
x = norm_layers_1(x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = ffn_layers(x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = norm_layers_2(x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
@ -172,17 +187,18 @@ class MultiHeadAttention(nn.Module):
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
# x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
x, _ = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, _ = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||
@ -304,7 +320,7 @@ class FFN(nn.Module):
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=0.0,
|
||||
activation=None,
|
||||
activation="",
|
||||
causal=False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -316,10 +332,11 @@ class FFN(nn.Module):
|
||||
self.activation = activation
|
||||
self.causal = causal
|
||||
|
||||
if causal:
|
||||
self.padding = self._causal_padding
|
||||
else:
|
||||
self.padding = self._same_padding
|
||||
# 从上下文看这里一定是 False
|
||||
# if causal:
|
||||
# self.padding = self._causal_padding
|
||||
# else:
|
||||
# self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
@ -334,6 +351,9 @@ class FFN(nn.Module):
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def padding(self, x):
|
||||
return self._same_padding(x)
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
@ -352,3 +372,35 @@ class FFN(nn.Module):
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
|
||||
class MRTE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
content_enc_channels=192,
|
||||
hidden_size=512,
|
||||
out_channels=192,
|
||||
kernel_size=5,
|
||||
n_heads=4,
|
||||
ge_layer=2,
|
||||
):
|
||||
super(MRTE, self).__init__()
|
||||
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
|
||||
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
||||
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
||||
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
|
||||
|
||||
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge):
|
||||
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
|
||||
|
||||
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
@ -13,10 +13,10 @@ def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
# def convert_pad_shape(pad_shape):
|
||||
# l = pad_shape[::-1]
|
||||
# pad_shape = [item for sublist in l for item in sublist]
|
||||
# return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@ -11,7 +12,6 @@ from module import attentions_onnx as attentions
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.mrte_model import MRTE
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
@ -218,7 +218,7 @@ class TextEncoder(nn.Module):
|
||||
symbols = symbols_v2.symbols
|
||||
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
||||
|
||||
self.mrte = MRTE()
|
||||
self.mrte = attentions.MRTE()
|
||||
|
||||
self.encoder2 = attentions.Encoder(
|
||||
hidden_channels,
|
||||
@ -249,25 +249,6 @@ class TextEncoder(nn.Module):
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return y, m, logs, y_mask
|
||||
|
||||
def extract_latent(self, x):
|
||||
x = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
|
||||
y = self.vq_proj(quantized) * y_mask
|
||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||
|
||||
y = self.mrte(y, y_mask, refer, refer_mask, ge)
|
||||
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
|
||||
stats = self.proj(y) * y_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return y, m, logs, y_mask, quantized
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(
|
||||
@ -448,7 +429,7 @@ class Generator(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
def forward(self, x, g:Optional[torch.Tensor]=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
@ -870,15 +851,15 @@ class SynthesizerTrn(nn.Module):
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
5,
|
||||
1,
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
# self.enc_q = PosteriorEncoder(
|
||||
# spec_channels,
|
||||
# inter_channels,
|
||||
# hidden_channels,
|
||||
# 5,
|
||||
# 1,
|
||||
# 16,
|
||||
# gin_channels=gin_channels,
|
||||
# )
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user