Merge a889187b84da1ed91b7d99d2a317f930926374f8 into 9da7e17efe05041e31d3c3f42c8730ae890397f2

This commit is contained in:
Ναρουσέ·μ·γιουμεμί·Χινακάννα 2025-04-03 16:43:23 +08:00 committed by GitHub
commit 46b56cdc1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1082 additions and 591 deletions

View File

@ -1,5 +1,4 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys import os, sys
now_dir = os.getcwd() now_dir = os.getcwd()

View File

@ -1,81 +1,88 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
import math
from typing import List, Optional
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from AR.modules.embedding_onnx import SinePositionalEmbedding from AR.models.utils import (
from AR.modules.embedding_onnx import TokenEmbedding sample,
from AR.modules.transformer_onnx import LayerNorm )
from AR.modules.transformer_onnx import TransformerEncoder from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer_onnx import TransformerEncoderLayer from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
from torch.distributions import Exponential
ISONNXEXPORT = False
default_config = { default_config = {
"embedding_dim": 512, "model": {
"hidden_dim": 512, "vocab_size": 1025,
"num_head": 8,
"num_layers": 12,
"num_codebook": 8,
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512, "phoneme_vocab_size": 512,
"embedding_dim": 1024,
"hidden_dim": 1024,
"head": 16,
"linear_units": 2048,
"n_layer": 16,
"dropout": 0,
"EOS": 1024, "EOS": 1024,
}
} }
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float() def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = -torch.log(torch.rand_like(probs_sort)) #https://github.com/RVC-Boss/GPT-SoVITS/pull/835
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.long)
def logits_to_probs( def logits_to_probs(
logits, logits,
previous_tokens = None, previous_tokens: torch.Tensor,
temperature: float = 1.0, temperature: torch.Tensor,
top_k = None, top_k: torch.Tensor,
top_p = None, top_p: torch.Tensor,
repetition_penalty: float = 1.0, repetition_penalty: torch.Tensor
): ):
previous_tokens = previous_tokens.squeeze() # if previous_tokens is not None:
if previous_tokens is not None and repetition_penalty != 1.0: # previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
previous_tokens = previous_tokens.long() previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens) score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where( score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty score < 0, score * repetition_penalty, score / repetition_penalty
) )
logits.scatter_(dim=0, index=previous_tokens, src=score) logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum( cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
) )
sorted_indices_to_remove = cum_probs > top_p sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove dim=1, index=sorted_indices, src=sorted_indices_to_remove
) )
logits = logits.masked_fill(indices_to_remove, -float("Inf")) logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5) logits = logits / torch.clamp_min(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k) v, _ = torch.topk(logits, top_k)
pivot = v.select(-1, -1).unsqueeze(-1) pivot = v[: , -1].unsqueeze(-1)
logits = torch.where(logits < pivot, inf_tensor_value, logits) logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
return probs return probs
def multinomial_sample_one_no_sync(
probs_sort
): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def sample( def sample(
logits, logits,
previous_tokens, previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs, **sampling_kwargs,
): ):
probs = logits_to_probs( probs = logits_to_probs(
@ -84,125 +91,326 @@ def sample(
idx_next = multinomial_sample_one_no_sync(probs) idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs return idx_next, probs
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
scale_factor = scale
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
class OnnxEncoder(nn.Module): if attn_mask is not None:
def __init__(self, ar_text_embedding, bert_proj, ar_text_position): if attn_mask.dtype == torch.bool:
attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
else:
attn_bias = attn_bias + attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_weight = attn_weight.masked_fill(attn_mask, 0)
else:
attn_mask = attn_mask.clone()
attn_mask[attn_mask!=float("-inf")] =0
attn_mask[attn_mask==float("-inf")] =1
attn_weight = attn_weight.masked_fill(attn_mask, 0)
return attn_weight @ value
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
self.w1 = w1
self.b1 = b1
self.w2 = w2
self.b2 = b2
def forward(self, x):
x = F.relu(F.linear(x, self.w1, self.b1))
x = F.linear(x, self.w2, self.b2)
return x
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w
self.qkv_b = qkv_b
self.out_w = out_w
self.out_b = out_b
self.norm_w1 = norm_w1
self.norm_b1 = norm_b1
self.norm_eps1 = norm_eps1
self.norm_w2 = norm_w2
self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask)
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
else:
attn = scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x, k_cache, v_cache):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,
padding_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_)
v_cache.append(v_cache_)
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
k_cache,
v_cache,
):
K_Cache = []
V_Cache = []
for i in range(self.num_blocks):
x, k, v = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
K_Cache.append(k)
V_Cache.append(v)
K_Cache = torch.stack(K_Cache, dim=0)
V_Cache = torch.stack(V_Cache, dim=0)
return x, K_Cache, V_Cache
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__() super().__init__()
self.ar_text_embedding = ar_text_embedding self.embedding_dim = embedding_dim
self.bert_proj = bert_proj self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.ar_text_position = ar_text_position self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = torch.nn.Dropout(p=dropout)
def forward(self, x, bert_feature): self.reverse = False
x = self.ar_text_embedding(x) self.pe = None
x = x + self.bert_proj(bert_feature.transpose(1, 2)) self.extend_pe(torch.tensor(0.0).expand(1, 114514))
return self.ar_text_position(x)
def extend_pe(self, x):
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
def forward(self, x: torch.Tensor, x_size) -> torch.Tensor:
output = x.unsqueeze(-1) if x.ndim == 2 else x
output[:,:x_size,:] = output[:,:x_size,:] * self.x_scale + self.alpha * self.pe[:, : x_size]
return self.dropout(output)
class T2SFirstStageDecoder(nn.Module): class PromptProcessor(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric, def __init__(self, cache_len, model, top_k):
top_k, early_stop_num, num_layers): super(PromptProcessor, self).__init__()
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
self.h = h
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.top_k = top_k self.top_k = top_k
self.early_stop_num = early_stop_num self.model = model
self.num_layers = num_layers self.ar_text_embedding = model.ar_text_embedding
self.ar_text_position = model.ar_text_position
self.ar_audio_embedding = model.ar_audio_embedding
self.ar_audio_position = model.ar_audio_position
self.bert_proj = model.bert_proj
cache_len = torch.tensor(cache_len)
self.register_buffer("cache_len", cache_len, persistent=False)
def forward(self, x, prompt): def forward(self, x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature):
y = prompt bsz = x.size(0)
x_example = x[:,:,0] * 0.0 src_len = x_len + y_len
#N, 1, 512
cache = { x_emb = self.ar_text_embedding(x)
"all_stage": self.num_layers, x_emb = x_emb + self.bert_proj(bert_feature)
"k": None, x_pos = self.ar_text_position(x_emb, x_len)
"v": None, x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
"y_emb": None,
"first_infer": 1,
"stage": 0,
}
y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb, y_len)
y_attn_mask = F.pad(torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),(x_len, 0),value=False)
cache["y_emb"] = y_emb xy_pos = torch.concat([x_pos, y_pos], dim=1)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1) x_attn_mask_pad = F.pad(x_attn_mask,(0, y_len),value=True)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).unsqueeze(0)\
.expand(bsz * self.model.num_head, -1, -1)\
.view(bsz, self.model.num_head, src_len, src_len)\
.to(device=x.device, dtype=torch.bool)
y_example = y_pos[:,:,0] * 0.0 xy_dec, k_cache, v_cache = self.model.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64) logits = self.model.ar_predict_layer(
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( xy_dec[:, -1]
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
) )
y_attn_mask = y_attn_mask > 0
x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(
logits, y, top_k=self.top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
return y, cache["k"], cache["v"], cache["y_emb"], x_example y_emb = self.ar_audio_embedding(samples)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len].to(dtype=y_emb.dtype,device=y_emb.device)
k_cache = torch.stack(k_cache, dim=0)
v_cache = torch.stack(v_cache, dim=0)
return y, k_cache, v_cache, xy_pos, y_len + 1, samples
class T2SStageDecoder(nn.Module): class DecodeNextToken(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric, def __init__(self, cache_len, model, top_k):
top_k, early_stop_num, num_layers): super(DecodeNextToken, self).__init__()
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
self.h = h
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.top_k = top_k self.top_k = top_k
self.early_stop_num = early_stop_num self.model = model
self.num_layers = num_layers self.ar_text_embedding = model.ar_text_embedding
self.ar_text_position = model.ar_text_position
self.ar_audio_embedding = model.ar_audio_embedding
self.ar_audio_position = model.ar_audio_position
cache_len = torch.tensor(cache_len)
self.register_buffer("cache_len", cache_len, persistent=False)
def forward(self, y, k, v, y_emb, x_example): def forward(self, y, k_cache, v_cache, xy_pos, y_idx, top_p, repetition_penalty, temperature):
cache = { xy_dec, k_cache, v_cache = self.model.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
"all_stage": self.num_layers, logits = self.model.ar_predict_layer(
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), xy_dec[:, -1]
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
"y_emb": y_emb,
"first_infer": 0,
"stage": 0,
}
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
) )
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
y_example = y_pos[:,:,0] * 0.0
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(
logits, y, top_k=self.top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
return y, cache["k"], cache["v"], cache["y_emb"], logits, samples y_emb = self.ar_audio_embedding(samples)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_idx].to(dtype=y_emb.dtype,device=y_emb.device)
return y, k_cache, v_cache, xy_pos, y_idx + 1, samples
class Text2SemanticDecoder(nn.Module): class Text2SemanticDecoder(nn.Module):
@ -215,15 +423,24 @@ class Text2SemanticDecoder(nn.Module):
self.norm_first = norm_first self.norm_first = norm_first
self.vocab_size = config["model"]["vocab_size"] self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = float(config["model"]["dropout"]) self.p_dropout = config["model"]["dropout"]
self.EOS = config["model"]["EOS"] self.EOS = config["model"]["EOS"]
self.norm_first = norm_first self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1 assert self.EOS == self.vocab_size - 1
self.bert_proj = nn.Linear(1024, self.embedding_dim) self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) self.ar_text_embedding = TokenEmbedding(
self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True) self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout) )
self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True) self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.h = TransformerEncoder( self.h = TransformerEncoder(
TransformerEncoderLayer( TransformerEncoderLayer(
d_model=self.model_dim, d_model=self.model_dim,
@ -236,8 +453,10 @@ class Text2SemanticDecoder(nn.Module):
num_layers=self.num_layers, num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None, norm=LayerNorm(self.model_dim) if norm_first else None,
) )
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum") self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy( self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size, self.vocab_size,
top_k=top_k, top_k=top_k,
@ -245,37 +464,126 @@ class Text2SemanticDecoder(nn.Module):
multidim_average="global", multidim_average="global",
ignore_index=self.EOS, ignore_index=self.EOS,
) )
self.top_k = torch.LongTensor([1])
self.early_stop_num = torch.LongTensor([-1])
def init_onnx(self): blocks = []
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
def forward(self, x, prompts, bert_feature): for i in range(self.num_layers):
early_stop_num = self.early_stop_num layer = self.h.layers[i]
prefix_len = prompts.shape[1] t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)
x = self.onnx_encoder(x, bert_feature) block = T2SBlock(
y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts) self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def infer_panel_naive(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False stop = False
for idx in range(1, 1500): k_cache = None
enco = self.stage_decoder(y, k, v, y_emb, stage, x_example) v_cache = None
y, k, v, y_emb, stage, logits, samples = enco
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
bsz = x.shape[0]
src_len = x_len + y_len
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
.unsqueeze(0)\
.expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
.to(device=x.device, dtype=torch.bool)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
if idx == 0:
xy_attn_mask = None
if(idx<11):###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True stop = True
if stop: if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break break
y[0, -1] = 0
return y, idx ####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
return y[:, :-1], idx
def infer(self, x, prompts, bert_feature): def infer(self, x, prompts, bert_feature):
top_k = self.top_k top_k = self.top_k

View File

@ -48,16 +48,17 @@ class SinePositionalEmbedding(nn.Module):
self.dropout = torch.nn.Dropout(p=dropout) self.dropout = torch.nn.Dropout(p=dropout)
self.reverse = False self.reverse = False
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim)) self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
self.pe = self.extend_pe(2000)
def extend_pe(self, x): def extend_pe(self, x):
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1) position = torch.cumsum(torch.ones((x,1)), dim=0)
scpe = (position * self.div_term).unsqueeze(0) scpe = (position * self.div_term).unsqueeze(0)
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0) pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
pe = pe.contiguous().view(1, -1, self.embedding_dim) pe = pe.contiguous().view(1, -1, self.embedding_dim)
return pe return pe
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
pe = self.extend_pe(x) pe = self.pe[:,:x.size(1),:]
output = x.unsqueeze(-1) if x.ndim == 2 else x output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * pe output = output * self.x_scale + self.alpha * pe
return self.dropout(output) return self.dropout(output)

