Merge pull request #721 from ChasonJiang/rebuild-and-optimize

tts infer重构优化和批量推理支持
This commit is contained in:
RVC-Boss 2024-03-10 12:24:40 +08:00 committed by GitHub
commit 8fd56afe91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2495 additions and 579 deletions

3
.gitignore vendored
View File

@ -10,5 +10,6 @@ reference
GPT_weights
SoVITS_weights
TEMP
ffmpeg.exe
ffprobe.exe

View File

@ -1,5 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
from typing import List
import torch
from tqdm import tqdm
@ -35,6 +37,142 @@ default_config = {
}
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
self.w1 = w1
self.b1 = b1
self.w2 = w2
self.b2 = b2
def forward(self, x):
x = F.relu(F.linear(x, self.w1, self.b1))
x = F.linear(x, self.w2, self.b2)
return x
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w
self.qkv_b = qkv_b
self.out_w = out_w
self.out_b = out_b
self.norm_w1 = norm_w1
self.norm_b1 = norm_b1
self.norm_eps1 = norm_eps1
self.norm_w2 = norm_w2
self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2
def process_prompt(self, x, attn_mask : torch.Tensor):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
k_cache = k
v_cache = v
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b)
x = F.layer_norm(
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(
x + self.mlp.forward(x),
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x, k_cache, v_cache):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b)
x = F.layer_norm(
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(
x + self.mlp.forward(x),
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks
self.blocks = blocks
def process_prompt(
self, x, attn_mask : torch.Tensor):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask)
k_cache.append(k_cache_)
v_cache.append(v_cache_)
return x, k_cache, v_cache
def decode_next_token(
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
@ -89,6 +227,37 @@ class Text2SemanticDecoder(nn.Module):
ignore_index=self.EOS,
)
blocks = []
for i in range(self.num_layers):
layer = self.h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)
block = T2SBlock(
self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -343,17 +512,9 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer": 1,
"stage": 0,
}
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
@ -361,7 +522,6 @@ class Text2SemanticDecoder(nn.Module):
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
cache["y_emb"] = y_emb
ref_free = False
else:
y_emb = None
@ -372,11 +532,25 @@ class Text2SemanticDecoder(nn.Module):
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
y_mask = make_pad_mask(y_lens)
x_mask = make_pad_mask(x_lens)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
_xy_padding_mask = (
xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1)
)
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
@ -385,64 +559,87 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
###### decode #####
y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0].unsqueeze(0)
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
if idx == 0:
xy_attn_mask = None
logits = logits[:, :-1]
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS则停止
l = samples[:, 0]==self.EOS
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if not (None in idx_list):
stop = True
if stop:
# if prompts.shape[1] == y.shape[1]:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# print("bad zero prediction")
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
cache["first_infer"] = 0
if cache["y_emb"] is not None:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
else:
y_emb = self.ar_audio_embedding(y[:, -1:])
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos
y_len = y_pos.shape[1]
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx-1
return y_list, [0]*x.shape[0]
return y_list, idx_list

View File

@ -0,0 +1,483 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
dpo_loss,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
"num_head": 8,
"num_layers": 12,
"num_codebook": 8,
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024,
}
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config["model"]["head"]
self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config["model"]["dropout"]
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1
# should be same as num of kmeans bin
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.h = TransformerEncoder(
TransformerEncoderLayer(
d_model=self.model_dim,
nhead=self.num_head,
dim_feedforward=self.model_dim * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size,
top_k=top_k,
average="micro",
multidim_average="global",
ignore_index=self.EOS,
)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens)
y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64)
codes = y.type(torch.int64) * (1 - y_mask_int)
# Training
# AR Decoder
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
x_len = x_lens.max()
y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
ar_xy_padding_mask = xy_padding_mask
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1,
),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
# x 和完整的 y 一次性输入模型
xy_pos = torch.concat([x, y_pos], dim=1)
return xy_pos, xy_attn_mask, targets
def forward(self, x, x_lens, y, y_lens, bert_feature):
"""
x: phoneme_ids
y: semantic_ids
"""
reject_y, reject_y_lens = make_reject_y(y, y_lens)
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
x_len = x_lens.max()
logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
mask=reject_xy_attn_mask,
)
x_len = x_lens.max()
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2
return loss, acc
def forward_old(self, x, x_lens, y, y_lens, bert_feature):
"""
x: phoneme_ids
y: semantic_ids
"""
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens)
y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64)
codes = y.type(torch.int64) * (1 - y_mask_int)
# Training
# AR Decoder
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
x_len = x_lens.max()
y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
ar_xy_padding_mask = xy_padding_mask
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1,
),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
# x 和完整的 y 一次性输入模型
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
for _ in tqdm(range(1500)):
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型
xy_pos = torch.concat([x, y_pos], dim=1)
y_len = y.shape[1]
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
# import os
# os._exit(2333)
y = torch.concat([y, samples], dim=1)
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer": 1,
"stage": 0,
}
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
cache["y_emb"] = y_emb
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
# 移除已经生成完毕的序列
reserved_idx_of_batch_for_y = None
if (self.EOS in torch.argmax(logits, dim=-1)) or \
(self.EOS in samples[:, 0]): ###如果生成到EOS则停止
l = samples[:, 0]==self.EOS
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if cache["y_emb"] is not None:
cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
if cache["k"] is not None:
for i in range(self.num_layers):
# 因为kv转置了所以batch dim是1
cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if not (None in idx_list):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
# if prompts.shape[1] == y.shape[1]:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# print("bad zero prediction")
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
cache["first_infer"] = 0
if cache["y_emb"] is not None:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
else:
y_emb = self.ar_audio_embedding(y[:, -1:])
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos
y_len = y_pos.shape[1]
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, idx_list

