mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-24 05:09:18 +08:00
1094 lines
39 KiB
Python
1094 lines
39 KiB
Python
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||
# reference: https://github.com/lifeiteng/vall-e
|
||
import argparse
|
||
from io import BytesIO
|
||
from typing import Optional
|
||
from my_utils import load_audio
|
||
import torch
|
||
import torchaudio
|
||
|
||
from torch import IntTensor, LongTensor, Tensor, nn
|
||
from torch.nn import functional as F
|
||
|
||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||
from feature_extractor import cnhubert
|
||
|
||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||
from module.models_onnx import SynthesizerTrn
|
||
|
||
from inference_webui import get_phones_and_bert
|
||
|
||
from sv import SV
|
||
import kaldi as Kaldi
|
||
|
||
import os
|
||
import soundfile
|
||
|
||
default_config = {
|
||
"embedding_dim": 512,
|
||
"hidden_dim": 512,
|
||
"num_head": 8,
|
||
"num_layers": 12,
|
||
"num_codebook": 8,
|
||
"p_dropout": 0.0,
|
||
"vocab_size": 1024 + 1,
|
||
"phoneme_vocab_size": 512,
|
||
"EOS": 1024,
|
||
}
|
||
|
||
sv_cn_model = None
|
||
def init_sv_cn(device, is_half):
|
||
global sv_cn_model
|
||
sv_cn_model = SV(device, is_half)
|
||
|
||
def load_sovits_new(sovits_path):
|
||
f = open(sovits_path, "rb")
|
||
meta = f.read(2)
|
||
if meta != b"PK":
|
||
data = b"PK" + f.read()
|
||
bio = BytesIO()
|
||
bio.write(data)
|
||
bio.seek(0)
|
||
return torch.load(bio, map_location="cpu", weights_only=False)
|
||
return torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||
|
||
|
||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||
config = dict_s1["config"]
|
||
config["model"]["dropout"] = float(config["model"]["dropout"])
|
||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||
t2s_model.load_state_dict(dict_s1["weight"])
|
||
t2s_model = t2s_model.eval()
|
||
return t2s_model
|
||
|
||
|
||
@torch.jit.script
|
||
def logits_to_probs(
|
||
logits,
|
||
previous_tokens: Optional[torch.Tensor] = None,
|
||
temperature: float = 1.0,
|
||
top_k: Optional[int] = None,
|
||
top_p: Optional[int] = None,
|
||
repetition_penalty: float = 1.0,
|
||
):
|
||
# if previous_tokens is not None:
|
||
# previous_tokens = previous_tokens.squeeze()
|
||
# print(logits.shape,previous_tokens.shape)
|
||
# pdb.set_trace()
|
||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||
previous_tokens = previous_tokens.long()
|
||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||
|
||
if top_p is not None and top_p < 1.0:
|
||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||
sorted_indices_to_remove = cum_probs > top_p
|
||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||
|
||
logits = logits / max(temperature, 1e-5)
|
||
|
||
if top_k is not None:
|
||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||
pivot = v[:, -1].unsqueeze(-1)
|
||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||
|
||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||
return probs
|
||
|
||
|
||
@torch.jit.script
|
||
def multinomial_sample_one_no_sync(probs_sort):
|
||
# Does multinomial sampling without a cuda synchronization
|
||
q = torch.randn_like(probs_sort)
|
||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||
|
||
|
||
@torch.jit.script
|
||
def sample(
|
||
logits,
|
||
previous_tokens,
|
||
temperature: float = 1.0,
|
||
top_k: Optional[int] = None,
|
||
top_p: Optional[int] = None,
|
||
repetition_penalty: float = 1.0,
|
||
):
|
||
probs = logits_to_probs(
|
||
logits=logits,
|
||
previous_tokens=previous_tokens,
|
||
temperature=temperature,
|
||
top_k=top_k,
|
||
top_p=top_p,
|
||
repetition_penalty=repetition_penalty,
|
||
)
|
||
idx_next = multinomial_sample_one_no_sync(probs)
|
||
return idx_next, probs
|
||
|
||
|
||
@torch.jit.script
|
||
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||
y = torch.nn.functional.pad(
|
||
y.unsqueeze(1),
|
||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||
mode="reflect",
|
||
)
|
||
y = y.squeeze(1)
|
||
spec = torch.stft(
|
||
y,
|
||
n_fft,
|
||
hop_length=hop_size,
|
||
win_length=win_size,
|
||
window=hann_window,
|
||
center=center,
|
||
pad_mode="reflect",
|
||
normalized=False,
|
||
onesided=True,
|
||
return_complex=False,
|
||
)
|
||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||
return spec
|
||
|
||
|
||
class DictToAttrRecursive(dict):
|
||
def __init__(self, input_dict):
|
||
super().__init__(input_dict)
|
||
for key, value in input_dict.items():
|
||
if isinstance(value, dict):
|
||
value = DictToAttrRecursive(value)
|
||
self[key] = value
|
||
setattr(self, key, value)
|
||
|
||
def __getattr__(self, item):
|
||
try:
|
||
return self[item]
|
||
except KeyError:
|
||
raise AttributeError(f"Attribute {item} not found")
|
||
|
||
def __setattr__(self, key, value):
|
||
if isinstance(value, dict):
|
||
value = DictToAttrRecursive(value)
|
||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||
super().__setattr__(key, value)
|
||
|
||
def __delattr__(self, item):
|
||
try:
|
||
del self[item]
|
||
except KeyError:
|
||
raise AttributeError(f"Attribute {item} not found")
|
||
|
||
|
||
@torch.jit.script
|
||
class T2SMLP:
|
||
def __init__(self, w1, b1, w2, b2):
|
||
self.w1 = w1
|
||
self.b1 = b1
|
||
self.w2 = w2
|
||
self.b2 = b2
|
||
|
||
def forward(self, x):
|
||
x = F.relu(F.linear(x, self.w1, self.b1))
|
||
x = F.linear(x, self.w2, self.b2)
|
||
return x
|
||
|
||
|
||
@torch.jit.script
|
||
class T2SBlock:
|
||
def __init__(
|
||
self,
|
||
num_heads: int,
|
||
hidden_dim: int,
|
||
mlp: T2SMLP,
|
||
qkv_w,
|
||
qkv_b,
|
||
out_w,
|
||
out_b,
|
||
norm_w1,
|
||
norm_b1,
|
||
norm_eps1: float,
|
||
norm_w2,
|
||
norm_b2,
|
||
norm_eps2: float,
|
||
):
|
||
self.num_heads = num_heads
|
||
self.mlp = mlp
|
||
self.hidden_dim: int = hidden_dim
|
||
self.qkv_w = qkv_w
|
||
self.qkv_b = qkv_b
|
||
self.out_w = out_w
|
||
self.out_b = out_b
|
||
self.norm_w1 = norm_w1
|
||
self.norm_b1 = norm_b1
|
||
self.norm_eps1 = norm_eps1
|
||
self.norm_w2 = norm_w2
|
||
self.norm_b2 = norm_b2
|
||
self.norm_eps2 = norm_eps2
|
||
|
||
self.false = torch.tensor(False, dtype=torch.bool)
|
||
|
||
@torch.jit.ignore
|
||
def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
|
||
if padding_mask is None:
|
||
return x
|
||
|
||
if padding_mask.dtype == torch.bool:
|
||
return x.masked_fill(padding_mask, 0)
|
||
else:
|
||
return x * padding_mask
|
||
|
||
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||
|
||
batch_size = q.shape[0]
|
||
q_len = q.shape[1]
|
||
kv_len = k.shape[1]
|
||
|
||
q = self.to_mask(q, padding_mask)
|
||
k_cache = self.to_mask(k, padding_mask)
|
||
v_cache = self.to_mask(v, padding_mask)
|
||
|
||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||
|
||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||
|
||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||
|
||
if padding_mask is not None:
|
||
for i in range(batch_size):
|
||
# mask = padding_mask[i,:,0]
|
||
if self.false.device != padding_mask.device:
|
||
self.false = self.false.to(padding_mask.device)
|
||
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
||
x_item = x[i, idx, :].unsqueeze(0)
|
||
attn_item = attn[i, idx, :].unsqueeze(0)
|
||
x_item = x_item + attn_item
|
||
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||
x_item = x_item + self.mlp.forward(x_item)
|
||
x_item = F.layer_norm(
|
||
x_item,
|
||
[self.hidden_dim],
|
||
self.norm_w2,
|
||
self.norm_b2,
|
||
self.norm_eps2,
|
||
)
|
||
x[i, idx, :] = x_item.squeeze(0)
|
||
x = self.to_mask(x, padding_mask)
|
||
else:
|
||
x = x + attn
|
||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||
x = x + self.mlp.forward(x)
|
||
x = F.layer_norm(
|
||
x,
|
||
[self.hidden_dim],
|
||
self.norm_w2,
|
||
self.norm_b2,
|
||
self.norm_eps2,
|
||
)
|
||
return x, k_cache, v_cache
|
||
|
||
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||
|
||
k_cache = torch.cat([k_cache, k], dim=1)
|
||
v_cache = torch.cat([v_cache, v], dim=1)
|
||
|
||
batch_size = q.shape[0]
|
||
q_len = q.shape[1]
|
||
kv_len = k_cache.shape[1]
|
||
|
||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||
|
||
attn = F.scaled_dot_product_attention(q, k, v)
|
||
|
||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||
attn = F.linear(attn, self.out_w, self.out_b)
|
||
|
||
x = x + attn
|
||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||
x = x + self.mlp.forward(x)
|
||
x = F.layer_norm(
|
||
x,
|
||
[self.hidden_dim],
|
||
self.norm_w2,
|
||
self.norm_b2,
|
||
self.norm_eps2,
|
||
)
|
||
return x, k_cache, v_cache
|
||
|
||
|
||
@torch.jit.script
|
||
class T2STransformer:
|
||
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
|
||
self.num_blocks: int = num_blocks
|
||
self.blocks = blocks
|
||
|
||
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||
k_cache: list[torch.Tensor] = []
|
||
v_cache: list[torch.Tensor] = []
|
||
for i in range(self.num_blocks):
|
||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
|
||
k_cache.append(k_cache_)
|
||
v_cache.append(v_cache_)
|
||
return x, k_cache, v_cache
|
||
|
||
def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
|
||
for i in range(self.num_blocks):
|
||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||
return x, k_cache, v_cache
|
||
|
||
|
||
class VitsModel(nn.Module):
|
||
def __init__(self, vits_path, version=None):
|
||
super().__init__()
|
||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||
dict_s2 = load_sovits_new(vits_path)
|
||
self.hps = dict_s2["config"]
|
||
|
||
if version is None:
|
||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||
self.hps["model"]["version"] = "v1"
|
||
else:
|
||
self.hps["model"]["version"] = "v2"
|
||
else:
|
||
if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]:
|
||
self.hps["model"]["version"] = version
|
||
else:
|
||
raise ValueError(f"Unsupported version: {version}")
|
||
|
||
self.hps = DictToAttrRecursive(self.hps)
|
||
self.hps.model.semantic_frame_rate = "25hz"
|
||
self.vq_model = SynthesizerTrn(
|
||
self.hps.data.filter_length // 2 + 1,
|
||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||
n_speakers=self.hps.data.n_speakers,
|
||
**self.hps.model,
|
||
)
|
||
self.vq_model.eval()
|
||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||
|
||
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
|
||
refer = spectrogram_torch(
|
||
ref_audio,
|
||
self.hps.data.filter_length,
|
||
self.hps.data.sampling_rate,
|
||
self.hps.data.hop_length,
|
||
self.hps.data.win_length,
|
||
center=False,
|
||
)
|
||
return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0]
|
||
|
||
|
||
class T2SModel(nn.Module):
|
||
def __init__(self, raw_t2s: Text2SemanticLightningModule):
|
||
super(T2SModel, self).__init__()
|
||
self.model_dim = raw_t2s.model.model_dim
|
||
self.embedding_dim = raw_t2s.model.embedding_dim
|
||
self.num_head = raw_t2s.model.num_head
|
||
self.num_layers = raw_t2s.model.num_layers
|
||
self.vocab_size = raw_t2s.model.vocab_size
|
||
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
|
||
# self.p_dropout = float(raw_t2s.model.p_dropout)
|
||
self.EOS: int = int(raw_t2s.model.EOS)
|
||
self.norm_first = raw_t2s.model.norm_first
|
||
assert self.EOS == self.vocab_size - 1
|
||
self.hz = 50
|
||
|
||
self.bert_proj = raw_t2s.model.bert_proj
|
||
self.ar_text_embedding = raw_t2s.model.ar_text_embedding
|
||
self.ar_text_position = raw_t2s.model.ar_text_position
|
||
self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
|
||
self.ar_audio_position = raw_t2s.model.ar_audio_position
|
||
|
||
# self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||
# self.t2s_transformer = raw_t2s.model.t2s_transformer
|
||
|
||
blocks = []
|
||
h = raw_t2s.model.h
|
||
|
||
for i in range(self.num_layers):
|
||
layer = h.layers[i]
|
||
t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
|
||
|
||
block = T2SBlock(
|
||
self.num_head,
|
||
self.model_dim,
|
||
t2smlp,
|
||
layer.self_attn.in_proj_weight,
|
||
layer.self_attn.in_proj_bias,
|
||
layer.self_attn.out_proj.weight,
|
||
layer.self_attn.out_proj.bias,
|
||
layer.norm1.weight,
|
||
layer.norm1.bias,
|
||
layer.norm1.eps,
|
||
layer.norm2.weight,
|
||
layer.norm2.bias,
|
||
layer.norm2.eps,
|
||
)
|
||
|
||
blocks.append(block)
|
||
|
||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||
|
||
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
||
self.ar_predict_layer = raw_t2s.model.ar_predict_layer
|
||
# self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
|
||
self.max_sec = raw_t2s.config["data"]["max_sec"]
|
||
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||
|
||
def forward(
|
||
self,
|
||
prompts: LongTensor,
|
||
ref_seq: LongTensor,
|
||
text_seq: LongTensor,
|
||
ref_bert: torch.Tensor,
|
||
text_bert: torch.Tensor,
|
||
top_k: LongTensor,
|
||
):
|
||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||
bert = bert.unsqueeze(0)
|
||
|
||
x = self.ar_text_embedding(all_phoneme_ids)
|
||
x = x + self.bert_proj(bert.transpose(1, 2))
|
||
x: torch.Tensor = self.ar_text_position(x)
|
||
|
||
early_stop_num = self.early_stop_num
|
||
|
||
# [1,N,512] [1,N]
|
||
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||
y = prompts
|
||
# x_example = x[:,:,0] * 0.0
|
||
|
||
x_len = x.shape[1]
|
||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||
|
||
y_emb = self.ar_audio_embedding(y)
|
||
y_len = y_emb.shape[1]
|
||
prefix_len = y.shape[1]
|
||
y_pos = self.ar_audio_position(y_emb)
|
||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||
|
||
bsz = x.shape[0]
|
||
src_len = x_len + y_len
|
||
x_attn_mask_pad = F.pad(
|
||
x_attn_mask,
|
||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||
value=True,
|
||
)
|
||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||
(x_len, 0),
|
||
value=False,
|
||
)
|
||
xy_attn_mask = (
|
||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||
.unsqueeze(0)
|
||
.expand(bsz * self.num_head, -1, -1)
|
||
.view(bsz, self.num_head, src_len, src_len)
|
||
.to(device=x.device, dtype=torch.bool)
|
||
)
|
||
|
||
idx = 0
|
||
top_k = int(top_k)
|
||
|
||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||
|
||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||
logits = logits[:, :-1]
|
||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||
y = torch.concat([y, samples], dim=1)
|
||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||
:, y_len + idx
|
||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||
|
||
stop = False
|
||
# for idx in range(1, 50):
|
||
for idx in range(1, 1500):
|
||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||
|
||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||
logits = logits[:, :-1]
|
||
|
||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||
|
||
y = torch.concat([y, samples], dim=1)
|
||
|
||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||
stop = True
|
||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||
stop = True
|
||
if stop:
|
||
if y.shape[1] == 0:
|
||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||
break
|
||
|
||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||
:, y_len + idx
|
||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||
|
||
y[0, -1] = 0
|
||
|
||
return y[:, -idx:].unsqueeze(0)
|
||
|
||
|
||
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
|
||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||
|
||
|
||
@torch.jit.script
|
||
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
|
||
phone_level_feature = []
|
||
for i in range(word2ph.shape[0]):
|
||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||
phone_level_feature.append(repeat_feature)
|
||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||
# [sum(word2ph), 1024]
|
||
return phone_level_feature
|
||
|
||
|
||
class MyBertModel(torch.nn.Module):
|
||
def __init__(self, bert_model):
|
||
super(MyBertModel, self).__init__()
|
||
self.bert = bert_model
|
||
|
||
def forward(
|
||
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
|
||
):
|
||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
||
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
||
return build_phone_level_feature(res, word2ph)
|
||
|
||
|
||
class SSLModel(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.ssl = cnhubert.get_model().model
|
||
|
||
def forward(self, ref_audio_16k) -> torch.Tensor:
|
||
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||
return ssl_content
|
||
|
||
|
||
class ExportSSLModel(torch.nn.Module):
|
||
def __init__(self, ssl: SSLModel):
|
||
super().__init__()
|
||
self.ssl = ssl
|
||
|
||
def forward(self, ref_audio: torch.Tensor):
|
||
return self.ssl(ref_audio)
|
||
|
||
@torch.jit.export
|
||
def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||
audio = resamplex(ref_audio, src_sr, dst_sr).float()
|
||
return audio
|
||
|
||
|
||
def export_bert(output_path):
|
||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||
|
||
text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
|
||
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
||
word2ph = []
|
||
for c in text:
|
||
if c in [",", "。", ":", "?", ",", ".", "?"]:
|
||
word2ph.append(1)
|
||
else:
|
||
word2ph.append(2)
|
||
ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
|
||
|
||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
|
||
my_bert_model = MyBertModel(bert_model)
|
||
|
||
ref_bert_inputs = {
|
||
"input_ids": ref_bert_inputs["input_ids"],
|
||
"attention_mask": ref_bert_inputs["attention_mask"],
|
||
"token_type_ids": ref_bert_inputs["token_type_ids"],
|
||
"word2ph": ref_bert_inputs["word2ph"],
|
||
}
|
||
|
||
torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
|
||
torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
|
||
torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
|
||
torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
|
||
|
||
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
|
||
output_path = os.path.join(output_path, "bert_model.pt")
|
||
my_bert_model.save(output_path)
|
||
print("#### exported bert ####")
|
||
|
||
|
||
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
|
||
if not os.path.exists(output_path):
|
||
os.makedirs(output_path)
|
||
print(f"目录已创建: {output_path}")
|
||
else:
|
||
print(f"目录已存在: {output_path}")
|
||
|
||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||
ssl = SSLModel()
|
||
if export_bert_and_ssl:
|
||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||
torch.jit.script(s).save(ssl_path)
|
||
print("#### exported ssl ####")
|
||
export_bert(output_path)
|
||
else:
|
||
s = ExportSSLModel(ssl)
|
||
|
||
print(f"device: {device}")
|
||
|
||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
|
||
)
|
||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||
text_bert = text_bert_T.T.to(text_seq.device)
|
||
|
||
ssl_content = ssl(ref_audio).to(device)
|
||
|
||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||
vits = VitsModel(vits_path).to(device)
|
||
vits.eval()
|
||
|
||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||
print("#### get_raw_t2s_model ####")
|
||
print(raw_t2s.config)
|
||
t2s_m = T2SModel(raw_t2s)
|
||
t2s_m.eval()
|
||
t2s = torch.jit.script(t2s_m).to(device)
|
||
print("#### script t2s_m ####")
|
||
|
||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
|
||
gpt_sovits.eval()
|
||
|
||
ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
|
||
|
||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||
torch._dynamo.mark_dynamic(text_seq, 1)
|
||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||
|
||
top_k = torch.LongTensor([5]).to(device)
|
||
|
||
with torch.no_grad():
|
||
gpt_sovits_export = torch.jit.trace(
|
||
gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||
)
|
||
|
||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||
gpt_sovits_export.save(gpt_sovits_path)
|
||
print("#### exported gpt_sovits ####")
|
||
|
||
|
||
def export_prov2(
|
||
gpt_path,
|
||
vits_path,
|
||
version,
|
||
ref_audio_path,
|
||
ref_text,
|
||
output_path,
|
||
export_bert_and_ssl=False,
|
||
device="cpu",
|
||
is_half=True,
|
||
):
|
||
if sv_cn_model == None:
|
||
init_sv_cn(device,is_half)
|
||
|
||
if not os.path.exists(output_path):
|
||
os.makedirs(output_path)
|
||
print(f"目录已创建: {output_path}")
|
||
else:
|
||
print(f"目录已存在: {output_path}")
|
||
|
||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||
ssl = SSLModel()
|
||
if export_bert_and_ssl:
|
||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||
torch.jit.script(s).save(ssl_path)
|
||
print("#### exported ssl ####")
|
||
export_bert(output_path)
|
||
else:
|
||
s = ExportSSLModel(ssl)
|
||
|
||
print(f"device: {device}")
|
||
|
||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(
|
||
ref_text, "all_zh", "v2"
|
||
)
|
||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||
ref_bert = ref_bert_T.T
|
||
if is_half:
|
||
ref_bert = ref_bert.half()
|
||
ref_bert = ref_bert.to(ref_seq.device)
|
||
|
||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "auto", "v2"
|
||
)
|
||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||
text_bert = text_bert_T.T
|
||
if is_half:
|
||
text_bert = text_bert.half()
|
||
text_bert = text_bert.to(text_seq.device)
|
||
|
||
ssl_content = ssl(ref_audio)
|
||
if is_half:
|
||
ssl_content = ssl_content.half()
|
||
ssl_content = ssl_content.to(device)
|
||
|
||
sv_model = ExportERes2NetV2(sv_cn_model)
|
||
|
||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||
vits = VitsModel(vits_path, version)
|
||
if is_half:
|
||
vits.vq_model = vits.vq_model.half()
|
||
vits.to(device)
|
||
vits.eval()
|
||
|
||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||
print("#### get_raw_t2s_model ####")
|
||
print(raw_t2s.config)
|
||
if is_half:
|
||
raw_t2s = raw_t2s.half()
|
||
t2s_m = T2SModel(raw_t2s)
|
||
t2s_m.eval()
|
||
t2s = torch.jit.script(t2s_m).to(device)
|
||
print("#### script t2s_m ####")
|
||
|
||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||
gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device)
|
||
gpt_sovits.eval()
|
||
|
||
ref_audio_sr = s.resample(ref_audio, 16000, 32000)
|
||
if is_half:
|
||
ref_audio_sr = ref_audio_sr.half()
|
||
ref_audio_sr = ref_audio_sr.to(device)
|
||
|
||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||
torch._dynamo.mark_dynamic(text_seq, 1)
|
||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||
# torch._dynamo.mark_dynamic(sv_emb, 0)
|
||
|
||
top_k = torch.LongTensor([5]).to(device)
|
||
# 先跑一遍 sv_model 让它加载 cache,详情见 L880
|
||
gpt_sovits.sv_model(ref_audio_sr)
|
||
|
||
with torch.no_grad():
|
||
gpt_sovits_export = torch.jit.trace(
|
||
gpt_sovits,
|
||
example_inputs=(
|
||
ssl_content,
|
||
ref_audio_sr,
|
||
ref_seq,
|
||
text_seq,
|
||
ref_bert,
|
||
text_bert,
|
||
top_k,
|
||
),
|
||
)
|
||
|
||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||
gpt_sovits_export.save(gpt_sovits_path)
|
||
print("#### exported gpt_sovits ####")
|
||
audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||
print("start write wav")
|
||
soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||
|
||
|
||
@torch.jit.script
|
||
def parse_audio(ref_audio):
|
||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
|
||
return ref_audio_16k, ref_audio_sr
|
||
|
||
|
||
@torch.jit.script
|
||
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
|
||
|
||
|
||
class GPT_SoVITS(nn.Module):
|
||
def __init__(self, t2s: T2SModel, vits: VitsModel):
|
||
super().__init__()
|
||
self.t2s = t2s
|
||
self.vits = vits
|
||
|
||
def forward(
|
||
self,
|
||
ssl_content: torch.Tensor,
|
||
ref_audio_sr: torch.Tensor,
|
||
ref_seq: Tensor,
|
||
text_seq: Tensor,
|
||
ref_bert: Tensor,
|
||
text_bert: Tensor,
|
||
top_k: LongTensor,
|
||
speed=1.0,
|
||
):
|
||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||
prompt_semantic = codes[0, 0]
|
||
prompts = prompt_semantic.unsqueeze(0)
|
||
|
||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
||
return audio
|
||
|
||
|
||
class ExportERes2NetV2(nn.Module):
|
||
def __init__(self, sv_cn_model:SV):
|
||
super(ExportERes2NetV2, self).__init__()
|
||
self.bn1 = sv_cn_model.embedding_model.bn1
|
||
self.conv1 = sv_cn_model.embedding_model.conv1
|
||
self.layer1 = sv_cn_model.embedding_model.layer1
|
||
self.layer2 = sv_cn_model.embedding_model.layer2
|
||
self.layer3 = sv_cn_model.embedding_model.layer3
|
||
self.layer4 = sv_cn_model.embedding_model.layer4
|
||
self.layer3_ds = sv_cn_model.embedding_model.layer3_ds
|
||
self.fuse34 = sv_cn_model.embedding_model.fuse34
|
||
|
||
# audio_16k.shape: [1,N]
|
||
def forward(self, audio_16k):
|
||
# 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关
|
||
# 只跟 device 和 dtype 有关
|
||
x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
|
||
x = torch.stack([x])
|
||
|
||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||
x = x.unsqueeze_(1)
|
||
out = F.relu(self.bn1(self.conv1(x)))
|
||
out1 = self.layer1(out)
|
||
out2 = self.layer2(out1)
|
||
out3 = self.layer3(out2)
|
||
out4 = self.layer4(out3)
|
||
out3_ds = self.layer3_ds(out3)
|
||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
|
||
|
||
|
||
class GPT_SoVITS_V2Pro(nn.Module):
|
||
def __init__(self, t2s: T2SModel, vits: VitsModel,sv_model:ExportERes2NetV2):
|
||
super().__init__()
|
||
self.t2s = t2s
|
||
self.vits = vits
|
||
self.sv_model = sv_model
|
||
|
||
def forward(
|
||
self,
|
||
ssl_content: torch.Tensor,
|
||
ref_audio_sr: torch.Tensor,
|
||
ref_seq: Tensor,
|
||
text_seq: Tensor,
|
||
ref_bert: Tensor,
|
||
text_bert: Tensor,
|
||
top_k: LongTensor,
|
||
speed=1.0,
|
||
):
|
||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||
prompt_semantic = codes[0, 0]
|
||
prompts = prompt_semantic.unsqueeze(0)
|
||
|
||
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||
sv_emb = self.sv_model(audio_16k)
|
||
|
||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb)
|
||
return audio
|
||
|
||
def test():
|
||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||
|
||
args = parser.parse_args()
|
||
gpt_path = args.gpt_model
|
||
vits_path = args.sovits_model
|
||
ref_audio_path = args.ref_audio
|
||
ref_text = args.ref_text
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||
# bert = MyBertModel(bert_model)
|
||
my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
|
||
|
||
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
||
# raw_t2s = get_raw_t2s_model(dict_s1)
|
||
# t2s = T2SModel(raw_t2s)
|
||
# t2s.eval()
|
||
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
|
||
|
||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||
# vits = VitsModel(vits_path)
|
||
# vits.eval()
|
||
|
||
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
||
# ssl.eval()
|
||
ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
|
||
|
||
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
||
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
|
||
|
||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||
ref_seq = torch.LongTensor([ref_seq_id])
|
||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
||
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
||
|
||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
|
||
|
||
test_bert = tokenizer(text, return_tensors="pt")
|
||
word2ph = []
|
||
for c in text:
|
||
if c in [",", "。", ":", "?", "?", ",", "."]:
|
||
word2ph.append(1)
|
||
else:
|
||
word2ph.append(2)
|
||
test_bert["word2ph"] = torch.Tensor(word2ph).int()
|
||
|
||
test_bert = my_bert(
|
||
test_bert["input_ids"].to("cuda"),
|
||
test_bert["attention_mask"].to("cuda"),
|
||
test_bert["token_type_ids"].to("cuda"),
|
||
test_bert["word2ph"].to("cuda"),
|
||
)
|
||
|
||
text_seq = torch.LongTensor([text_seq_id])
|
||
text_bert = text_bert_T.T.to(text_seq.device)
|
||
|
||
print("text_bert:", text_bert.shape, text_bert)
|
||
print("test_bert:", test_bert.shape, test_bert)
|
||
print(torch.allclose(text_bert.to("cuda"), test_bert))
|
||
|
||
print("text_seq:", text_seq.shape)
|
||
print("text_bert:", text_bert.shape, text_bert.type())
|
||
|
||
# [1,N]
|
||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
|
||
print("ref_audio:", ref_audio.shape)
|
||
|
||
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
|
||
print("start ssl")
|
||
ssl_content = ssl(ref_audio)
|
||
|
||
print("start gpt_sovits:")
|
||
print("ssl_content:", ssl_content.shape)
|
||
print("ref_audio_sr:", ref_audio_sr.shape)
|
||
print("ref_seq:", ref_seq.shape)
|
||
ref_seq = ref_seq.to("cuda")
|
||
print("text_seq:", text_seq.shape)
|
||
text_seq = text_seq.to("cuda")
|
||
print("ref_bert:", ref_bert.shape)
|
||
ref_bert = ref_bert.to("cuda")
|
||
print("text_bert:", text_bert.shape)
|
||
text_bert = text_bert.to("cuda")
|
||
|
||
top_k = torch.LongTensor([5]).to("cuda")
|
||
|
||
with torch.no_grad():
|
||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||
print("start write wav")
|
||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||
|
||
|
||
import text
|
||
import json
|
||
|
||
|
||
def export_symbel(version="v2"):
|
||
if version == "v1":
|
||
symbols = text._symbol_to_id_v1
|
||
with open("onnx/symbols_v1.json", "w") as file:
|
||
json.dump(symbols, file, indent=4)
|
||
else:
|
||
symbols = text._symbol_to_id_v2
|
||
with open("onnx/symbols_v2.json", "w") as file:
|
||
json.dump(symbols, file, indent=4)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||
parser.add_argument(
|
||
"--sovits_model", required=True, help="Path to the SoVITS model file"
|
||
)
|
||
parser.add_argument(
|
||
"--ref_audio", required=True, help="Path to the reference audio file"
|
||
)
|
||
parser.add_argument(
|
||
"--ref_text", required=True, help="Path to the reference text file"
|
||
)
|
||
parser.add_argument(
|
||
"--output_path", required=True, help="Path to the output directory"
|
||
)
|
||
parser.add_argument(
|
||
"--export_common_model", action="store_true", help="Export Bert and SSL model"
|
||
)
|
||
parser.add_argument("--device", help="Device to use")
|
||
parser.add_argument("--version", help="version of the model", default="v2")
|
||
parser.add_argument("--no-half", action="store_true", help = "Do not use half precision for model weights")
|
||
|
||
args = parser.parse_args()
|
||
if args.version in ["v2Pro", "v2ProPlus"]:
|
||
is_half = not args.no_half
|
||
print(f"Using half precision: {is_half}")
|
||
export_prov2(
|
||
gpt_path=args.gpt_model,
|
||
vits_path=args.sovits_model,
|
||
version=args.version,
|
||
ref_audio_path=args.ref_audio,
|
||
ref_text=args.ref_text,
|
||
output_path=args.output_path,
|
||
export_bert_and_ssl=args.export_common_model,
|
||
device=args.device,
|
||
is_half=is_half,
|
||
)
|
||
else:
|
||
export(
|
||
gpt_path=args.gpt_model,
|
||
vits_path=args.sovits_model,
|
||
ref_audio_path=args.ref_audio,
|
||
ref_text=args.ref_text,
|
||
output_path=args.output_path,
|
||
device=args.device,
|
||
export_bert_and_ssl=args.export_common_model,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
with torch.no_grad():
|
||
main()
|
||
# test()
|