306
GPT_SoVITS/export_onnx.py Normal file
View File

@ -0,0 +1,306 @@
import os
import json
import onnx
import torch
import onnxsim
from torch.nn import Module
from feature_extractor import cnhubert
from onnxruntime import InferenceSession
from pytorch_lightning import LightningModule
from transformers import AutoTokenizer, AutoModelForMaskedLM
import AR.models.t2s_model_onnx as t2s
from module.models_onnx import SynthesizerTrn
root_path = os.path.dirname(os.path.abspath(__file__))
onnx_path = os.path.join(root_path, "onnx")
if not os.path.exists(onnx_path):
os.makedirs(onnx_path)
class BertWrapper(Module):
def __init__(self):
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
super(BertWrapper, self).__init__()
self.model = AutoModelForMaskedLM.from_pretrained(bert_path)
self.tokenizer = AutoTokenizer.from_pretrained(bert_path)
def forward(self, input_ids):
attention_mask = torch.ones_like(input_ids)
token_type_ids = torch.zeros_like(input_ids)
res = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
return torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
def export_onnx(self):
vocab_dict = { k: v for k, v in self.tokenizer.get_vocab().items() }
vocab_path = os.path.join(onnx_path, "Vocab.json")
with open(vocab_path, "w") as f:
json.dump(vocab_dict, f, indent=4)
dummy_input = torch.randint(0, 100, (1, 20)).long()
torch.onnx.export(
self,
dummy_input,
os.path.join(onnx_path, "Bert.onnx"),
input_names=["input_ids"],
output_names=["output"],
dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}},
opset_version=18,
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "Bert.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "Bert.onnx"))
print("Exported BERT to ONNX format.")
class CnHubertWrapper(Module):
def __init__(self):
super(CnHubertWrapper, self).__init__()
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
)
cnhubert.cnhubert_base_path = cnhubert_base_path
self.model = cnhubert.get_model().model
def forward(self, signal):
return self.model(signal)["last_hidden_state"]
def export_onnx(self):
dummy_input = torch.randn(1, 16000 * 10)
torch.onnx.export(
self,
dummy_input,
os.path.join(onnx_path, "CnHubert.onnx"),
input_names=["signal"],
output_names=["output"],
dynamic_axes={"signal": {0: "batch_size", 1: "sequence_length"}},
opset_version=18,
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "CnHubert.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "CnHubert.onnx"))
print("Exported CN-Hubert to ONNX format.")
class Text2SemanticLightningModule(LightningModule):
def __init__(self, path, top_k=20, cache_size=2000):
super().__init__()
dict_s1 = torch.load(path, map_location="cpu")
config = dict_s1["config"]
self.model = t2s.Text2SemanticDecoder(config=config)
self.load_state_dict(dict_s1["weight"])
self.cache_size = cache_size
self.top_k = top_k
def export_ar(path, top_k=20, cache_size=2000):
model_l = Text2SemanticLightningModule(path, top_k=top_k, cache_size=cache_size)
model = model_l.model
x = torch.randint(0, 100, (1, 20)).long()
x_len = torch.tensor([20]).long()
y = torch.randint(0, 100, (1, 20)).long()
y_len = torch.tensor([20]).long()
bert_feature = torch.randn(1, 20, 1024)
top_p = torch.tensor([0.8])
repetition_penalty = torch.tensor([1.35])
temperature = torch.tensor([0.6])
prompt_processor = t2s.PromptProcessor(cache_len=cache_size, model=model, top_k=top_k)
decode_next_token = t2s.DecodeNextToken(cache_len=cache_size, model=model, top_k=top_k)
torch.onnx.export(
prompt_processor,
(x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature),
os.path.join(onnx_path, "PromptProcessor.onnx"),
input_names=["x", "x_len", "y", "y_len", "bert_feature", "top_p", "repetition_penalty", "temperature"],
output_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "samples"],
dynamic_axes={
"x": {0: "batch_size", 1: "sequence_length"},
"y": {0: "batch_size", 1: "sequence_length"},
"bert_feature": {0: "batch_size", 1: "sequence_length"},
},
opset_version=18,
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "PromptProcessor.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "PromptProcessor.onnx"))
y, k_cache, v_cache, xy_pos, y_idx, samples = prompt_processor(
x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature
)
torch.onnx.export(
decode_next_token,
(y, k_cache, v_cache, xy_pos, y_idx, top_p, repetition_penalty, temperature),
os.path.join(onnx_path, "DecodeNextToken.onnx"),
input_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "top_p", "repetition_penalty", "temperature"],
output_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "samples"],
dynamic_axes={
"y": {0: "batch_size", 1: "sequence_length"},
"k_cache": {1: "batch_size", 2: "sequence_length"},
"v_cache": {1: "batch_size", 2: "sequence_length"},
},
opset_version=18
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "DecodeNextToken.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "DecodeNextToken.onnx"))
from io import BytesIO
def load_sovits_new(sovits_path):
f=open(sovits_path,"rb")
meta=f.read(2)
if meta!="PK":
data = b'PK' + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path,map_location="cpu", weights_only=False)
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class Extractor(Module):
def __init__(self, model):
super(Extractor, self).__init__()
self.model = model
def forward(self, x):
return self.model.extract_latent(x.transpose(1, 2))
class V1V2(Module):
def __init__(self, path):
super(V1V2, self).__init__()
dict_s2 = load_sovits_new(path)
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps.model.version = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
version=hps.model.version
# print("sovits版本:",hps.model.version)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.eval()
self.vq_model = vq_model
self.hps = hps
self.ext = Extractor(self.vq_model)
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
)
return self.vq_model(pred_semantic.unsqueeze(0), text_seq, refer)[0, 0]
def export(self):
test_seq = torch.randint(0, 100, (1, 20)).long()
pred_semantic = torch.randint(0, 100, (1, 20)).long()
ref_audio = torch.randn(1, 16000 * 10)
torch.onnx.export(
self,
(test_seq, pred_semantic, ref_audio),
os.path.join(onnx_path, "GptSoVitsV1V2.onnx"),
input_names=["text_seq", "pred_semantic", "ref_audio"],
output_names=["output"],
dynamic_axes={
"text_seq": {0: "batch_size", 1: "sequence_length"},
"pred_semantic": {0: "batch_size", 1: "sequence_length"},
"ref_audio": {0: "batch_size", 1: "sequence_length"},
},
opset_version=18,
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "GptSoVitsV1V2.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "GptSoVitsV1V2.onnx"))
ref_units = torch.randn(1, 20, 768)
torch.onnx.export(
self.ext,
ref_units,
os.path.join(onnx_path, "Extractor.onnx"),
input_names=["ref_units"],
output_names=["output"],
dynamic_axes={
"ref_units": {0: "batch_size", 1: "sequence_length"},
},
opset_version=18,
)
if __name__ == "__main__":
#CnHubertWrapper().export_onnx()
#BertWrapper().export_onnx()
V1V2("D:\\VSGIT\GPT-SoVITS-main\\GPT_SoVITS\\GPT-SoVITS-v3lora-20250228\\GPT_SoVITS\\t\\SoVITS_weights\\小特.pth").export()
'''export_ar(
"D:\\VSGIT\GPT-SoVITS-main\\GPT_SoVITS\\GPT-SoVITS-v3lora-20250228\\GPT_SoVITS\\t\\GPT_weights\\小特.ckpt",
top_k=10,
cache_size=1500,
)'''

View File

@ -1,24 +1,28 @@
import warnings
warnings.filterwarnings("ignore")
import copy import copy
import math import math
from typing import Optional import os
import pdb
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from module import commons
from module import modules from module import modules
from module import attentions_onnx as attentions from module import attentions
#from f5_tts.model.backbones.dit import DiT
from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
import contextlib,random
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
@ -186,7 +190,7 @@ class TextEncoder(nn.Module):
kernel_size, kernel_size,
p_dropout, p_dropout,
latent_channels=192, latent_channels=192,
version="v2", version = "v2",
): ):
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
@ -220,7 +224,7 @@ class TextEncoder(nn.Module):
symbols = symbols_v2.symbols symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = attentions.MRTE() self.mrte = MRTE()
self.encoder2 = attentions.Encoder( self.encoder2 = attentions.Encoder(
hidden_channels, hidden_channels,
@ -233,7 +237,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge, speed=1): def forward(self, y, text, ge, speed=1,test=None):
y_mask = torch.ones_like(y[:1,:1,:]) y_mask = torch.ones_like(y[:1,:1,:])
y = self.ssl_proj(y * y_mask) * y_mask y = self.ssl_proj(y * y_mask) * y_mask
@ -254,6 +258,25 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask return y, m, logs, y_mask
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
y = self.mrte(y, y_mask, refer, refer_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module): class ResidualCouplingBlock(nn.Module):
def __init__( def __init__(
@ -465,7 +488,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g:Optional[torch.Tensor]=None): def forward(self, x, g=None):
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: if g is not None:
x = x + self.cond(g) x = x + self.cond(g)
@ -923,7 +946,7 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer,noise_scale=0.5, speed=1): def forward(self, codes, text, refer, noise_scale=0.5):
refer_mask = torch.ones_like(refer[:1,:1,:]) refer_mask = torch.ones_like(refer[:1,:1,:])
if (self.version == "v1"): if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
@ -935,79 +958,98 @@ class SynthesizerTrn(nn.Module):
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p( _, m_p, logs_p, y_mask = self.enc_p(
quantized, text, ge, speed quantized, text, ge
) )
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge) o = self.dec((z * y_mask)[:, :, :], g=ge)
return o return o
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) _, codes, _, _ = self.quantizer(ssl)
return codes.transpose(0, 1) return codes.transpose(0, 1)
class CFM(torch.nn.Module): class CFM(torch.nn.Module):
def __init__( def __init__(
self, self,
in_channels,dit in_channels,dit
): ):
super().__init__() super().__init__()
# self.sigma_min = 1e-6 self.sigma_min = 1e-6
self.estimator = dit self.estimator = dit
self.in_channels = in_channels self.in_channels = in_channels
# self.criterion = torch.nn.MSELoss() self.criterion = torch.nn.MSELoss()
def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0): @torch.inference_mode()
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
"""Forward diffusion""" """Forward diffusion"""
B, T = mu.size(0), mu.size(1) B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) * temperature
ntimesteps = int(n_timesteps)
prompt_len = prompt.size(-1) prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x,dtype=mu.dtype) prompt_x = torch.zeros_like(x,dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len] prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0.0 x[..., :prompt_len] = 0
mu=mu.transpose(2,1) mu=mu.transpose(2,1)
t = torch.tensor(0.0,dtype=x.dtype,device=x.device) t = 0
d = torch.tensor(1.0/ntimesteps,dtype=x.dtype,device=x.device) d = 1 / n_timesteps
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d for j in range(n_timesteps):
for j in range(ntimesteps):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args) # v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu).transpose(2, 1) v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
# if inference_cfg_rate>1e-5: if inference_cfg_rate>1e-5:
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1) neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
x = x + d * v_pred x = x + d * v_pred
t = t + d t = t + d
x[:, :, :prompt_len] = 0.0 x[:, :, :prompt_len] = 0
return x return x
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
b, _, t = x1.shape
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
x0 = torch.randn_like(x1,device=mu.device)
vt = x1 - x0
xt = x0 + t[:, None, None] * vt
dt = torch.zeros_like(t,device=mu.device)
prompt = torch.zeros_like(x1)
for i in range(b):
prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]]
xt[i, :, :prompt_lens[i]] = 0
gailv=0.3# if ttime()>1736250488 else 0.1
if random.random() < gailv:
base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
d = 1/torch.pow(2, base)
d_input = d.clone()
d_input[d_input < 1e-2] = 0
# with torch.no_grad():
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
x_mid = xt + d[:, None, None] * v_pred_1
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
vt = (v_pred_1 + v_pred_2) / 2
vt = vt.detach()
dt = 2*d
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
loss = 0
for i in range(b):
loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]])
loss /= b
return loss
def set_no_grad(net_g): def set_no_grad(net_g):
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
param.requires_grad=False param.requires_grad=False
@torch.jit.script_if_tracing
def compile_codes_length(codes):
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
return y_lengths1 * 2.5 * 1.5
@torch.jit.script_if_tracing
def compile_ref_length(refer):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
return refer_lengths
class SynthesizerTrnV3(nn.Module): class SynthesizerTrnV3(nn.Module):
""" """
@ -1035,7 +1077,6 @@ class SynthesizerTrnV3(nn.Module):
use_sdp=True, use_sdp=True,
semantic_frame_rate=None, semantic_frame_rate=None,
freeze_quantizer=None, freeze_quantizer=None,
version="v3",
**kwargs): **kwargs):
super().__init__() super().__init__()
@ -1056,7 +1097,6 @@ class SynthesizerTrnV3(nn.Module):
self.segment_size = segment_size self.segment_size = segment_size
self.n_speakers = n_speakers self.n_speakers = n_speakers
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.version = version
self.model_dim=512 self.model_dim=512
self.use_sdp = use_sdp self.use_sdp = use_sdp
@ -1083,7 +1123,7 @@ class SynthesizerTrnV3(nn.Module):
n_q=1, n_q=1,
bins=1024 bins=1024
) )
freeze_quantizer self.freeze_quantizer=freeze_quantizer
inter_channels2=512 inter_channels2=512
self.bridge=nn.Sequential( self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
@ -1092,32 +1132,213 @@ class SynthesizerTrnV3(nn.Module):
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels) self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1) self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
if freeze_quantizer==True: if self.freeze_quantizer==True:
set_no_grad(self.ssl_proj) set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer) set_no_grad(self.quantizer)
set_no_grad(self.enc_p) set_no_grad(self.enc_p)
def create_ge(self, refer): def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
refer_lengths = compile_ref_length(refer) with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()#
self.quantizer.eval()
self.enc_p.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
B=ssl.shape[0]
prompt_len_max = mel_lengths*2/3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
minn=min(mel.shape[-1],fea.shape[-1])
mel=mel[:,:,:minn]
fea=fea[:,:,:minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
return cfm_loss
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None,speed=1):
# print(2333333,refer.shape)
# ge=None
if(ge==None):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
return ge y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
if speed==1:
def forward(self, codes, text,ge,speed=1): sizee=int(codes.size(2)*2.5*1.5)
else:
y_lengths1=compile_codes_length(codes) sizee=int(codes.size(2)*2.5*1.5/speed)+1
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge,speed) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed)
fea=self.bridge(x) fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
####more wn paramter to learn mel ####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge) fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea return fea,ge
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1) return codes.transpose(0,1)
class SynthesizerTrnV3b(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.model_dim=512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.freeze_quantizer=freeze_quantizer
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
nn.LeakyReLU()
)
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
# ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
# ge=None
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
learned_mel = self.linear_mel(fea)
B=ssl.shape[0]
prompt_len_max = mel_lengths*2/3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)#
minn=min(mel.shape[-1],fea.shape[-1])
mel=mel[:,:,:minn]
fea=fea[:,:,:minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need
return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None):
# print(2333333,refer.shape)
# ge=None
if(ge==None):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea,ge
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)