View File

@ -115,17 +115,17 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
if previous_tokens is not None:
previous_tokens = previous_tokens.squeeze()
# if previous_tokens is not None:
# previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@ -133,9 +133,9 @@ def logits_to_probs(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))

View File

@ -0,0 +1,680 @@
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
import ffmpeg
import os
from typing import Generator, List, Union
import numpy as np
import torch
import yaml
from transformers import AutoModelForMaskedLM, AutoTokenizer
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from feature_extractor.cnhubert import CNHubert
from module.models import SynthesizerTrn
import librosa
from time import time as ttime
from tools.i18n.i18n import I18nAuto
from my_utils import load_audio
from module.mel_processing import spectrogram_torch
from .text_segmentation_method import splits
from .TextPreprocessor import TextPreprocessor
i18n = I18nAuto()
# configs/tts_infer.yaml
"""
default:
device: cpu
is_half: false
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
custom:
device: cuda
is_half: true
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
"""
class TTS_Config:
def __init__(self, configs: Union[dict, str]):
configs_base_path:str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(configs)
# assert isinstance(configs, dict)
self.default_configs:dict = configs.get("default", None)
if self.default_configs is None:
self.default_configs={
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
}
self.configs:dict = configs.get("custom", self.default_configs)
self.device = self.configs.get("device")
self.is_half = self.configs.get("is_half")
self.t2s_weights_path = self.configs.get("t2s_weights_path")
self.vits_weights_path = self.configs.get("vits_weights_path")
self.bert_base_path = self.configs.get("bert_base_path")
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
self.max_sec = None
self.hz:int = 50
self.semantic_frame_rate:str = "25hz"
self.segment_size:int = 20480
self.filter_length:int = 2048
self.sampling_rate:int = 32000
self.hop_length:int = 640
self.win_length:int = 2048
self.n_speakers:int = 300
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
print(self)
def _load_configs(self, configs_path: str)->dict:
with open(configs_path, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
return configs
def save_configs(self, configs_path:str=None)->None:
configs={
"default": {
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
},
"custom": {
"device": str(self.device),
"is_half": self.is_half,
"t2s_weights_path": self.t2s_weights_path,
"vits_weights_path": self.vits_weights_path,
"bert_base_path": self.bert_base_path,
"cnhuhbert_base_path": self.cnhuhbert_base_path
}
}
if configs_path is None:
configs_path = self.configs_path
with open(configs_path, 'w') as f:
yaml.dump(configs, f)
def __str__(self):
string = "----------------TTS Config--------------\n"
string += "device: {}\n".format(self.device)
string += "is_half: {}\n".format(self.is_half)
string += "bert_base_path: {}\n".format(self.bert_base_path)
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
string += "----------------------------------------\n"
return string
class TTS:
def __init__(self, configs: Union[dict, str, TTS_Config]):
if isinstance(configs, TTS_Config):
self.configs = configs
else:
self.configs:TTS_Config = TTS_Config(configs)
self.t2s_model:Text2SemanticLightningModule = None
self.vits_model:SynthesizerTrn = None
self.bert_tokenizer:AutoTokenizer = None
self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None
self._init_models()
self.text_preprocessor:TextPreprocessor = \
TextPreprocessor(self.bert_model,
self.bert_tokenizer,
self.configs.device)
self.prompt_cache:dict = {
"ref_audio_path":None,
"prompt_semantic":None,
"refer_spepc":None,
"prompt_text":None,
"prompt_lang":None,
"phones":None,
"bert_features":None,
"norm_text":None,
}
self.stop_flag:bool = False
def _init_models(self,):
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
self.init_bert_weights(self.configs.bert_base_path)
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
def init_cnhuhbert_weights(self, base_path: str):
self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model.eval()
if self.configs.is_half == True:
self.cnhuhbert_model = self.cnhuhbert_model.half()
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
def init_bert_weights(self, base_path: str):
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
if self.configs.is_half:
self.bert_model = self.bert_model.half()
self.bert_model = self.bert_model.to(self.configs.device)
def init_vits_weights(self, weights_path: str):
self.configs.vits_weights_path = weights_path
self.configs.save_configs()
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
hps = dict_s2["config"]
self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"]
self.configs.sampling_rate = hps["data"]["sampling_rate"]
self.configs.hop_length = hps["data"]["hop_length"]
self.configs.win_length = hps["data"]["win_length"]
self.configs.n_speakers = hps["data"]["n_speakers"]
self.configs.semantic_frame_rate = "25hz"
kwargs = hps["model"]
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
n_speakers=self.configs.n_speakers,
**kwargs
)
# if ("pretrained" not in weights_path):
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
if self.configs.is_half:
vits_model = vits_model.half()
vits_model = vits_model.to(self.configs.device)
vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model
def init_t2s_weights(self, weights_path: str):
self.configs.t2s_weights_path = weights_path
self.configs.save_configs()
self.configs.hz = 50
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
config = dict_s1["config"]
self.configs.max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if self.configs.is_half:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(self.configs.device)
t2s_model.eval()
self.t2s_model = t2s_model
def set_ref_audio(self, ref_audio_path:str):
'''
To set the reference audio for the TTS model,
including the prompt_semantic and refer_spepc.
Args:
ref_audio_path: str, the path of the reference audio.
'''
self._set_prompt_semantic(ref_audio_path)
self._set_ref_spepc(ref_audio_path)
def _set_ref_spepc(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
self.configs.filter_length,
self.configs.sampling_rate,
self.configs.hop_length,
self.configs.win_length,
center=False,
)
spec = spec.to(self.configs.device)
if self.configs.is_half:
spec = spec.half()
# self.refer_spepc = spec
self.prompt_cache["refer_spepc"] = spec
def _set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=np.float16 if self.configs.is_half else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(self.configs.device)
zero_wav_torch = zero_wav_torch.to(self.configs.device)
if self.configs.is_half:
wav16k = wav16k.half()
zero_wav_torch = zero_wav_torch.half()
wav16k = torch.cat([wav16k, zero_wav_torch])
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = self.vits_model.extract_latent(hubert_feature)
prompt_semantic = codes[0, 0].to(self.configs.device)
self.prompt_cache["prompt_semantic"] = prompt_semantic
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
seq = sequences[0]
ndim = seq.dim()
if axis < 0:
axis += ndim
dtype:torch.dtype = seq.dtype
pad_value = torch.tensor(pad_value, dtype=dtype)
seq_lengths = [seq.shape[axis] for seq in sequences]
if max_length is None:
max_length = max(seq_lengths)
else:
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value)
padded_sequences.append(padded_seq)
batch = torch.stack(padded_sequences)
return batch
def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75, split_bucket:bool=True):
_data:list = []
index_and_len_list = []
for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
batch_index_list = []
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/batch.mean()
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
assert batch_index_list_len == len(data)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
phones = torch.LongTensor(item["phones"])
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]
phones = torch.LongTensor(item["phones"])
all_phones = phones.clone()
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones)
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_batch = phones_list
max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
all_bert_features_batch.zero_()
for idx, item in enumerate(all_bert_features_list):
if item != None:
all_bert_features_batch[idx, :, : item.shape[-1]] = item
batch = {
"phones": phones_batch,
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch
}
_data.append(batch)
return _data, batch_index_list
def recovery_order(self, data:list, batch_index_list:list)->list:
'''
Recovery the order of the audio according to the batch_index_list.
Args:
data (List[list(np.ndarray)]): the out of order audio .
batch_index_list (List[list[int]]): the batch index list.
Returns:
list (List[np.ndarray]): the data in the original order.
'''
lenght = len(sum(batch_index_list, []))
_data = [None]*lenght
for i, index_list in enumerate(batch_index_list):
for j, index in enumerate(index_list):
_data[index] = data[i][j]
return _data
def stop(self,):
'''
Stop the inference process.
'''
self.stop_flag = True
def run(self, inputs:dict):
"""
Text to speech inference.
Args:
inputs (dict):
{
"text": "", # str. text to be synthesized
"text_lang: "", # str. language of the text to be synthesized
"ref_audio_path": "", # str. reference audio path
"prompt_text": "", # str. prompt text for the reference audio
"prompt_lang": "", # str. language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 0.9, # float. top p sampling
"temperature": 0.6, # float. temperature for sampling
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
}
returns:
tulpe[int, np.ndarray]: sampling rate and audio data.
"""
########## variables initialization ###########
self.stop_flag:bool = False
text:str = inputs.get("text", "")
text_lang:str = inputs.get("text_lang", "")
ref_audio_path:str = inputs.get("ref_audio_path", "")
prompt_text:str = inputs.get("prompt_text", "")
prompt_lang:str = inputs.get("prompt_lang", "")
top_k:int = inputs.get("top_k", 20)
top_p:float = inputs.get("top_p", 0.9)
temperature:float = inputs.get("temperature", 0.6)
text_split_method:str = inputs.get("text_split_method", "")
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False)
if return_fragment:
split_bucket = False
print(i18n("分段返回模式已开启"))
if split_bucket:
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket:
print(i18n("分桶处理模式已开启"))
no_prompt_text = False
if prompt_text in [None, ""]:
no_prompt_text = True
assert text_lang in self.configs.langauges
if not no_prompt_text:
assert prompt_lang in self.configs.langauges
if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spepc"] is None)):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
###### setting reference audio and prompt text preprocessing ########
t0 = ttime()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
self.set_ref_audio(ref_audio_path)
if not no_prompt_text:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_lang != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
if self.prompt_cache["prompt_text"] != prompt_text:
self.prompt_cache["prompt_text"] = prompt_text
self.prompt_cache["prompt_lang"] = prompt_lang
phones, bert_features, norm_text = \
self.text_preprocessor.segment_and_extract_feature_for_text(
prompt_text,
prompt_lang)
self.prompt_cache["phones"] = phones
self.prompt_cache["bert_features"] = bert_features
self.prompt_cache["norm_text"] = norm_text
###### text preprocessing ########
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
audio = []
t1 = ttime()
data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size,
threshold=batch_threshold,
split_bucket=split_bucket
)
t2 = ttime()
zero_wav = torch.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=torch.float16 if self.configs.is_half else torch.float32,
device=self.configs.device
)
###### inference ######
t_34 = 0.0
t_45 = 0.0
for item in data:
t3 = ttime()
batch_phones = item["phones"]
all_phoneme_ids = item["all_phones"]
all_phoneme_lens = item["all_phones_len"]
all_bert_features = item["all_bert_features"]
norm_text = item["norm_text"]
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
all_bert_features = all_bert_features.to(self.configs.device)
if self.configs.is_half:
all_bert_features = all_bert_features.half()
print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text :
prompt = None
else:
prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device)
with torch.no_grad():
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
)
t4 = ttime()
t_34 += t4 - t3
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].clone().to(self.configs.device)
if self.configs.is_half:
refer_audio_spepc = refer_audio_spepc.half()
## 直接对batch进行decode 生成的音频会有问题
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
# batch_phones = batch_phones.to(self.configs.device)
# batch_audio_fragment =(self.vits_model.decode(
# pred_semantic, batch_phones, refer_audio_spepc
# ).detach()[:, 0, :])
# max_audio=torch.abs(batch_audio_fragment).max()#简单防止16bit爆音
# if max_audio>1: batch_audio_fragment/=max_audio
# batch_audio_fragment = batch_audio_fragment.cpu().numpy()
## 改成串行处理
batch_audio_fragment = []
for i, idx in enumerate(idx_list):
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment =(self.vits_model.decode(
_pred_semantic, phones, refer_audio_spepc
).detach()[0, 0, :])
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio
audio_fragment = torch.cat([audio_fragment, zero_wav], dim=0)
batch_audio_fragment.append(
audio_fragment.cpu().numpy()
) ###试试重建不带上prompt部分
t5 = ttime()
t_45 += t5 - t4
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess(batch_audio_fragment,
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket)
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16)
return
if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield self.audio_postprocess(audio,
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket)
def audio_postprocess(self,
audio:np.ndarray,
sr:int,
batch_index_list:list=None,
speed_factor:float=1.0,
split_bucket:bool=True)->tuple[int, np.ndarray]:
if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
else:
audio = [item for batch in audio for item in batch]
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
try:
if speed_factor != 1.0:
audio = speed_change(audio, speed=speed_factor, sr=int(sr))
except Exception as e:
print(f"Failed to change speed of audio: \n{e}")
return sr, audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()
# 设置 ffmpeg 输入流
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
# 变速处理
output_stream = input_stream.filter('atempo', speed)
# 输出流到管道
out, _ = (
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
)
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio

View File

@ -0,0 +1,176 @@
import re
import torch
import LangSegment
from typing import Dict, List, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from .text_segmentation_method import splits, get_method as get_seg_method
# from tools.i18n.i18n import I18nAuto
# i18n = I18nAuto()
def get_first(text:str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts:str, threshold:int) -> list:
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
for text in texts:
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
res={
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
result.append(res)
return result
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4):
text = "" + text if lang != "en" else "." + text
# print(i18n("实际输入的目标文本:"), text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
# print(i18n("实际输入的目标文本(切句后):"), text)
_texts = text.split("\n")
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if lang != "en" else "."
texts.append(text)
return texts
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
textlist, langlist = self.seg_text(texts, language)
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
return phones, bert_features, norm_text
def seg_text(self, text:str, language:str)->Tuple[list, list]:
textlist=[]
langlist=[]
if language in ["auto", "zh", "ja"]:
# LangSegment.setfilters(["zh","ja","en","ko"])
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
elif tmp["lang"] == "en":
langlist.append("en")
else:
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language if language!="auto" else tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
# LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
while " " in formattext:
formattext = formattext.replace(" ", " ")
textlist.append(formattext)
langlist.append("en")
elif language in ["all_zh","all_ja"]:
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
language = language.replace("all_","")
textlist.append(formattext)
langlist.append(language)
else:
raise ValueError(f"language {language} not supported")
return textlist, langlist
def extract_bert_feature(self, textlist:list, langlist:list):
phones_list = []
bert_feature_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang)
_bert_feature = self.get_bert_inf(phones, word2ph, norm_text, lang)
# phones_list.append(phones)
phones_list.extend(phones)
norm_text_list.append(norm_text)
bert_feature_list.append(_bert_feature)
bert_feature = torch.cat(bert_feature_list, dim=1)
# phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones, bert_feature, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def clean_text_inf(self, text:str, language:str):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
language=language.replace("all_","")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else:
feature = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
return feature

