This commit is contained in:
chasonjiang 2024-03-16 22:00:20 +08:00
commit 8b548e1956
15 changed files with 3050 additions and 545 deletions

3
.gitignore vendored
View File

@ -10,5 +10,6 @@ reference
GPT_weights
SoVITS_weights
TEMP
ffmpeg.exe
ffprobe.exe

View File

@ -13,11 +13,11 @@ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled)
pretrained_s1 = config.get("pretrained_s1")
if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))

View File

@ -1,5 +1,9 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import List
import torch
from tqdm import tqdm
@ -35,8 +39,144 @@ default_config = {
}
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
self.w1 = w1
self.b1 = b1
self.w2 = w2
self.b2 = b2
def forward(self, x):
x = F.relu(F.linear(x, self.w1, self.b1))
x = F.linear(x, self.w2, self.b2)
return x
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w
self.qkv_b = qkv_b
self.out_w = out_w
self.out_b = out_b
self.norm_w1 = norm_w1
self.norm_b1 = norm_b1
self.norm_eps1 = norm_eps1
self.norm_w2 = norm_w2
self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2
def process_prompt(self, x, attn_mask : torch.Tensor):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
k_cache = k
v_cache = v
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b)
x = F.layer_norm(
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(
x + self.mlp.forward(x),
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x, k_cache, v_cache):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b)
x = F.layer_norm(
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(
x + self.mlp.forward(x),
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks
self.blocks = blocks
def process_prompt(
self, x, attn_mask : torch.Tensor):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask)
k_cache.append(k_cache_)
v_cache.append(v_cache_)
return x, k_cache, v_cache
def decode_next_token(
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
@ -88,6 +228,47 @@ class Text2SemanticDecoder(nn.Module):
multidim_average="global",
ignore_index=self.EOS,
)
self.enable_flash_attn(flash_attn_enabled)
def enable_flash_attn(self, enable:bool=True):
if not enable:
print("Not Using Flash Attention")
self.infer_panel = self.infer_panel_batch_only
else:
self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
print("Using Flash Attention")
blocks = []
for i in range(self.num_layers):
layer = self.h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)
block = T2SBlock(
self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
@ -321,19 +502,197 @@ class Text2SemanticDecoder(nn.Module):
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel(
def infer_panel_batch_infer_with_flash_attn(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
# 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
max_len = 0
for x_item, bert_item in zip(x, bert_feature):
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
bert_feature = torch.stack(bert_features_list, dim=0)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
# x = self.ar_text_embedding(x)
x = x + bert_feature
x = self.ar_text_position(x)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
y_mask = make_pad_mask(y_lens)
x_mask = make_pad_mask(x_lens)
# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
###### decode #####
y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
if idx == 0:
xy_attn_mask = None
logits = logits[:, :-1]
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS则停止
l = samples[:, 0]==self.EOS
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if not (None in idx_list):
stop = True
if stop:
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, idx_list
def infer_panel_batch_only(
self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
# 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
max_len = 0
for x_item, bert_item in zip(x, bert_feature):
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
bert_feature = torch.stack(bert_features_list, dim=0)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
# x = self.ar_text_embedding(x)
x = x + bert_feature
x = self.ar_text_position(x)
# AR Decoder
@ -372,21 +731,37 @@ class Text2SemanticDecoder(nn.Module):
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
y_mask = make_pad_mask(y_lens)
x_mask = make_pad_mask(x_lens)
# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz*self.num_head, -1, -1).to(x.device)
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(bsz, src_len, src_len).repeat(self.num_head, 1, 1)
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
@ -397,17 +772,45 @@ class Text2SemanticDecoder(nn.Module):
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0].unsqueeze(0)
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
# 移除已经生成完毕的序列
reserved_idx_of_batch_for_y = None
if (self.EOS in torch.argmax(logits, dim=-1)) or \
(self.EOS in samples[:, 0]): ###如果生成到EOS则停止
l = samples[:, 0]==self.EOS
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if cache["y_emb"] is not None:
cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
if cache["k"] is not None:
for i in range(self.num_layers):
# 因为kv转置了所以batch dim是1
cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if not (None in idx_list):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
@ -443,6 +846,12 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx-1
return y_list, [0]*x.shape[0]
return y_list, idx_list

View File

@ -0,0 +1,483 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
dpo_loss,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
"num_head": 8,
"num_layers": 12,
"num_codebook": 8,
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024,
}
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config["model"]["head"]
self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config["model"]["dropout"]
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1
# should be same as num of kmeans bin
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.h = TransformerEncoder(
TransformerEncoderLayer(
d_model=self.model_dim,
nhead=self.num_head,
dim_feedforward=self.model_dim * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size,
top_k=top_k,
average="micro",
multidim_average="global",
ignore_index=self.EOS,
)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens)
y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64)
codes = y.type(torch.int64) * (1 - y_mask_int)
# Training
# AR Decoder
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
x_len = x_lens.max()
y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
ar_xy_padding_mask = xy_padding_mask
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1,
),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
# x 和完整的 y 一次性输入模型
xy_pos = torch.concat([x, y_pos], dim=1)
return xy_pos, xy_attn_mask, targets
def forward(self, x, x_lens, y, y_lens, bert_feature):
"""
x: phoneme_ids
y: semantic_ids
"""
reject_y, reject_y_lens = make_reject_y(y, y_lens)
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
x_len = x_lens.max()
logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
mask=reject_xy_attn_mask,
)
x_len = x_lens.max()
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2
return loss, acc
def forward_old(self, x, x_lens, y, y_lens, bert_feature):
"""
x: phoneme_ids
y: semantic_ids
"""
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens)
y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64)
codes = y.type(torch.int64) * (1 - y_mask_int)
# Training
# AR Decoder
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
x_len = x_lens.max()
y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
ar_xy_padding_mask = xy_padding_mask
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1,
),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
# x 和完整的 y 一次性输入模型
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
for _ in tqdm(range(1500)):
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型
xy_pos = torch.concat([x, y_pos], dim=1)
y_len = y.shape[1]
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
# import os
# os._exit(2333)
y = torch.concat([y, samples], dim=1)
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer": 1,
"stage": 0,
}
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
cache["y_emb"] = y_emb
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
# 移除已经生成完毕的序列
reserved_idx_of_batch_for_y = None
if (self.EOS in torch.argmax(logits, dim=-1)) or \
(self.EOS in samples[:, 0]): ###如果生成到EOS则停止
l = samples[:, 0]==self.EOS
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if cache["y_emb"] is not None:
cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
if cache["k"] is not None:
for i in range(self.num_layers):
# 因为kv转置了所以batch dim是1
cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if not (None in idx_list):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
# if prompts.shape[1] == y.shape[1]:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# print("bad zero prediction")
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
cache["first_infer"] = 0
if cache["y_emb"] is not None:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
else:
y_emb = self.ar_audio_embedding(y[:, -1:])
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos
y_len = y_pos.shape[1]
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, idx_list