View File

@ -1,344 +0,0 @@
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch
import torchaudio
from torch import nn
from feature_extractor import cnhubert
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
import soundfile
from tools.my_utils import load_audio
import os
import json
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
class T2SEncoder(nn.Module):
def __init__(self, t2s, vits):
super().__init__()
self.encoder = t2s.onnx_encoder
self.vits = vits
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
codes = self.vits.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0)
prompt = prompt_semantic.unsqueeze(0)
return self.encoder(all_phoneme_ids, bert), prompt
class T2SModel(nn.Module):
def __init__(self, t2s_path, vits_model):
super().__init__()
dict_s1 = torch.load(t2s_path, map_location="cpu")
self.config = dict_s1["config"]
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
self.t2s_model.load_state_dict(dict_s1["weight"])
self.t2s_model.eval()
self.vits_model = vits_model.vq_model
self.hz = 50
self.max_sec = self.config["data"]["max_sec"]
self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
self.t2s_model = self.t2s_model.model
self.t2s_model.init_onnx()
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
self.first_stage_decoder = self.t2s_model.first_stage_decoder
self.stage_decoder = self.t2s_model.stage_decoder
#self.t2s_model = torch.jit.script(self.t2s_model)
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
early_stop_num = self.t2s_model.early_stop_num
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
prefix_len = prompts.shape[1]
#[1,N,512] [1,N]
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
stop = False
for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
stop = True
if stop:
break
y[0, -1] = 0
return y[:, -idx:].unsqueeze(0)
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
if dynamo:
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_encoder_export_output = torch.onnx.dynamo_export(
self.onnx_encoder,
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
export_options=export_options
)
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
return
torch.onnx.export(
self.onnx_encoder,
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
output_names=["x", "prompts"],
dynamic_axes={
"ref_seq": {1 : "ref_length"},
"text_seq": {1 : "text_length"},
"ref_bert": {0 : "ref_length"},
"text_bert": {0 : "text_length"},
"ssl_content": {2 : "ssl_length"},
},
opset_version=16
)
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
torch.onnx.export(
self.first_stage_decoder,
(x, prompts),
f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
input_names=["x", "prompts"],
output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={
"x": {1 : "x_length"},
"prompts": {1 : "prompts_length"},
},
verbose=False,
opset_version=16
)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
torch.onnx.export(
self.stage_decoder,
(y, k, v, y_emb, x_example),
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
dynamic_axes={
"iy": {1 : "iy_length"},
"ik": {1 : "ik_length"},
"iv": {1 : "iv_length"},
"iy_emb": {1 : "iy_emb_length"},
"ix_example": {1 : "ix_example_length"},
},
verbose=False,
opset_version=16
)
class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
)
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
class GptSoVits(nn.Module):
def __init__(self, vits, t2s):
super().__init__()
self.vits = vits
self.t2s = t2s
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
audio = self.vits(text_seq, pred_semantic, ref_audio)
if debug:
import onnxruntime
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
audio1 = sess.run(None, {
"text_seq" : text_seq.detach().cpu().numpy(),
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
"ref_audio" : ref_audio.detach().cpu().numpy()
})
return audio, audio1
return audio
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
torch.onnx.export(
self.vits,
(text_seq, pred_semantic, ref_audio),
f"onnx/{project_name}/{project_name}_vits.onnx",
input_names=["text_seq", "pred_semantic", "ref_audio"],
output_names=["audio"],
dynamic_axes={
"text_seq": {1 : "text_length"},
"pred_semantic": {2 : "pred_length"},
"ref_audio": {1 : "audio_length"},
},
opset_version=17,
verbose=False
)
class SSLModel(nn.Module):
def __init__(self):
super().__init__()
self.ssl = ssl_model
def forward(self, ref_audio_16k):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def export(vits_path, gpt_path, project_name, vits_model="v2"):
vits = VitsModel(vits_path)
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float()
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
try:
os.mkdir(f"onnx/{project_name}")
except:
pass
ssl_content = ssl(ref_audio_16k).float()
# debug = False
debug = True
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if debug:
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
else:
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
if vits_model == "v1":
symbols = symbols_v1
else:
symbols = symbols_v2
MoeVSConf = {
"Folder": f"{project_name}",
"Name": f"{project_name}",
"Type": "GPT-SoVits",
"Rate": vits.hps.data.sampling_rate,
"NumLayers": gpt.t2s_model.num_layers,
"EmbeddingDim": gpt.t2s_model.embedding_dim,
"Dict": "BasicDict",
"BertPath": "chinese-roberta-wwm-ext-large",
# "Symbol": symbols,
"AddBlank": False,
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
if __name__ == "__main__":
try:
os.mkdir("onnx")
except:
pass
gpt_path = "GPT_weights/nahida-e25.ckpt"
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
exp_path = "nahida"
export(vits_path, gpt_path, exp_path)
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)