mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 04:22:46 +08:00
All in one! 合并main分支和fast_inference_分支 (#1490)
* 合并main分支和fast_inference_分支 * 修复一些bug
This commit is contained in:
parent
9c75f35ece
commit
52c50c6c81
2
.gitignore
vendored
2
.gitignore
vendored
@ -14,3 +14,5 @@ GPT_weights_v2
|
||||
SoVITS_weights_v2
|
||||
TEMP
|
||||
weight.json
|
||||
ffmpeg*
|
||||
ffprobe*
|
@ -1,11 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import math
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
from AR.models.utils import make_pad_mask
|
||||
from AR.models.utils import (
|
||||
topk_sampling,
|
||||
@ -37,6 +36,34 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
||||
# Efficient implementation equivalent to the following:
|
||||
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
|
||||
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
if scale is None:
|
||||
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
||||
else:
|
||||
scale_factor = scale
|
||||
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_bias.masked_fill_(attn_mask, float("-inf"))
|
||||
else:
|
||||
attn_bias += attn_mask
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight += attn_bias
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
else:
|
||||
attn_mask[attn_mask!=float("-inf")] =0
|
||||
attn_mask[attn_mask==float("-inf")] =1
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
|
||||
return attn_weight @ value
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
@ -84,61 +111,112 @@ class T2SBlock:
|
||||
self.norm_b2 = norm_b2
|
||||
self.norm_eps2 = norm_eps2
|
||||
|
||||
def process_prompt(self, x, attn_mask: torch.Tensor):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
if padding_mask.dtype == torch.bool:
|
||||
return x.masked_fill(padding_mask, 0)
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
|
||||
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k.shape[1]
|
||||
|
||||
k_cache = k
|
||||
v_cache = v
|
||||
|
||||
q = self.to_mask(q, padding_mask)
|
||||
k_cache = self.to_mask(k, padding_mask)
|
||||
v_cache = self.to_mask(v, padding_mask)
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
if torch_sdpa:
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
else:
|
||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
x = F.layer_norm(
|
||||
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(
|
||||
x + self.mlp.forward(x),
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
if self.false.device!= padding_mask.device:
|
||||
self.false = self.false.to(padding_mask.device)
|
||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
||||
x_item = x[i,idx,:].unsqueeze(0)
|
||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(
|
||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i,idx,:] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x, k_cache, v_cache):
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
v_cache = torch.cat([v_cache, v], dim=1)
|
||||
kv_len = k_cache.shape[1]
|
||||
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k_cache.shape[1]
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
|
||||
if torch_sdpa:
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x + self.mlp.forward(x),
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
@ -149,25 +227,32 @@ class T2SBlock:
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
|
||||
self.num_blocks: int = num_blocks
|
||||
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x, attn_mask: torch.Tensor):
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,
|
||||
padding_mask : Optional[torch.Tensor]=None,
|
||||
torch_sdpa:bool=True
|
||||
):
|
||||
k_cache : List[torch.Tensor] = []
|
||||
v_cache : List[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask)
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
|
||||
k_cache.append(k_cache_)
|
||||
v_cache.append(v_cache_)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
|
||||
self, x:torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
attn_mask : Optional[torch.Tensor]=None,
|
||||
torch_sdpa:bool=True
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@ -235,7 +320,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
)
|
||||
# (layer.self_attn.in_proj_weight, layer.self_attn.in_proj_bias)
|
||||
|
||||
block = T2SBlock(
|
||||
self.num_head,
|
||||
self.model_dim,
|
||||
@ -253,7 +338,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
|
||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||
@ -283,7 +368,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
|
||||
# x_attn_mask[:, x_len]=False
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(
|
||||
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
||||
@ -344,7 +429,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
||||
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
||||
|
||||
|
||||
loss = loss_1 + loss_2
|
||||
|
||||
return loss, acc
|
||||
@ -488,16 +573,225 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
x, #####全部文本token
|
||||
x_lens,
|
||||
prompts, ####参考音频token
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
def infer_panel_batch_infer(
|
||||
self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs,
|
||||
):
|
||||
if prompts is None:
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
|
||||
|
||||
|
||||
max_len = kwargs.get("max_len",x_lens.max())
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
|
||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
|
||||
x_list.append(x_item)
|
||||
x = torch.stack(x_list, dim=0)
|
||||
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
|
||||
x_len = x.shape[1]
|
||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||
stop = False
|
||||
|
||||
k_cache = None
|
||||
v_cache = None
|
||||
################### first step ##########################
|
||||
if y is not None:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
ref_free = False
|
||||
else:
|
||||
y_emb = None
|
||||
y_len = 0
|
||||
prefix_len = 0
|
||||
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
|
||||
y_pos = None
|
||||
xy_pos = x
|
||||
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
||||
ref_free = True
|
||||
|
||||
|
||||
##### create mask #####
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
y_paddind_mask = make_pad_mask(y_lens, y_len)
|
||||
x_paddind_mask = make_pad_mask(x_lens, max_len)
|
||||
|
||||
# (bsz, x_len + y_len)
|
||||
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
||||
|
||||
x_mask = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
|
||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
for i in range(bsz):
|
||||
l = x_lens[i]
|
||||
_xy_padding_mask[i,l:max_len,:]=True
|
||||
|
||||
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||
xy_attn_mask = xy_attn_mask.bool()
|
||||
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
|
||||
if idx == 0:
|
||||
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
||||
logits = logits[:, :-1]
|
||||
else:
|
||||
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||
tokens = torch.argmax(logits, dim=-1)
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in samples[:, 0]) or \
|
||||
(self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0]==self.EOS
|
||||
l2 = tokens==self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx - 1
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留batch中未生成完毕的序列
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None :
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
|
||||
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
for i, batch_index in enumerate(batch_idx_map):
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
if not (None in idx_list):
|
||||
stop = True
|
||||
|
||||
if stop:
|
||||
if y.shape[1]==0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
break
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
||||
|
||||
if (None in idx_list):
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
if ref_free:
|
||||
return y_list, [0]*x.shape[0]
|
||||
# print(idx_list)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive_batched(self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
):
|
||||
y_list = []
|
||||
idx_list = []
|
||||
for i in range(len(x)):
|
||||
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
top_k,
|
||||
top_p,
|
||||
early_stop_num,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs)
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@ -528,9 +822,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_pos = None
|
||||
xy_pos = x
|
||||
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
||||
prompts = y
|
||||
ref_free = True
|
||||
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
x_attn_mask_pad = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
@ -541,13 +836,15 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
x.device
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
|
||||
for idx in tqdm(range(1500)):
|
||||
if xy_attn_mask is not None:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
|
||||
@ -558,9 +855,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if idx == 0:
|
||||
xy_attn_mask = None
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(
|
||||
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0].unsqueeze(0)
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
@ -583,4 +881,20 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx - 1
|
||||
return y[:, :-1], idx - 1
|
||||
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
):
|
||||
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
|
||||
|
@ -115,17 +115,17 @@ def logits_to_probs(
|
||||
top_p: Optional[int] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
if previous_tokens is not None:
|
||||
previous_tokens = previous_tokens.squeeze()
|
||||
# if previous_tokens is not None:
|
||||
# previous_tokens = previous_tokens.squeeze()
|
||||
# print(logits.shape,previous_tokens.shape)
|
||||
# pdb.set_trace()
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
@ -133,9 +133,9 @@ def logits_to_probs(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@ -143,7 +143,7 @@ def logits_to_probs(
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
1023
GPT_SoVITS/TTS_infer_pack/TTS.py
Normal file
1023
GPT_SoVITS/TTS_infer_pack/TTS.py
Normal file
File diff suppressed because it is too large
Load Diff
244
GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
Normal file
244
GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
Normal file
@ -0,0 +1,244 @@
|
||||
|
||||
import os, sys
|
||||
|
||||
from tqdm import tqdm
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
import re
|
||||
import torch
|
||||
import LangSegment
|
||||
from text import chinese
|
||||
from typing import Dict, List, Tuple
|
||||
from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
||||
|
||||
def get_first(text:str) -> str:
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
text = ""
|
||||
for ele in texts:
|
||||
text += ele
|
||||
if len(text) >= threshold:
|
||||
result.append(text)
|
||||
text = ""
|
||||
if (len(text) > 0):
|
||||
if len(result) == 0:
|
||||
result.append(text)
|
||||
else:
|
||||
result[len(result) - 1] += text
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
tokenizer:AutoTokenizer, device:torch.device):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
|
||||
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v1")->List[Dict]:
|
||||
print(i18n("############ 切分文本 ############"))
|
||||
text = self.replace_consecutive_punctuation(text) # 变量命名应该是写错了
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
print(i18n("############ 提取文本Bert特征 ############"))
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
if phones is None or norm_text=="":
|
||||
continue
|
||||
res={
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
}
|
||||
result.append(res)
|
||||
return result
|
||||
|
||||
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
|
||||
text = text.strip("\n")
|
||||
if len(text) == 0:
|
||||
return []
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
print(i18n("实际输入的目标文本:"))
|
||||
print(text)
|
||||
|
||||
seg_method = get_seg_method(text_split_method)
|
||||
text = seg_method(text)
|
||||
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
|
||||
_texts = text.split("\n")
|
||||
_texts = self.filter_text(_texts)
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||
|
||||
# 解决句子过长导致Bert报错的问题
|
||||
if (len(text) > 510):
|
||||
texts.extend(split_big_text(text))
|
||||
else:
|
||||
texts.append(text)
|
||||
|
||||
print(i18n("实际输入的目标文本(切句后):"))
|
||||
print(texts)
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
|
||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
language = language.replace("all_","")
|
||||
if language == "en":
|
||||
LangSegment.setfilters(["en"])
|
||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"zh",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"yue",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float32,
|
||||
).to(self.device)
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||
if language == "auto":
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text,language,version,final=True)
|
||||
|
||||
return phones, bert, norm_text
|
||||
|
||||
|
||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
def clean_text_inf(self, text:str, language:str, version:str="v1"):
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
|
||||
language=language.replace("all_","")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
else:
|
||||
feature = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float32,
|
||||
).to(self.device)
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
def filter_text(self,texts):
|
||||
_text=[]
|
||||
if all(text in [None, " ", "\n",""] for text in texts):
|
||||
raise ValueError(i18n("请输入有效文本"))
|
||||
for text in texts:
|
||||
if text in [None, " ", ""]:
|
||||
pass
|
||||
else:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(self,text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
return result
|
||||
|
||||
|
||||
|
1
GPT_SoVITS/TTS_infer_pack/__init__.py
Normal file
1
GPT_SoVITS/TTS_infer_pack/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import TTS, text_segmentation_method
|
173
GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py
Normal file
173
GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py
Normal file
@ -0,0 +1,173 @@
|
||||
|
||||
|
||||
|
||||
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
||||
METHODS = dict()
|
||||
|
||||
def get_method(name:str)->Callable:
|
||||
method = METHODS.get(name, None)
|
||||
if method is None:
|
||||
raise ValueError(f"Method {name} not found")
|
||||
return method
|
||||
|
||||
def get_method_names()->list:
|
||||
return list(METHODS.keys())
|
||||
|
||||
def register_method(name):
|
||||
def decorator(func):
|
||||
METHODS[name] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
|
||||
def split_big_text(text, max_len=510):
|
||||
# 定义全角和半角标点符号
|
||||
punctuation = "".join(splits)
|
||||
|
||||
# 切割文本
|
||||
segments = re.split('([' + punctuation + '])', text)
|
||||
|
||||
# 初始化结果列表和当前片段
|
||||
result = []
|
||||
current_segment = ''
|
||||
|
||||
for segment in segments:
|
||||
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||
if len(current_segment + segment) > max_len:
|
||||
result.append(current_segment)
|
||||
current_segment = segment
|
||||
else:
|
||||
current_segment += segment
|
||||
|
||||
# 将最后一个片段加入结果列表
|
||||
if current_segment:
|
||||
result.append(current_segment)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if todo_text[-1] not in splits:
|
||||
todo_text += "。"
|
||||
i_split_head = i_split_tail = 0
|
||||
len_text = len(todo_text)
|
||||
todo_texts = []
|
||||
while 1:
|
||||
if i_split_head >= len_text:
|
||||
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||
if todo_text[i_split_head] in splits:
|
||||
i_split_head += 1
|
||||
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
||||
i_split_tail = i_split_head
|
||||
else:
|
||||
i_split_head += 1
|
||||
return todo_texts
|
||||
|
||||
|
||||
# 不切
|
||||
@register_method("cut0")
|
||||
def cut0(inp):
|
||||
if not set(inp).issubset(punctuation):
|
||||
return inp
|
||||
else:
|
||||
return "/n"
|
||||
|
||||
|
||||
# 凑四句一切
|
||||
@register_method("cut1")
|
||||
def cut1(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
split_idx = list(range(0, len(inps), 4))
|
||||
split_idx[-1] = None
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
for idx in range(len(split_idx) - 1):
|
||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||
else:
|
||||
opts = [inp]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# 凑50字一切
|
||||
@register_method("cut2")
|
||||
def cut2(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
if len(inps) < 2:
|
||||
return inp
|
||||
opts = []
|
||||
summ = 0
|
||||
tmp_str = ""
|
||||
for i in range(len(inps)):
|
||||
summ += len(inps[i])
|
||||
tmp_str += inps[i]
|
||||
if summ > 50:
|
||||
summ = 0
|
||||
opts.append(tmp_str)
|
||||
tmp_str = ""
|
||||
if tmp_str != "":
|
||||
opts.append(tmp_str)
|
||||
# print(opts)
|
||||
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||
opts[-2] = opts[-2] + opts[-1]
|
||||
opts = opts[:-1]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
# 按中文句号。切
|
||||
@register_method("cut3")
|
||||
def cut3(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = ["%s" % item for item in inp.strip("。").split("。")]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
#按英文句号.切
|
||||
@register_method("cut4")
|
||||
def cut4(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = ["%s" % item for item in inp.strip(".").split(".")]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
# 按标点符号切
|
||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||
@register_method("cut5")
|
||||
def cut5(inp):
|
||||
inp = inp.strip("\n")
|
||||
punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'}
|
||||
mergeitems = []
|
||||
items = []
|
||||
|
||||
for i, char in enumerate(inp):
|
||||
if char in punds:
|
||||
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
items.append(char)
|
||||
else:
|
||||
items.append(char)
|
||||
mergeitems.append("".join(items))
|
||||
items = []
|
||||
else:
|
||||
items.append(char)
|
||||
|
||||
if items:
|
||||
mergeitems.append("".join(items))
|
||||
|
||||
opt = [item for item in mergeitems if not set(item).issubset(punds)]
|
||||
return "\n".join(opt)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
method = get_method("cut5")
|
||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||
|
24
GPT_SoVITS/configs/tts_infer.yaml
Normal file
24
GPT_SoVITS/configs/tts_infer.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
custom:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cuda
|
||||
is_half: true
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
default:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
version: v1
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
default_v2:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
@ -23,13 +23,15 @@ cnhubert_base_path = None
|
||||
|
||||
|
||||
class CNHubert(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, base_path:str=None):
|
||||
super().__init__()
|
||||
if os.path.exists(cnhubert_base_path):...
|
||||
else:raise FileNotFoundError(cnhubert_base_path)
|
||||
self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True)
|
||||
if base_path is None:
|
||||
base_path = cnhubert_base_path
|
||||
if os.path.exists(base_path):...
|
||||
else:raise FileNotFoundError(base_path)
|
||||
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
cnhubert_base_path, local_files_only=True
|
||||
base_path, local_files_only=True
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
331
GPT_SoVITS/inference_webui_fast.py
Normal file
331
GPT_SoVITS/inference_webui_fast.py
Normal file
@ -0,0 +1,331 @@
|
||||
'''
|
||||
按中英混合识别
|
||||
按日英混合识别
|
||||
多语种启动切分识别语种
|
||||
全部按中文识别
|
||||
全部按英文识别
|
||||
全部按日文识别
|
||||
'''
|
||||
import random
|
||||
import os, re, logging
|
||||
import sys
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
logging.getLogger("markdown_it").setLevel(logging.ERROR)
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
import pdb
|
||||
import torch
|
||||
|
||||
|
||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||
infer_ttswebui = int(infer_ttswebui)
|
||||
is_share = os.environ.get("is_share", "False")
|
||||
is_share = eval(is_share)
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
gpt_path = os.environ.get("gpt_path", None)
|
||||
sovits_path = os.environ.get("sovits_path", None)
|
||||
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||
bert_path = os.environ.get("bert_path", None)
|
||||
version=os.environ.get("version","v2")
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
|
||||
|
||||
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
# elif torch.backends.mps.is_available():
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja",#全部按日文识别
|
||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
||||
}
|
||||
dict_language_v2 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja",#全部按日文识别
|
||||
i18n("粤语"): "all_yue",#全部按中文识别
|
||||
i18n("韩文"): "all_ko",#全部按韩文识别
|
||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
||||
i18n("粤英混合"): "yue",#按粤英混合识别####不变
|
||||
i18n("韩英混合"): "ko",#按韩英混合识别####不变
|
||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
||||
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
|
||||
}
|
||||
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
|
||||
|
||||
cut_method = {
|
||||
i18n("不切"):"cut0",
|
||||
i18n("凑四句一切"): "cut1",
|
||||
i18n("凑50字一切"): "cut2",
|
||||
i18n("按中文句号。切"): "cut3",
|
||||
i18n("按英文句号.切"): "cut4",
|
||||
i18n("按标点符号切"): "cut5",
|
||||
}
|
||||
|
||||
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
|
||||
tts_config.device = device
|
||||
tts_config.is_half = is_half
|
||||
tts_config.version = version
|
||||
if gpt_path is not None:
|
||||
tts_config.t2s_weights_path = gpt_path
|
||||
if sovits_path is not None:
|
||||
tts_config.vits_weights_path = sovits_path
|
||||
if cnhubert_base_path is not None:
|
||||
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
||||
if bert_path is not None:
|
||||
tts_config.bert_base_path = bert_path
|
||||
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
gpt_path = tts_config.t2s_weights_path
|
||||
sovits_path = tts_config.vits_weights_path
|
||||
version = tts_config.version
|
||||
|
||||
def inference(text, text_lang,
|
||||
ref_audio_path,
|
||||
aux_ref_audio_paths,
|
||||
prompt_text,
|
||||
prompt_lang, top_k,
|
||||
top_p, temperature,
|
||||
text_split_method, batch_size,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty
|
||||
):
|
||||
|
||||
seed = -1 if keep_random else seed
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
|
||||
inputs={
|
||||
"text": text,
|
||||
"text_lang": dict_language[text_lang],
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": [item.name for item in aux_ref_audio_paths] if aux_ref_audio_paths is not None else [],
|
||||
"prompt_text": prompt_text if not ref_text_free else "",
|
||||
"prompt_lang": dict_language[prompt_lang],
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": cut_method[text_split_method],
|
||||
"batch_size":int(batch_size),
|
||||
"speed_factor":float(speed_factor),
|
||||
"split_bucket":split_bucket,
|
||||
"return_fragment":False,
|
||||
"fragment_interval":fragment_interval,
|
||||
"seed":actual_seed,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
}
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
parts = re.split('(\d+)', s)
|
||||
# 将数字部分转换为整数,非数字部分保持不变
|
||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||
return parts
|
||||
|
||||
|
||||
def change_choices():
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
|
||||
|
||||
|
||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
|
||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
|
||||
_ =[[],[]]
|
||||
for i in range(2):
|
||||
if os.path.exists(pretrained_gpt_name[i]):
|
||||
_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):
|
||||
_[-1].append(pretrained_sovits_name[i])
|
||||
pretrained_gpt_name,pretrained_sovits_name = _
|
||||
|
||||
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
|
||||
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
|
||||
for path in SoVITS_weight_root+GPT_weight_root:
|
||||
os.makedirs(path,exist_ok=True)
|
||||
|
||||
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||
SoVITS_names = [i for i in pretrained_sovits_name]
|
||||
for path in SoVITS_weight_root:
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
|
||||
GPT_names = [i for i in pretrained_gpt_name]
|
||||
for path in GPT_weight_root:
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
|
||||
return SoVITS_names, GPT_names
|
||||
|
||||
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
global version, dict_language
|
||||
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
|
||||
else:
|
||||
prompt_text_update = {'__type__':'update', 'value':''}
|
||||
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
if text_language in list(dict_language.keys()):
|
||||
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
|
||||
else:
|
||||
text_update = {'__type__':'update', 'value':''}
|
||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
|
||||
|
||||
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
gr.Markdown(
|
||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
# with gr.Group():
|
||||
gr.Markdown(value=i18n("模型切换"))
|
||||
with gr.Row():
|
||||
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
||||
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||
with gr.Row():
|
||||
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
||||
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple", type="file")
|
||||
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
||||
with gr.Row():
|
||||
prompt_language = gr.Dropdown(
|
||||
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
with gr.Column():
|
||||
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
|
||||
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
||||
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
|
||||
text_language = gr.Dropdown(
|
||||
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("推理设置"))
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
how_to_cut = gr.Dropdown(
|
||||
label=i18n("怎么切"),
|
||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True, scale=1
|
||||
)
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
|
||||
|
||||
with gr.Row():
|
||||
seed = gr.Number(label=i18n("随机种子"),value=-1)
|
||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
with gr.Row():
|
||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
||||
|
||||
|
||||
inference_button.click(
|
||||
inference,
|
||||
[
|
||||
text,text_language, inp_ref, inp_refs,
|
||||
prompt_text, prompt_language,
|
||||
top_k, top_p, temperature,
|
||||
how_to_cut, batch_size,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty
|
||||
],
|
||||
[output, seed],
|
||||
)
|
||||
stop_infer.click(tts_pipeline.stop, [], [])
|
||||
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language])
|
||||
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
||||
with gr.Row():
|
||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
||||
with gr.Column():
|
||||
_how_to_cut = gr.Radio(
|
||||
label=i18n("怎么切"),
|
||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
)
|
||||
cut_text= gr.Button(i18n("切分"), variant="primary")
|
||||
|
||||
def to_cut(text_inp, how_to_cut):
|
||||
if len(text_inp.strip()) == 0 or text_inp==[]:
|
||||
return ""
|
||||
method = get_method(cut_method[how_to_cut])
|
||||
return method(text_inp)
|
||||
|
||||
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
|
||||
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
||||
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.queue(concurrency_count=511, max_size=1022).launch(
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
share=is_share,
|
||||
server_port=infer_ttswebui,
|
||||
quiet=True,
|
||||
)
|
458
api_v2.py
Normal file
458
api_v2.py
Normal file
@ -0,0 +1,458 @@
|
||||
"""
|
||||
# WebAPI文档
|
||||
|
||||
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
|
||||
|
||||
## 执行参数:
|
||||
`-a` - `绑定地址, 默认"127.0.0.1"`
|
||||
`-p` - `绑定端口, 默认9880`
|
||||
`-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
|
||||
|
||||
## 调用:
|
||||
|
||||
### 推理
|
||||
|
||||
endpoint: `/tts`
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
|
||||
```
|
||||
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"text": "", # str.(required) text to be synthesized
|
||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||
"ref_audio_path": "", # str.(required) reference audio path
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
}
|
||||
```
|
||||
|
||||
RESP:
|
||||
成功: 直接返回 wav 音频流, http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
### 命令控制
|
||||
|
||||
endpoint: `/control`
|
||||
|
||||
command:
|
||||
"restart": 重新运行
|
||||
"exit": 结束运行
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/control?command=restart
|
||||
```
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"command": "restart"
|
||||
}
|
||||
```
|
||||
|
||||
RESP: 无
|
||||
|
||||
|
||||
### 切换GPT模型
|
||||
|
||||
endpoint: `/set_gpt_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
```
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
|
||||
### 切换Sovits模型
|
||||
|
||||
endpoint: `/set_sovits_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
```
|
||||
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import wave
|
||||
import signal
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Request, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi import FastAPI, UploadFile, File
|
||||
import uvicorn
|
||||
from io import BytesIO
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
# print(sys.path)
|
||||
i18n = I18nAuto()
|
||||
cut_method_names = get_cut_method_names()
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
|
||||
args = parser.parse_args()
|
||||
config_path = args.tts_config
|
||||
# device = args.device
|
||||
port = args.port
|
||||
host = args.bind_addr
|
||||
argv = sys.argv
|
||||
|
||||
if config_path in [None, ""]:
|
||||
config_path = "GPT-SoVITS/configs/tts_infer.yaml"
|
||||
|
||||
tts_config = TTS_Config(config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
|
||||
APP = FastAPI()
|
||||
class TTS_Request(BaseModel):
|
||||
text: str = None
|
||||
text_lang: str = None
|
||||
ref_audio_path: str = None
|
||||
aux_ref_audio_paths: list = None
|
||||
prompt_lang: str = None
|
||||
prompt_text: str = ""
|
||||
top_k:int = 5
|
||||
top_p:float = 1
|
||||
temperature:float = 1
|
||||
text_split_method:str = "cut5"
|
||||
batch_size:int = 1
|
||||
batch_threshold:float = 0.75
|
||||
split_bucket:bool = True
|
||||
speed_factor:float = 1.0
|
||||
fragment_interval:float = 0.3
|
||||
seed:int = -1
|
||||
media_type:str = "wav"
|
||||
streaming_mode:bool = False
|
||||
parallel_infer:bool = True
|
||||
repetition_penalty:float = 1.35
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||
audio_file.write(data)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_raw(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
io_buffer.write(data.tobytes())
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_wav(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
io_buffer = BytesIO()
|
||||
sf.write(io_buffer, data, rate, format='wav')
|
||||
return io_buffer
|
||||
|
||||
def pack_aac(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
process = subprocess.Popen([
|
||||
'ffmpeg',
|
||||
'-f', 's16le', # 输入16位有符号小端整数PCM
|
||||
'-ar', str(rate), # 设置采样率
|
||||
'-ac', '1', # 单声道
|
||||
'-i', 'pipe:0', # 从管道读取输入
|
||||
'-c:a', 'aac', # 音频编码器为AAC
|
||||
'-b:a', '192k', # 比特率
|
||||
'-vn', # 不包含视频
|
||||
'-f', 'adts', # 输出AAC数据流格式
|
||||
'pipe:1' # 将输出写入管道
|
||||
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
return io_buffer
|
||||
|
||||
def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
|
||||
|
||||
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
|
||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||
# This will create a wave header then append the frame input
|
||||
# It should be first on a streaming wav file
|
||||
# Other frames better should not have it (else you will hear some artifacts each chunk start)
|
||||
wav_buf = BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
def handle_control(command:str):
|
||||
if command == "restart":
|
||||
os.execl(sys.executable, sys.executable, *argv)
|
||||
elif command == "exit":
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
|
||||
|
||||
def check_params(req:dict):
|
||||
text:str = req.get("text", "")
|
||||
text_lang:str = req.get("text_lang", "")
|
||||
ref_audio_path:str = req.get("ref_audio_path", "")
|
||||
streaming_mode:bool = req.get("streaming_mode", False)
|
||||
media_type:str = req.get("media_type", "wav")
|
||||
prompt_lang:str = req.get("prompt_lang", "")
|
||||
text_split_method:str = req.get("text_split_method", "cut5")
|
||||
|
||||
if ref_audio_path in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
||||
if text in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||
if (text_lang in [None, ""]) :
|
||||
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
|
||||
elif text_lang.lower() not in tts_config.languages:
|
||||
return JSONResponse(status_code=400, content={"message": "text_lang is not supported"})
|
||||
if (prompt_lang in [None, ""]) :
|
||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
|
||||
elif prompt_lang.lower() not in tts_config.languages:
|
||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"})
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
|
||||
elif media_type == "ogg" and not streaming_mode:
|
||||
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||
|
||||
if text_split_method not in cut_method_names:
|
||||
return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
|
||||
|
||||
return None
|
||||
|
||||
async def tts_handle(req:dict):
|
||||
"""
|
||||
Text to speech handler.
|
||||
|
||||
Args:
|
||||
req (dict):
|
||||
{
|
||||
"text": "", # str.(required) text to be synthesized
|
||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||
"ref_audio_path": "", # str.(required) reference audio path
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
}
|
||||
returns:
|
||||
StreamingResponse: audio stream response.
|
||||
"""
|
||||
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
|
||||
check_res = check_params(req)
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
if streaming_mode:
|
||||
req["return_fragment"] = True
|
||||
|
||||
try:
|
||||
tts_generator=tts_pipeline.run(req)
|
||||
|
||||
if streaming_mode:
|
||||
def streaming_generator(tts_generator:Generator, media_type:str):
|
||||
if media_type == "wav":
|
||||
yield wave_header_chunk()
|
||||
media_type = "raw"
|
||||
for sr, chunk in tts_generator:
|
||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return Response(audio_data, media_type=f"audio/{media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@APP.get("/control")
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
return JSONResponse(status_code=400, content={"message": "command is required"})
|
||||
handle_control(command)
|
||||
|
||||
|
||||
|
||||
@APP.get("/tts")
|
||||
async def tts_get_endpoint(
|
||||
text: str = None,
|
||||
text_lang: str = None,
|
||||
ref_audio_path: str = None,
|
||||
aux_ref_audio_paths:list = None,
|
||||
prompt_lang: str = None,
|
||||
prompt_text: str = "",
|
||||
top_k:int = 5,
|
||||
top_p:float = 1,
|
||||
temperature:float = 1,
|
||||
text_split_method:str = "cut0",
|
||||
batch_size:int = 1,
|
||||
batch_threshold:float = 0.75,
|
||||
split_bucket:bool = True,
|
||||
speed_factor:float = 1.0,
|
||||
fragment_interval:float = 0.3,
|
||||
seed:int = -1,
|
||||
media_type:str = "wav",
|
||||
streaming_mode:bool = False,
|
||||
parallel_infer:bool = True,
|
||||
repetition_penalty:float = 1.35
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
"text_lang": text_lang.lower(),
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||||
"prompt_text": prompt_text,
|
||||
"prompt_lang": prompt_lang.lower(),
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": text_split_method,
|
||||
"batch_size":int(batch_size),
|
||||
"batch_threshold":float(batch_threshold),
|
||||
"speed_factor":float(speed_factor),
|
||||
"split_bucket":split_bucket,
|
||||
"fragment_interval":fragment_interval,
|
||||
"seed":seed,
|
||||
"media_type":media_type,
|
||||
"streaming_mode":streaming_mode,
|
||||
"parallel_infer":parallel_infer,
|
||||
"repetition_penalty":float(repetition_penalty)
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
@APP.post("/tts")
|
||||
async def tts_post_endpoint(request: TTS_Request):
|
||||
req = request.dict()
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
@APP.get("/set_refer_audio")
|
||||
async def set_refer_aduio(refer_audio_path: str = None):
|
||||
try:
|
||||
tts_pipeline.set_ref_audio(refer_audio_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
|
||||
# @APP.post("/set_refer_audio")
|
||||
# async def set_refer_aduio_post(audio_file: UploadFile = File(...)):
|
||||
# try:
|
||||
# # 检查文件类型,确保是音频文件
|
||||
# if not audio_file.content_type.startswith("audio/"):
|
||||
# return JSONResponse(status_code=400, content={"message": "file type is not supported"})
|
||||
|
||||
# os.makedirs("uploaded_audio", exist_ok=True)
|
||||
# save_path = os.path.join("uploaded_audio", audio_file.filename)
|
||||
# # 保存音频文件到服务器上的一个目录
|
||||
# with open(save_path , "wb") as buffer:
|
||||
# buffer.write(await audio_file.read())
|
||||
|
||||
# tts_pipeline.set_ref_audio(save_path)
|
||||
# except Exception as e:
|
||||
# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
|
||||
# return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
@APP.get("/set_gpt_weights")
|
||||
async def set_gpt_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
|
||||
tts_pipeline.init_t2s_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"change gpt weight failed", "Exception": str(e)})
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
|
||||
@APP.get("/set_sovits_weights")
|
||||
async def set_sovits_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
|
||||
tts_pipeline.init_vits_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
uvicorn.run(app=APP, host=host, port=port, workers=1)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
13
webui.py
13
webui.py
@ -223,8 +223,12 @@ def change_uvr5():
|
||||
p_uvr5=None
|
||||
yield i18n("UVR5已关闭"), {'__type__':'update','visible':True}, {'__type__':'update','visible':False}
|
||||
|
||||
def change_tts_inference(bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits_path):
|
||||
def change_tts_inference(bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits_path, batched_infer_enabled):
|
||||
global p_tts_inference
|
||||
if batched_infer_enabled:
|
||||
cmd = '"%s" GPT_SoVITS/inference_webui_fast.py "%s"'%(python_exec, language)
|
||||
else:
|
||||
cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language)
|
||||
if(p_tts_inference==None):
|
||||
os.environ["gpt_path"]=gpt_path if "/" in gpt_path else "%s/%s"%(GPT_weight_root,gpt_path)
|
||||
os.environ["sovits_path"]=sovits_path if "/"in sovits_path else "%s/%s"%(SoVITS_weight_root,sovits_path)
|
||||
@ -234,7 +238,6 @@ def change_tts_inference(bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits
|
||||
os.environ["is_half"]=str(is_half)
|
||||
os.environ["infer_ttswebui"]=str(webui_port_infer_tts)
|
||||
os.environ["is_share"]=str(is_share)
|
||||
cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language)
|
||||
yield i18n("TTS推理进程已开启"), {'__type__':'update','visible':False}, {'__type__':'update','visible':True}
|
||||
print(cmd)
|
||||
p_tts_inference = Popen(cmd, shell=True)
|
||||
@ -1031,13 +1034,15 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
refresh_button.click(fn=change_choices,inputs=[],outputs=[SoVITS_dropdown,GPT_dropdown])
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
batched_infer_enabled = gr.Checkbox(label=i18n("启用并行推理版本(推理速度更快)"), value=False, interactive=True, show_label=True)
|
||||
with gr.Row():
|
||||
open_tts = gr.Button(value=i18n("开启TTS推理WebUI"),variant='primary',visible=True)
|
||||
close_tts = gr.Button(value=i18n("关闭TTS推理WebUI"),variant='primary',visible=False)
|
||||
with gr.Row():
|
||||
tts_info = gr.Textbox(label=i18n("TTS推理WebUI进程输出信息"))
|
||||
open_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown], [tts_info,open_tts,close_tts])
|
||||
close_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown], [tts_info,open_tts,close_tts])
|
||||
open_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
|
||||
close_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
|
||||
version_checkbox.change(switch_version,[version_checkbox],[pretrained_s2G,pretrained_s2D,pretrained_s1,GPT_dropdown,SoVITS_dropdown])
|
||||
with gr.TabItem(i18n("2-GPT-SoVITS-变声")):gr.Markdown(value=i18n("施工中,请静候佳音"))
|
||||
app.queue(concurrency_count=511, max_size=1022).launch(
|
||||
|
Loading…
x
Reference in New Issue
Block a user