View File

@ -115,17 +115,17 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
if previous_tokens is not None:
previous_tokens = previous_tokens.squeeze()
# if previous_tokens is not None:
# previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@ -133,9 +133,9 @@ def logits_to_probs(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@ -143,7 +143,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
pivot = v[: , -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)

View File

@ -0,0 +1,920 @@
from copy import deepcopy
import math
import os, sys
import random
import traceback
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
import ffmpeg
import os
from typing import Generator, List, Union
import numpy as np
import torch
import torch.nn.functional as F
import yaml
from transformers import AutoModelForMaskedLM, AutoTokenizer
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from feature_extractor.cnhubert import CNHubert
from module.models import SynthesizerTrn
import librosa
from time import time as ttime
from tools.i18n.i18n import I18nAuto
from my_utils import load_audio
from module.mel_processing import spectrogram_torch
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
i18n = I18nAuto()
# configs/tts_infer.yaml
"""
default:
device: cpu
is_half: false
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
flash_attn_enabled: true
custom:
device: cuda
is_half: true
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
flash_attn_enabled: true
"""
def set_seed(seed:int):
seed = int(seed)
seed = seed if seed != -1 else random.randrange(1 << 32)
print(f"Set seed to {seed}")
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
try:
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = True
except:
pass
return seed
class TTS_Config:
default_configs={
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
}
configs:dict = None
def __init__(self, configs: Union[dict, str]=None):
# 设置默认配置文件路径
configs_base_path:str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
if configs in ["", None]:
if not os.path.exists(self.configs_path):
self.save_configs()
print(f"Create default config file at {self.configs_path}")
configs:dict = {"default": deepcopy(self.default_configs)}
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict)
default_configs:dict = configs.get("default", None)
if default_configs is not None:
self.default_configs = default_configs
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
self.device = self.configs.get("device", torch.device("cpu"))
self.is_half = self.configs.get("is_half", False)
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
self.vits_weights_path = self.default_configs['vits_weights_path']
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
self.bert_base_path = self.default_configs['bert_base_path']
print(f"fall back to default bert_base_path: {self.bert_base_path}")
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs()
self.max_sec = None
self.hz:int = 50
self.semantic_frame_rate:str = "25hz"
self.segment_size:int = 20480
self.filter_length:int = 2048
self.sampling_rate:int = 32000
self.hop_length:int = 640
self.win_length:int = 2048
self.n_speakers:int = 300
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
# print(self)
def _load_configs(self, configs_path: str)->dict:
with open(configs_path, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
return configs
def save_configs(self, configs_path:str=None)->None:
configs={
"default":self.default_configs,
}
if self.configs is not None:
configs["custom"] = self.update_configs()
if configs_path is None:
configs_path = self.configs_path
with open(configs_path, 'w') as f:
yaml.dump(configs, f)
def update_configs(self):
self.config = {
"device" : str(self.device),
"is_half" : self.is_half,
"t2s_weights_path" : self.t2s_weights_path,
"vits_weights_path" : self.vits_weights_path,
"bert_base_path" : self.bert_base_path,
"cnhuhbert_base_path": self.cnhuhbert_base_path,
"flash_attn_enabled" : self.flash_attn_enabled
}
return self.config
def __str__(self):
self.configs = self.update_configs()
string = "TTS Config".center(100, '-') + '\n'
for k, v in self.configs.items():
string += f"{str(k).ljust(20)}: {str(v)}\n"
string += "-" * 100 + '\n'
return string
def __repr__(self):
return self.__str__()
class TTS:
def __init__(self, configs: Union[dict, str, TTS_Config]):
if isinstance(configs, TTS_Config):
self.configs = configs
else:
self.configs:TTS_Config = TTS_Config(configs)
self.t2s_model:Text2SemanticLightningModule = None
self.vits_model:SynthesizerTrn = None
self.bert_tokenizer:AutoTokenizer = None
self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None
self._init_models()
self.text_preprocessor:TextPreprocessor = \
TextPreprocessor(self.bert_model,
self.bert_tokenizer,
self.configs.device)
self.prompt_cache:dict = {
"ref_audio_path":None,
"prompt_semantic":None,
"refer_spepc":None,
"prompt_text":None,
"prompt_lang":None,
"phones":None,
"bert_features":None,
"norm_text":None,
}
self.stop_flag:bool = False
self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
def _init_models(self,):
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
self.init_bert_weights(self.configs.bert_base_path)
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
# self.enable_half_precision(self.configs.is_half)
def init_cnhuhbert_weights(self, base_path: str):
print(f"Loading CNHuBERT weights from {base_path}")
self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model=self.cnhuhbert_model.eval()
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
if self.configs.is_half and str(self.configs.device)!="cpu":
self.cnhuhbert_model = self.cnhuhbert_model.half()
def init_bert_weights(self, base_path: str):
print(f"Loading BERT weights from {base_path}")
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
self.bert_model=self.bert_model.eval()
self.bert_model = self.bert_model.to(self.configs.device)
if self.configs.is_half and str(self.configs.device)!="cpu":
self.bert_model = self.bert_model.half()
def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path
self.configs.save_configs()
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
hps = dict_s2["config"]
self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"]
self.configs.sampling_rate = hps["data"]["sampling_rate"]
self.configs.hop_length = hps["data"]["hop_length"]
self.configs.win_length = hps["data"]["win_length"]
self.configs.n_speakers = hps["data"]["n_speakers"]
self.configs.semantic_frame_rate = "25hz"
kwargs = hps["model"]
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs
)
# if ("pretrained" not in weights_path):
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
vits_model = vits_model.to(self.configs.device)
vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model
if self.configs.is_half and str(self.configs.device)!="cpu":
self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path
self.configs.save_configs()
self.configs.hz = 50
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
config = dict_s1["config"]
self.configs.max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
flash_attn_enabled=self.configs.flash_attn_enabled)
t2s_model.load_state_dict(dict_s1["weight"])
t2s_model = t2s_model.to(self.configs.device)
t2s_model = t2s_model.eval()
self.t2s_model = t2s_model
if self.configs.is_half and str(self.configs.device)!="cpu":
self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True):
'''
To enable half precision for the TTS model.
Args:
enable: bool, whether to enable half precision.
'''
if str(self.configs.device) == "cpu" and enable:
print("Half precision is not supported on CPU.")
return
self.configs.is_half = enable
self.precison = torch.float16 if enable else torch.float32
self.configs.save_configs()
if enable:
if self.t2s_model is not None:
self.t2s_model =self.t2s_model.half()
if self.vits_model is not None:
self.vits_model = self.vits_model.half()
if self.bert_model is not None:
self.bert_model =self.bert_model.half()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half()
else:
if self.t2s_model is not None:
self.t2s_model = self.t2s_model.float()
if self.vits_model is not None:
self.vits_model = self.vits_model.float()
if self.bert_model is not None:
self.bert_model = self.bert_model.float()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float()
def set_device(self, device: torch.device):
'''
To set the device for all models.
Args:
device: torch.device, the device to use for all models.
'''
self.configs.device = device
self.configs.save_configs()
if self.t2s_model is not None:
self.t2s_model = self.t2s_model.to(device)
if self.vits_model is not None:
self.vits_model = self.vits_model.to(device)
if self.bert_model is not None:
self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
def set_ref_audio(self, ref_audio_path:str):
'''
To set the reference audio for the TTS model,
including the prompt_semantic and refer_spepc.
Args:
ref_audio_path: str, the path of the reference audio.
'''
self._set_prompt_semantic(ref_audio_path)
self._set_ref_spepc(ref_audio_path)
def _set_ref_spepc(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
self.configs.filter_length,
self.configs.sampling_rate,
self.configs.hop_length,
self.configs.win_length,
center=False,
)
spec = spec.to(self.configs.device)
if self.configs.is_half:
spec = spec.half()
# self.refer_spepc = spec
self.prompt_cache["refer_spepc"] = spec
def _set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=np.float16 if self.configs.is_half else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(self.configs.device)
zero_wav_torch = zero_wav_torch.to(self.configs.device)
if self.configs.is_half:
wav16k = wav16k.half()
zero_wav_torch = zero_wav_torch.half()
wav16k = torch.cat([wav16k, zero_wav_torch])
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = self.vits_model.extract_latent(hubert_feature)
prompt_semantic = codes[0, 0].to(self.configs.device)
self.prompt_cache["prompt_semantic"] = prompt_semantic
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
seq = sequences[0]
ndim = seq.dim()
if axis < 0:
axis += ndim
dtype:torch.dtype = seq.dtype
pad_value = torch.tensor(pad_value, dtype=dtype)
seq_lengths = [seq.shape[axis] for seq in sequences]
if max_length is None:
max_length = max(seq_lengths)
else:
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value)
padded_sequences.append(padded_seq)
batch = torch.stack(padded_sequences)
return batch
def to_batch(self, data:list,
prompt_data:dict=None,
batch_size:int=5,
threshold:float=0.75,
split_bucket:bool=True,
device:torch.device=torch.device("cpu"),
precison:torch.dtype=torch.float32,
):
_data:list = []
index_and_len_list = []
for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
batch_index_list = []
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
assert batch_index_list_len == len(data)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=precison, device=device)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]\
.to(dtype=precison, device=device)
phones = torch.LongTensor(item["phones"]).to(device)
all_phones = phones
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_batch = phones_list
all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list
# max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad会增大复读概率。
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
# all_bert_features_batch = all_bert_features_list
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precison, device=device)
# for idx, item in enumerate(all_bert_features_list):
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
# #### 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
# all_phones_batch = torch.stack(all_phones_list, dim=0)
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch
}
_data.append(batch)
return _data, batch_index_list
def recovery_order(self, data:list, batch_index_list:list)->list:
'''
Recovery the order of the audio according to the batch_index_list.
Args:
data (List[list(np.ndarray)]): the out of order audio .
batch_index_list (List[list[int]]): the batch index list.
Returns:
list (List[np.ndarray]): the data in the original order.
'''
lenght = len(sum(batch_index_list, []))
_data = [None]*lenght
for i, index_list in enumerate(batch_index_list):
for j, index in enumerate(index_list):
_data[index] = data[i][j]
return _data
def stop(self,):
'''
Stop the inference process.
'''
self.stop_flag = True
def run(self, inputs:dict):
"""
Text to speech inference.
Args:
inputs (dict):
{
"text": "", # str. text to be synthesized
"text_lang: "", # str. language of the text to be synthesized
"ref_audio_path": "", # str. reference audio path
"prompt_text": "", # str. prompt text for the reference audio
"prompt_lang": "", # str. language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
}
returns:
tulpe[int, np.ndarray]: sampling rate and audio data.
"""
########## variables initialization ###########
self.stop_flag:bool = False
text:str = inputs.get("text", "")
text_lang:str = inputs.get("text_lang", "")
ref_audio_path:str = inputs.get("ref_audio_path", "")
prompt_text:str = inputs.get("prompt_text", "")
prompt_lang:str = inputs.get("prompt_lang", "")
top_k:int = inputs.get("top_k", 5)
top_p:float = inputs.get("top_p", 1)
temperature:float = inputs.get("temperature", 1)
text_split_method:str = inputs.get("text_split_method", "")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False)
fragment_interval = inputs.get("fragment_interval", 0.3)
seed = inputs.get("seed", -1)
seed = -1 if seed in ["", None] else seed
set_seed(seed)
if return_fragment:
# split_bucket = False
print(i18n("分段返回模式已开启"))
if split_bucket:
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket:
print(i18n("分桶处理模式已开启"))
if fragment_interval<0.01:
fragment_interval = 0.01
print(i18n("分段间隔过小已自动设置为0.01"))
no_prompt_text = False
if prompt_text in [None, ""]:
no_prompt_text = True
assert text_lang in self.configs.langauges
if not no_prompt_text:
assert prompt_lang in self.configs.langauges
if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
###### setting reference audio and prompt text preprocessing ########
t0 = ttime()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
self.set_ref_audio(ref_audio_path)
if not no_prompt_text:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_lang != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
if self.prompt_cache["prompt_text"] != prompt_text:
self.prompt_cache["prompt_text"] = prompt_text
self.prompt_cache["prompt_lang"] = prompt_lang
phones, bert_features, norm_text = \
self.text_preprocessor.segment_and_extract_feature_for_text(
prompt_text,
prompt_lang)
self.prompt_cache["phones"] = phones
self.prompt_cache["bert_features"] = bert_features
self.prompt_cache["norm_text"] = norm_text
###### text preprocessing ########
t1 = ttime()
data:list = None
if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
return
batch_index_list:list = None
data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size,
threshold=batch_threshold,
split_bucket=split_bucket,
device=self.configs.device,
precison=self.precison
)
else:
print(i18n("############ 切分文本 ############"))
texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
data = []
for i in range(len(texts)):
if i%batch_size == 0:
data.append([])
data[-1].append(texts[i])
def make_batch(batch_texts):
batch_data = []
print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(batch_texts):
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang)
if phones is None:
continue
res={
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
batch_data.append(res)
if len(batch_data) == 0:
return None
batch, _ = self.to_batch(batch_data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size,
threshold=batch_threshold,
split_bucket=False,
device=self.configs.device,
precison=self.precison
)
return batch[0]
t2 = ttime()
try:
print("############ 推理 ############")
###### inference ######
t_34 = 0.0
t_45 = 0.0
audio = []
for item in data:
t3 = ttime()
if return_fragment:
item = make_batch(item)
if item is None:
continue
batch_phones:List[torch.LongTensor] = item["phones"]
batch_phones_len:torch.LongTensor = item["phones_len"]
all_phoneme_ids:List[torch.LongTensor] = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
all_bert_features:List[torch.LongTensor] = item["all_bert_features"]
norm_text:str = item["norm_text"]
print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text :
prompt = None
else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
with torch.no_grad():
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
)
t4 = ttime()
t_34 += t4 - t3
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
.to(dtype=self.precison, device=self.configs.device)
batch_audio_fragment = []
# ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# max_len = 0
# for i in range(0, len(batch_phones)):
# max_len = max(max_len, batch_phones[i].shape[-1])
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
# ))
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones,refer_audio_spepc
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
# ## vits串行推理
# for i, idx in enumerate(idx_list):
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
# audio_fragment =(self.vits_model.decode(
# _pred_semantic, phones, refer_audio_spepc
# ).detach()[0, 0, :])
# batch_audio_fragment.append(
# audio_fragment
# ) ###试试重建不带上prompt部分
t5 = ttime()
t_45 += t5 - t4
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment],
self.configs.sampling_rate,
None,
speed_factor,
False,
fragment_interval
)
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
return
if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield self.audio_postprocess(audio,
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket,
fragment_interval
)
except Exception as e:
traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。
del self.t2s_model
del self.vits_model
self.t2s_model = None
self.vits_model = None
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
finally:
self.empty_cache()
def empty_cache(self):
try:
if "cuda" in str(self.configs.device):
torch.cuda.empty_cache()
elif str(self.configs.device) == "mps":
torch.mps.empty_cache()
except:
pass
def audio_postprocess(self,
audio:List[torch.Tensor],
sr:int,
batch_index_list:list=None,
speed_factor:float=1.0,
split_bucket:bool=True,
fragment_interval:float=0.3
)->tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval),
dtype=self.precison,
device=self.configs.device
)
for i, batch in enumerate(audio):
for j, audio_fragment in enumerate(batch):
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy()
if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
else:
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
try:
if speed_factor != 1.0:
audio = speed_change(audio, speed=speed_factor, sr=int(sr))
except Exception as e:
print(f"Failed to change speed of audio: \n{e}")
return sr, audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()
# 设置 ffmpeg 输入流
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
# 变速处理
output_stream = input_stream.filter('atempo', speed)
# 输出流到管道
out, _ = (
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
)
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio

View File

@ -0,0 +1,210 @@
import os, sys
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
import re
import torch
import LangSegment
from typing import Dict, List, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
def get_first(text:str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts:str, threshold:int) -> list:
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
print(i18n("############ 切分文本 ############"))
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
if phones is None:
continue
res={
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
result.append(res)
return result
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4):
text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:"))
print(text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
_texts = text.split("\n")
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
if (len(text) > 510):
texts.extend(split_big_text(text))
else:
texts.append(text)
print(i18n("实际输入的目标文本(切句后):"))
print(texts)
return texts
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
textlist, langlist = self.seg_text(texts, language)
if len(textlist) == 0:
return None, None, None
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
return phones, bert_features, norm_text
def seg_text(self, text:str, language:str)->Tuple[list, list]:
textlist=[]
langlist=[]
if language in ["auto", "zh", "ja"]:
LangSegment.setfilters(["zh","ja","en","ko"])
for tmp in LangSegment.getTexts(text):
if tmp["text"] == "":
continue
if tmp["lang"] == "ko":
langlist.append("zh")
elif tmp["lang"] == "en":
langlist.append("en")
else:
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language if language!="auto" else tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
while " " in formattext:
formattext = formattext.replace(" ", " ")
if formattext != "":
textlist.append(formattext)
langlist.append("en")
elif language in ["all_zh","all_ja"]:
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
language = language.replace("all_","")
if text == "":
return [],[]
textlist.append(formattext)
langlist.append(language)
else:
raise ValueError(f"language {language} not supported")
return textlist, langlist
def extract_bert_feature(self, textlist:list, langlist:list):
phones_list = []
bert_feature_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang)
_bert_feature = self.get_bert_inf(phones, word2ph, norm_text, lang)
# phones_list.append(phones)
phones_list.extend(phones)
norm_text_list.append(norm_text)
bert_feature_list.append(_bert_feature)
bert_feature = torch.cat(bert_feature_list, dim=1)
# phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones_list, bert_feature, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def clean_text_inf(self, text:str, language:str):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
language=language.replace("all_","")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else:
feature = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
return feature