View File

@ -0,0 +1 @@
from . import TTS, text_segmentation_method

View File

@ -0,0 +1,126 @@
import re
from typing import Callable
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
METHODS = dict()
def get_method(name:str)->Callable:
method = METHODS.get(name, None)
if method is None:
raise ValueError(f"Method {name} not found")
return method
def register_method(name):
def decorator(func):
METHODS[name] = func
return func
return decorator
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
# 不切
@register_method("cut0")
def cut0(inp):
return inp
# 凑四句一切
@register_method("cut1")
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
# 凑50字一切
@register_method("cut2")
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
# 按中文句号。切
@register_method("cut3")
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
#按英文句号.切
@register_method("cut4")
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5")
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
if __name__ == '__main__':
method = get_method("cut1")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

View File

@ -0,0 +1,14 @@
custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda
is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
default:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth

View File

@ -20,13 +20,16 @@ cnhubert_base_path = None
class CNHubert(nn.Module):
def __init__(self):
def __init__(self, base_path:str=None):
super().__init__()
self.model = HubertModel.from_pretrained(cnhubert_base_path)
if base_path is None:
base_path = cnhubert_base_path
self.model = HubertModel.from_pretrained(base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
cnhubert_base_path
base_path
)
def forward(self, x):
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000

View File