View File

@ -0,0 +1 @@
from . import TTS, text_segmentation_method

View File

@ -0,0 +1,152 @@
import re
from typing import Callable
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
METHODS = dict()
def get_method(name:str)->Callable:
method = METHODS.get(name, None)
if method is None:
raise ValueError(f"Method {name} not found")
return method
def register_method(name):
def decorator(func):
METHODS[name] = func
return func
return decorator
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)
# 切割文本
segments = re.split('([' + punctuation + '])', text)
# 初始化结果列表和当前片段
result = []
current_segment = ''
for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
if len(current_segment + segment) > max_len:
result.append(current_segment)
current_segment = segment
else:
current_segment += segment
# 将最后一个片段加入结果列表
if current_segment:
result.append(current_segment)
return result
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
# 不切
@register_method("cut0")
def cut0(inp):
return inp
# 凑四句一切
@register_method("cut1")
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
# 凑50字一切
@register_method("cut2")
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
# 按中文句号。切
@register_method("cut3")
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
#按英文句号.切
@register_method("cut4")
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5")
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
if __name__ == '__main__':
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

View File

@ -0,0 +1,16 @@
custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda
flash_attn_enabled: true
is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
default:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
flash_attn_enabled: true
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth

View File

@ -20,13 +20,16 @@ cnhubert_base_path = None
class CNHubert(nn.Module):
def __init__(self):
def __init__(self, base_path:str=None):
super().__init__()
self.model = HubertModel.from_pretrained(cnhubert_base_path)
if base_path is None:
base_path = cnhubert_base_path
self.model = HubertModel.from_pretrained(base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
cnhubert_base_path
base_path
)
def forward(self, x):
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000

View File

@ -7,7 +7,7 @@ import soundfile as sf
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from GPT_SoVITS.inference_webui_old import change_gpt_weights, change_sovits_weights, get_tts_wav
class GPTSoVITSGUI(QMainWindow):

View File

@ -6,8 +6,11 @@
全部按英文识别
全部按日文识别
'''
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
import os, re, logging
import LangSegment
logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
@ -18,31 +21,7 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r', encoding="utf-8") as file:
gweight_data = file.read()
gpt_path = os.environ.get(
"gpt_path", gweight_data)
else:
gpt_path = os.environ.get(
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
if os.path.exists("./sweight.txt"):
with open("./sweight.txt", 'r', encoding="utf-8") as file:
sweight_data = file.read()
sovits_path = os.environ.get("sovits_path", sweight_data)
else:
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
# gpt_path = os.environ.get(
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# )
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
is_share = os.environ.get("is_share", "False")
@ -50,21 +29,14 @@ is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None)
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
import librosa
from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
@ -73,131 +45,11 @@ i18n = I18nAuto()
if torch.cuda.is_available():
device = "cuda"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
def change_sovits_weights(sovits_path):
global vq_model, hps
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
if ("pretrained" not in sovits_path):
del vq_model.enc_q
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt", "w", encoding="utf-8") as f:
f.write(sovits_path)
change_sovits_weights(sovits_path)
def change_gpt_weights(gpt_path):
global hz, max_sec, t2s_model, config
hz = 50
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half == True:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
change_gpt_weights(gpt_path)
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec
dict_language = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
@ -207,314 +59,62 @@ dict_language = {
i18n("多语种混合"): "auto",#多语种启动切分识别语种
}
cut_method = {
i18n("不切"):"cut0",
i18n("凑四句一切"): "cut1",
i18n("凑50字一切"): "cut2",
i18n("按中文句号。切"): "cut3",
i18n("按英文句号.切"): "cut4",
i18n("按标点符号切"): "cut5",
}
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
dtype=torch.float16 if is_half == True else torch.float32
def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
return bert
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_first(text):
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja","auto"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
textlist.append(tmp["text"])
else:
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
print(textlist)
print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones,bert.to(dtype),norm_text
def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device
tts_config.is_half = is_half
if gpt_path is not None:
tts_config.t2s_weights_path = gpt_path
if sovits_path is not None:
tts_config.vits_weights_path = sovits_path
if cnhubert_base_path is not None:
tts_config.cnhuhbert_base_path = cnhubert_base_path
if bert_path is not None:
tts_config.bert_base_path = bert_path
print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
if (how_to_cut == i18n("凑四句一切")):
text = cut1(text)
elif (how_to_cut == i18n("凑50字一切")):
text = cut2(text)
elif (how_to_cut == i18n("按中文句号。切")):
text = cut3(text)
elif (how_to_cut == i18n("按英文句号.切")):
text = cut4(text)
elif (how_to_cut == i18n("按标点符号切")):
text = cut5(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
print(i18n("实际输入的目标文本(切句后):"), text)
texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
else:
bert = bert2
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
None if ref_free else prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
print(tts_config)
tts_pipline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
def inference(text, text_lang,
ref_audio_path, prompt_text,
prompt_lang, top_k,
top_p, temperature,
text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed,
):
inputs={
"text": text,
"text_lang": dict_language[text_lang],
"ref_audio_path": ref_audio_path,
"prompt_text": prompt_text if not ref_text_free else "",
"prompt_lang": dict_language[prompt_lang],
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
"fragment_interval":fragment_interval,
"seed":seed,
}
for item in tts_pipline.run(inputs):
yield item
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
@ -552,65 +152,103 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
)
with gr.Group():
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
SoVITS_dropdown.change(tts_pipline.init_vits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(tts_pipline.init_t2s_weights, [GPT_dropdown], [])
with gr.Row():
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频超过会报错"), type="filepath")
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
with gr.Row():
text = gr.Textbox(label=i18n("需要合成的文本"), value="")
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=2)
with gr.Row():
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
with gr.Column():
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=16, max_lines=16)
text_language = gr.Dropdown(
label=i18n("需要合成的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
with gr.Row():
gr.Markdown(value=i18n("gpt采样参数(无参考文本时不要太低)"))
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
inference_button = gr.Button(i18n("合成语音"), variant="primary")
output = gr.Audio(label=i18n("输出的语音"))
with gr.Column():
how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
with gr.Row():
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
seed = gr.Number(label=i18n("随机种子"),value=-1)
# with gr.Column():
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
inference,
[
text,text_language, inp_ref,
prompt_text, prompt_language,
top_k, top_p, temperature,
how_to_cut, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed
],
[output],
)
stop_infer.click(tts_pipline.stop, [], [])
with gr.Group():
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
button1 = gr.Button(i18n("凑四句一切"), variant="primary")
button2 = gr.Button(i18n("凑50字一切"), variant="primary")
button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
button5 = gr.Button(i18n("按标点符号切"), variant="primary")
text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt])
button4.click(cut4, [text_inp], [text_opt])
button5.click(cut5, [text_inp], [text_opt])
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column():
_how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
cut_text= gr.Button(i18n("切分"), variant="primary")
def to_cut(text_inp, how_to_cut):
if len(text_inp.strip()) == 0 or text_inp==[]:
return ""
method = get_method(cut_method[how_to_cut])
return method(text_inp)
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
app.queue(concurrency_count=511, max_size=1022).launch(

View File

@ -0,0 +1,622 @@
'''
按中英混合识别
按日英混合识别
多语种启动切分识别语种
全部按中文识别
全部按英文识别
全部按日文识别
'''
import os, re, logging
import LangSegment
logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r', encoding="utf-8") as file:
gweight_data = file.read()
gpt_path = os.environ.get(
"gpt_path", gweight_data)
else:
gpt_path = os.environ.get(
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
if os.path.exists("./sweight.txt"):
with open("./sweight.txt", 'r', encoding="utf-8") as file:
sweight_data = file.read()
sovits_path = os.environ.get("sovits_path", sweight_data)
else:
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
# gpt_path = os.environ.get(
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# )
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
is_share = os.environ.get("is_share", "False")
is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
import librosa
from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
def change_sovits_weights(sovits_path):
global vq_model, hps
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
if ("pretrained" not in sovits_path):
del vq_model.enc_q
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt", "w", encoding="utf-8") as f:
f.write(sovits_path)
change_sovits_weights(sovits_path)
def change_gpt_weights(gpt_path):
global hz, max_sec, t2s_model, config
hz = 50
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half == True:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
change_gpt_weights(gpt_path)
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec
dict_language = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别
i18n("中英混合"): "zh",#按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种
}
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
dtype=torch.float16 if is_half == True else torch.float32
def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
return bert
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_first(text):
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja","auto"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
textlist.append(tmp["text"])
else:
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
print(textlist)
print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones,bert.to(dtype),norm_text
def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
if (how_to_cut == i18n("凑四句一切")):
text = cut1(text)
elif (how_to_cut == i18n("凑50字一切")):
text = cut2(text)
elif (how_to_cut == i18n("按中文句号。切")):
text = cut3(text)
elif (how_to_cut == i18n("按英文句号.切")):
text = cut4(text)
elif (how_to_cut == i18n("按标点符号切")):
text = cut5(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
print(i18n("实际输入的目标文本(切句后):"), text)
texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
else:
bert = bert2
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
None if ref_free else prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
# 将数字部分转换为整数,非数字部分保持不变
parts = [int(part) if part.isdigit() else part for part in parts]
return parts
def change_choices():
SoVITS_names, GPT_names = get_weights_names()
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
pretrained_sovits_name = "GPT_SoVITS/pretrained_models/s2G488k.pth"
pretrained_gpt_name = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
SoVITS_weight_root = "SoVITS_weights"
GPT_weight_root = "GPT_weights"
os.makedirs(SoVITS_weight_root, exist_ok=True)
os.makedirs(GPT_weight_root, exist_ok=True)
def get_weights_names():
SoVITS_names = [pretrained_sovits_name]
for name in os.listdir(SoVITS_weight_root):
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (SoVITS_weight_root, name))
GPT_names = [pretrained_gpt_name]
for name in os.listdir(GPT_weight_root):
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (GPT_weight_root, name))
return SoVITS_names, GPT_names
SoVITS_names, GPT_names = get_weights_names()
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
)
with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频超过会报错"), type="filepath")
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
with gr.Row():
text = gr.Textbox(label=i18n("需要合成的文本"), value="")
text_language = gr.Dropdown(
label=i18n("需要合成的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
with gr.Row():
gr.Markdown(value=i18n("gpt采样参数(无参考文本时不要太低)"))
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
inference_button = gr.Button(i18n("合成语音"), variant="primary")
output = gr.Audio(label=i18n("输出的语音"))
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
[output],
)
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
button1 = gr.Button(i18n("凑四句一切"), variant="primary")
button2 = gr.Button(i18n("凑50字一切"), variant="primary")
button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
button5 = gr.Button(i18n("按标点符号切"), variant="primary")
text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt])
button4.click(cut4, [text_inp], [text_opt])
button5.click(cut5, [text_inp], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
inbrowser=True,
share=is_share,
server_port=infer_ttswebui,
quiet=True,
)

View File

@ -1,5 +1,6 @@
import copy
import math
from typing import List
import torch
from torch import nn
from torch.nn import functional as F
@ -986,6 +987,55 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask)
# y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
# codes.dtype
# )
y_lengths = (y_lengths * 2).long().to(codes.device)
text_lengths = text_lengths.long().to(text.device)
# y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
# text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
# 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
z_masked = (z * y_mask)[:, :, :]
# 串行。把padding部分去掉再decode
o_list:List[torch.Tensor] = []
for i in range(z_masked.shape[0]):
z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
o = self.dec(z_slice, g=ge)[0, 0, :].detach()
o_list.append(o)
# 并行会有问题。先decode再把padding的部分去掉
# o = self.dec(z_masked, g=ge)
# upsample_rate = int(math.prod(self.upsample_rates))
# o_lengths = y_lengths*upsample_rate
# o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
return o_list
def extract_latent(self, x):
ssl = self.ssl_proj(x)