@ -7,7 +7,7 @@ import soundfile as sf
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from GPT_SoVITS.inference_webui_old import change_gpt_weights, change_sovits_weights, get_tts_wav
class GPTSoVITSGUI(QMainWindow):

View File

@ -6,8 +6,11 @@
全部按英文识别
全部按日文识别
'''
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
import os, re, logging
import LangSegment
logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
@ -17,32 +20,9 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r', encoding="utf-8") as file:
gweight_data = file.read()
gpt_path = os.environ.get(
"gpt_path", gweight_data)
else:
gpt_path = os.environ.get(
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
if os.path.exists("./sweight.txt"):
with open("./sweight.txt", 'r', encoding="utf-8") as file:
sweight_data = file.read()
sovits_path = os.environ.get("sovits_path", sweight_data)
else:
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
# gpt_path = os.environ.get(
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# )
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
is_share = os.environ.get("is_share", "False")
@ -51,22 +31,10 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
import librosa
from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5
from tools.i18n.i18n import I18nAuto
from TTS_infer_pack.text_segmentation_method import get_method
i18n = I18nAuto()
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
@ -76,128 +44,6 @@ if torch.cuda.is_available():
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
def change_sovits_weights(sovits_path):
global vq_model, hps
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
if ("pretrained" not in sovits_path):
del vq_model.enc_q
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt", "w", encoding="utf-8") as f:
f.write(sovits_path)
change_sovits_weights(sovits_path)
def change_gpt_weights(gpt_path):
global hz, max_sec, t2s_model, config
hz = 50
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half == True:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
change_gpt_weights(gpt_path)
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec
dict_language = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
@ -207,313 +53,46 @@ dict_language = {
i18n("多语种混合"): "auto",#多语种启动切分识别语种
}
cut_method = {
i18n("不切"):"cut0",
i18n("凑四句一切"): "cut1",
i18n("凑50字一切"): "cut2",
i18n("按中文句号。切"): "cut3",
i18n("按英文句号.切"): "cut4",
i18n("按标点符号切"): "cut5",
}
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
dtype=torch.float16 if is_half == True else torch.float32
def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
return bert
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_first(text):
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja","auto"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
textlist.append(tmp["text"])
else:
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
print(textlist)
print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones,bert.to(dtype),norm_text
def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
if (how_to_cut == i18n("凑四句一切")):
text = cut1(text)
elif (how_to_cut == i18n("凑50字一切")):
text = cut2(text)
elif (how_to_cut == i18n("按中文句号。切")):
text = cut3(text)
elif (how_to_cut == i18n("按英文句号.切")):
text = cut4(text)
elif (how_to_cut == i18n("按标点符号切")):
text = cut5(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
print(i18n("实际输入的目标文本(切句后):"), text)
texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
else:
bert = bert2
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
None if ref_free else prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device
tts_config.is_half = is_half
tts_pipline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
def inference(text, text_lang,
ref_audio_path, prompt_text,
prompt_lang, top_k,
top_p, temperature,
text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket
):
inputs={
"text": text,
"text_lang": dict_language[text_lang],
"ref_audio_path": ref_audio_path,
"prompt_text": prompt_text if not ref_text_free else "",
"prompt_lang": dict_language[prompt_lang],
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
}
yield next(tts_pipline.run(inputs))
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
@ -552,65 +131,99 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
)
with gr.Group():
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
SoVITS_dropdown.change(tts_pipline.init_vits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(tts_pipline.init_t2s_weights, [GPT_dropdown], [])
with gr.Row():
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频超过会报错"), type="filepath")
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
with gr.Row():
text = gr.Textbox(label=i18n("需要合成的文本"), value="")
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=2)
with gr.Row():
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
with gr.Column():
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=16, max_lines=16)
text_language = gr.Dropdown(
label=i18n("需要合成的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
with gr.Row():
gr.Markdown(value=i18n("gpt采样参数(无参考文本时不要太低)"))
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True)
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
inference_button = gr.Button(i18n("合成语音"), variant="primary")
output = gr.Audio(label=i18n("输出的语音"))
with gr.Column():
how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
# with gr.Column():
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
inference,
[
text,text_language, inp_ref,
prompt_text, prompt_language,
top_k, top_p, temperature,
how_to_cut, batch_size,
speed_factor, ref_text_free,
split_bucket
],
[output],
)
stop_infer.click(tts_pipline.stop, [], [])
with gr.Group():
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
button1 = gr.Button(i18n("凑四句一切"), variant="primary")
button2 = gr.Button(i18n("凑50字一切"), variant="primary")
button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
button5 = gr.Button(i18n("按标点符号切"), variant="primary")
text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt])
button4.click(cut4, [text_inp], [text_opt])
button5.click(cut5, [text_inp], [text_opt])
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column():
_how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
cut_text= gr.Button(i18n("切分"), variant="primary")
def to_cut(text_inp, how_to_cut):
if len(text_inp.strip()) == 0 or text_inp==[]:
return ""
method = get_method(cut_method[how_to_cut])
return method(text_inp)
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
app.queue(concurrency_count=511, max_size=1022).launch(

View File

@ -0,0 +1,622 @@
'''
按中英混合识别
按日英混合识别
多语种启动切分识别语种
全部按中文识别
全部按英文识别
全部按日文识别
'''
import os, re, logging
import LangSegment
logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r', encoding="utf-8") as file:
gweight_data = file.read()
gpt_path = os.environ.get(
"gpt_path", gweight_data)
else:
gpt_path = os.environ.get(
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
if os.path.exists("./sweight.txt"):
with open("./sweight.txt", 'r', encoding="utf-8") as file:
sweight_data = file.read()
sovits_path = os.environ.get("sovits_path", sweight_data)
else:
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
# gpt_path = os.environ.get(
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# )
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
is_share = os.environ.get("is_share", "False")
is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
import librosa
from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
def change_sovits_weights(sovits_path):
global vq_model, hps
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
if ("pretrained" not in sovits_path):
del vq_model.enc_q
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt", "w", encoding="utf-8") as f:
f.write(sovits_path)
change_sovits_weights(sovits_path)
def change_gpt_weights(gpt_path):
global hz, max_sec, t2s_model, config
hz = 50
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half == True:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
change_gpt_weights(gpt_path)
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec
dict_language = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别
i18n("中英混合"): "zh",#按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种
}
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
return phones, word2ph, norm_text
dtype=torch.float16 if is_half == True else torch.float32
def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
return bert
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_first(text):
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)
else:
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja","auto"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
textlist.append(tmp["text"])
else:
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
print(textlist)
print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones,bert.to(dtype),norm_text
def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
if (how_to_cut == i18n("凑四句一切")):
text = cut1(text)
elif (how_to_cut == i18n("凑50字一切")):
text = cut2(text)
elif (how_to_cut == i18n("按中文句号。切")):
text = cut3(text)
elif (how_to_cut == i18n("按英文句号.切")):
text = cut4(text)
elif (how_to_cut == i18n("按标点符号切")):
text = cut5(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
print(i18n("实际输入的目标文本(切句后):"), text)
texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
else:
bert = bert2
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
None if ref_free else prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
# 将数字部分转换为整数,非数字部分保持不变
parts = [int(part) if part.isdigit() else part for part in parts]
return parts
def change_choices():
SoVITS_names, GPT_names = get_weights_names()
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
pretrained_sovits_name = "GPT_SoVITS/pretrained_models/s2G488k.pth"
pretrained_gpt_name = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
SoVITS_weight_root = "SoVITS_weights"
GPT_weight_root = "GPT_weights"
os.makedirs(SoVITS_weight_root, exist_ok=True)
os.makedirs(GPT_weight_root, exist_ok=True)
def get_weights_names():
SoVITS_names = [pretrained_sovits_name]
for name in os.listdir(SoVITS_weight_root):
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (SoVITS_weight_root, name))
GPT_names = [pretrained_gpt_name]
for name in os.listdir(GPT_weight_root):
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (GPT_weight_root, name))
return SoVITS_names, GPT_names
SoVITS_names, GPT_names = get_weights_names()
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
)
with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频超过会报错"), type="filepath")
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
with gr.Row():
text = gr.Textbox(label=i18n("需要合成的文本"), value="")
text_language = gr.Dropdown(
label=i18n("需要合成的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
with gr.Row():
gr.Markdown(value=i18n("gpt采样参数(无参考文本时不要太低)"))
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
inference_button = gr.Button(i18n("合成语音"), variant="primary")
output = gr.Audio(label=i18n("输出的语音"))
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
[output],
)
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
button1 = gr.Button(i18n("凑四句一切"), variant="primary")
button2 = gr.Button(i18n("凑50字一切"), variant="primary")
button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
button5 = gr.Button(i18n("按标点符号切"), variant="primary")
text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt])
button4.click(cut4, [text_inp], [text_opt])
button5.click(cut5, [text_inp], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
inbrowser=True,
share=is_share,
server_port=infer_ttswebui,
quiet=True,
)