Merge branch 'main' into dev

This commit is contained in:
刘洋 2024-10-05 14:45:05 +08:00 committed by GitHub
commit c0e9136b75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 137982 additions and 268 deletions

2
.gitignore vendored
View File

@ -14,3 +14,5 @@ GPT_weights_v2
SoVITS_weights_v2 SoVITS_weights_v2
TEMP TEMP
weight.json weight.json
ffmpeg*
ffprobe*

View File

@ -1,11 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
import math
from typing import List, Optional
import torch import torch
import random
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from typing import List
from AR.models.utils import make_pad_mask from AR.models.utils import make_pad_mask
from AR.models.utils import ( from AR.models.utils import (
topk_sampling, topk_sampling,
@ -37,6 +36,34 @@ default_config = {
"EOS": 1024, "EOS": 1024,
} }
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
scale_factor = scale
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask, float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0)
else:
attn_mask[attn_mask!=float("-inf")] =0
attn_mask[attn_mask==float("-inf")] =1
attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value
@torch.jit.script @torch.jit.script
class T2SMLP: class T2SMLP:
@ -84,30 +111,74 @@ class T2SBlock:
self.norm_b2 = norm_b2 self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2 self.norm_eps2 = norm_eps2
def process_prompt(self, x, attn_mask: torch.Tensor): self.false = torch.tensor(False, dtype=torch.bool)
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0] batch_size = q.shape[0]
q_len = q.shape[1] q_len = q.shape[1]
kv_len = k.shape[1] kv_len = k.shape[1]
k_cache = k q = self.to_mask(q, padding_mask)
v_cache = v k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask)
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
else:
attn = scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim) attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = F.linear(attn, self.out_w, self.out_b) attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = F.layer_norm( if padding_mask is not None:
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 for i in range(batch_size):
# mask = padding_mask[i,:,0]
if self.false.device!= padding_mask.device:
self.false = self.false.to(padding_mask.device)
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
x_item = x[i,idx,:].unsqueeze(0)
attn_item = attn[i,idx,:].unsqueeze(0)
x_item = x_item + attn_item
x_item = F.layer_norm(
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
) )
x_item = x_item + self.mlp.forward(x_item)
x_item = F.layer_norm(
x_item,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
x[i,idx,:] = x_item.squeeze(0)
x = self.to_mask(x, padding_mask)
else:
x = x + attn
x = F.layer_norm( x = F.layer_norm(
x + self.mlp.forward(x), x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim], [self.hidden_dim],
self.norm_w2, self.norm_w2,
self.norm_b2, self.norm_b2,
@ -115,30 +186,37 @@ class T2SBlock:
) )
return x, k_cache, v_cache return x, k_cache, v_cache
def decode_next_token(self, x, k_cache, v_cache): def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1) k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1) v_cache = torch.cat([v_cache, v], dim=1)
kv_len = k_cache.shape[1]
batch_size = q.shape[0] batch_size = q.shape[0]
q_len = q.shape[1] q_len = q.shape[1]
kv_len = k_cache.shape[1]
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim) if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v)
else:
attn = scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b) attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn
x = F.layer_norm( x = F.layer_norm(
x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
) )
x = x + self.mlp.forward(x)
x = F.layer_norm( x = F.layer_norm(
x + self.mlp.forward(x), x,
[self.hidden_dim], [self.hidden_dim],
self.norm_w2, self.norm_w2,
self.norm_b2, self.norm_b2,
@ -154,20 +232,27 @@ class T2STransformer:
self.blocks = blocks self.blocks = blocks
def process_prompt( def process_prompt(
self, x, attn_mask: torch.Tensor): self, x:torch.Tensor, attn_mask : torch.Tensor,
padding_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
):
k_cache : List[torch.Tensor] = [] k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = [] v_cache : List[torch.Tensor] = []
for i in range(self.num_blocks): for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask) x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_) k_cache.append(k_cache_)
v_cache.append(v_cache_) v_cache.append(v_cache_)
return x, k_cache, v_cache return x, k_cache, v_cache
def decode_next_token( def decode_next_token(
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor] self, x:torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
): ):
for i in range(self.num_blocks): for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
return x, k_cache, v_cache return x, k_cache, v_cache
@ -235,7 +320,7 @@ class Text2SemanticDecoder(nn.Module):
layer.linear2.weight, layer.linear2.weight,
layer.linear2.bias layer.linear2.bias
) )
# (layer.self_attn.in_proj_weight, layer.self_attn.in_proj_bias)
block = T2SBlock( block = T2SBlock(
self.num_head, self.num_head,
self.model_dim, self.model_dim,
@ -283,7 +368,7 @@ class Text2SemanticDecoder(nn.Module):
(0, y_len), (0, y_len),
value=True, value=True,
) )
# x_attn_mask[:, x_len]=False
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu( torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@ -488,16 +573,225 @@ class Text2SemanticDecoder(nn.Module):
# 错位 # 错位
return targets[:, :-1], targets[:, 1:] return targets[:, :-1], targets[:, 1:]
def infer_panel( def infer_panel_batch_infer(
self, self,
x, #####全部文本token x:List[torch.LongTensor], #####全部文本token
x_lens, x_lens:torch.LongTensor,
prompts, ####参考音频token prompts:torch.LongTensor, ####参考音频token
bert_feature, bert_feature:List[torch.LongTensor],
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs,
):
if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
max_len = kwargs.get("max_len",x_lens.max())
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
x_list.append(x_item)
x = torch.stack(x_list, dim=0)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_paddind_mask = make_pad_mask(y_lens, y_len)
x_paddind_mask = make_pad_mask(x_lens, max_len)
# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
for i in range(bsz):
l = x_lens[i]
_xy_padding_mask[i,l:max_len,:]=True
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
xy_attn_mask = xy_attn_mask.bool()
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
###### decode #####
y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
if idx == 0:
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
logits = logits[:, :-1]
else:
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0]==self.EOS
l2 = tokens==self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if not (None in idx_list):
stop = True
if stop:
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
# print(idx_list)
return y_list, idx_list
def infer_panel_naive_batched(self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs)
y_list.append(y[0])
idx_list.append(idx)
return y_list, idx_list
def infer_panel_naive(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
): ):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -528,9 +822,10 @@ class Text2SemanticDecoder(nn.Module):
y_pos = None y_pos = None
xy_pos = x xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device) y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
prompts = y
ref_free = True ref_free = True
bsz = x.shape[0]
src_len = x_len + y_len
x_attn_mask_pad = F.pad( x_attn_mask_pad = F.pad(
x_attn_mask, x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y) (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
@ -541,13 +836,15 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
x.device .unsqueeze(0)\
) .expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
.to(device=x.device, dtype=torch.bool)
for idx in tqdm(range(1500)): for idx in tqdm(range(1500)):
if xy_attn_mask is not None: if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
else: else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
@ -557,10 +854,12 @@ class Text2SemanticDecoder(nn.Module):
if idx == 0: if idx == 0:
xy_attn_mask = None xy_attn_mask = None
if(idx<11):###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1] logits = logits[:, :-1]
samples = sample( samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0].unsqueeze(0) )[0]
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
@ -584,3 +883,19 @@ class Text2SemanticDecoder(nn.Module):
if ref_free: if ref_free:
return y[:, :-1], 0 return y[:, :-1], 0
return y[:, :-1], idx - 1 return y[:, :-1], idx - 1
def infer_panel(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)

View File

@ -115,17 +115,17 @@ def logits_to_probs(
top_p: Optional[int] = None, top_p: Optional[int] = None,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
): ):
if previous_tokens is not None: # if previous_tokens is not None:
previous_tokens = previous_tokens.squeeze() # previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape) # print(logits.shape,previous_tokens.shape)
# pdb.set_trace() # pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0: if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long() previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens) score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where( score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty score < 0, score * repetition_penalty, score / repetition_penalty
) )
logits.scatter_(dim=0, index=previous_tokens, src=score) logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@ -133,9 +133,9 @@ def logits_to_probs(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
) )
sorted_indices_to_remove = cum_probs > top_p sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove dim=1, index=sorted_indices, src=sorted_indices_to_remove
) )
logits = logits.masked_fill(indices_to_remove, -float("Inf")) logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@ -143,7 +143,7 @@ def logits_to_probs(
if top_k is not None: if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1) pivot = v[: , -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits) logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,244 @@
import os, sys
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
import re
import torch
import LangSegment
from text import chinese
from typing import Dict, List, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
punctuation = set(['!', '?', '', ',', '.', '-'," "])
def get_first(text:str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts:str, threshold:int) -> list:
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
print(i18n("############ 切分文本 ############"))
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text=="":
continue
res={
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
result.append(res)
return result
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
text = text.strip("\n")
if len(text) == 0:
return []
if (text[0] not in splits and len(get_first(text)) < 4):
text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:"))
print(text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
_texts = text.split("\n")
_texts = self.filter_text(_texts)
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if not re.sub("\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if (text[-1] not in splits): text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
if (len(text) > 510):
texts.extend(split_big_text(text))
else:
texts.append(text)
print(i18n("实际输入的目标文本(切句后):"))
print(texts)
return texts
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version)
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日韩文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text,language,version,final=True)
return phones, bert, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def clean_text_inf(self, text:str, language:str, version:str="v2"):
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
language=language.replace("all_","")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else:
feature = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
return feature
def filter_text(self,texts):
_text=[]
if all(text in [None, " ", "\n",""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", ""]:
pass
else:
_text.append(text)
return _text
def replace_consecutive_punctuation(self,text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result

View File

@ -0,0 +1 @@
from . import TTS, text_segmentation_method

View File

@ -0,0 +1,173 @@
import re
from typing import Callable
punctuation = set(['!', '?', '', ',', '.', '-'," "])
METHODS = dict()
def get_method(name:str)->Callable:
method = METHODS.get(name, None)
if method is None:
raise ValueError(f"Method {name} not found")
return method
def get_method_names()->list:
return list(METHODS.keys())
def register_method(name):
def decorator(func):
METHODS[name] = func
return func
return decorator
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)
# 切割文本
segments = re.split('([' + punctuation + '])', text)
# 初始化结果列表和当前片段
result = []
current_segment = ''
for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
if len(current_segment + segment) > max_len:
result.append(current_segment)
current_segment = segment
else:
current_segment += segment
# 将最后一个片段加入结果列表
if current_segment:
result.append(current_segment)
return result
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
# 不切
@register_method("cut0")
def cut0(inp):
if not set(inp).issubset(punctuation):
return inp
else:
return "/n"
# 凑四句一切
@register_method("cut1")
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 凑50字一切
@register_method("cut2")
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按中文句号。切
@register_method("cut3")
def cut3(inp):
inp = inp.strip("\n")
opts = ["%s" % item for item in inp.strip("").split("")]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
#按英文句号.切
@register_method("cut4")
def cut4(inp):
inp = inp.strip("\n")
opts = ["%s" % item for item in inp.strip(".").split(".")]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5")
def cut5(inp):
inp = inp.strip("\n")
punds = {',', '.', ';', '?', '!', '', '', '', '', '', ';', '', ''}
mergeitems = []
items = []
for i, char in enumerate(inp):
if char in punds:
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
items.append(char)
else:
items.append(char)
mergeitems.append("".join(items))
items = []
else:
items.append(char)
if items:
mergeitems.append("".join(items))
opt = [item for item in mergeitems if not set(item).issubset(punds)]
return "\n".join(opt)
if __name__ == '__main__':
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

1
GPT_SoVITS/configs/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.yaml

View File

@ -0,0 +1,24 @@
custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda
is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
default:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
version: v1
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
default_v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth

View File

@ -0,0 +1,737 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
from typing import Optional
from my_utils import load_audio
from text import cleaned_text_to_sequence
import torch
import torchaudio
from torch import IntTensor, LongTensor, Tensor, nn
from torch.nn import functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer
from feature_extractor import cnhubert
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from module.models_onnx import SynthesizerTrn
import os
import soundfile
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
"num_head": 8,
"num_layers": 12,
"num_codebook": 8,
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024,
}
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"]
config["model"]["dropout"] = float(config["model"]["dropout"])
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
t2s_model = t2s_model.eval()
return t2s_model
@torch.jit.script
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
# if previous_tokens is not None:
# previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[: , -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
@torch.jit.script
def multinomial_sample_one_no_sync(probs_sort):
# Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@torch.jit.script
def sample(
logits,
previous_tokens,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
@torch.jit.script
def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
self.w1 = w1
self.b1 = b1
self.w2 = w2
self.b2 = b2
def forward(self, x):
x = F.relu(F.linear(x, self.w1, self.b1))
x = F.linear(x, self.w2, self.b2)
return x
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1: float,
norm_w2,
norm_b2,
norm_eps2: float,
):
self.num_heads = num_heads
self.mlp = mlp
self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w
self.qkv_b = qkv_b
self.out_w = out_w
self.out_b = out_b
self.norm_w1 = norm_w1
self.norm_b1 = norm_b1
self.norm_eps1 = norm_eps1
self.norm_w2 = norm_w2
self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask)
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
if padding_mask is not None:
for i in range(batch_size):
# mask = padding_mask[i,:,0]
if self.false.device!= padding_mask.device:
self.false = self.false.to(padding_mask.device)
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
x_item = x[i,idx,:].unsqueeze(0)
attn_item = attn[i,idx,:].unsqueeze(0)
x_item = x_item + attn_item
x_item = F.layer_norm(
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x_item = x_item + self.mlp.forward(x_item)
x_item = F.layer_norm(
x_item,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
x[i,idx,:] = x_item.squeeze(0)
x = self.to_mask(x, padding_mask)
else:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: list[T2SBlock]):
self.num_blocks : int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
k_cache : list[torch.Tensor] = []
v_cache : list[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
k_cache.append(k_cache_)
v_cache.append(v_cache_)
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
k_cache: list[torch.Tensor],
v_cache: list[torch.Tensor]):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache
class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
)
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
class T2SModel(nn.Module):
def __init__(self,raw_t2s:Text2SemanticLightningModule):
super(T2SModel, self).__init__()
self.model_dim = raw_t2s.model.model_dim
self.embedding_dim = raw_t2s.model.embedding_dim
self.num_head = raw_t2s.model.num_head
self.num_layers = raw_t2s.model.num_layers
self.vocab_size = raw_t2s.model.vocab_size
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
# self.p_dropout = float(raw_t2s.model.p_dropout)
self.EOS:int = int(raw_t2s.model.EOS)
self.norm_first = raw_t2s.model.norm_first
assert self.EOS == self.vocab_size - 1
self.hz = 50
self.bert_proj = raw_t2s.model.bert_proj
self.ar_text_embedding = raw_t2s.model.ar_text_embedding
self.ar_text_position = raw_t2s.model.ar_text_position
self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
self.ar_audio_position = raw_t2s.model.ar_audio_position
# self.t2s_transformer = T2STransformer(self.num_layers, blocks)
# self.t2s_transformer = raw_t2s.model.t2s_transformer
blocks = []
h = raw_t2s.model.h
for i in range(self.num_layers):
layer = h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)
block = T2SBlock(
self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.ar_predict_layer = raw_t2s.model.ar_predict_layer
# self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.max_sec = raw_t2s.config["data"]["max_sec"]
self.top_k = int(raw_t2s.config["inference"]["top_k"])
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor):
bert = torch.cat([ref_bert.T, text_bert.T], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0)
x = self.ar_text_embedding(all_phoneme_ids)
x = x + self.bert_proj(bert.transpose(1, 2))
x:torch.Tensor = self.ar_text_position(x)
early_stop_num = self.early_stop_num
#[1,N,512] [1,N]
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
y = prompts
# x_example = x[:,:,0] * 0.0
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
bsz = x.shape[0]
src_len = x_len + y_len
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
.unsqueeze(0)\
.expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
.to(device=x.device, dtype=torch.bool)
idx = 0
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
logits = self.ar_predict_layer(xy_dec[:, -1])
logits = logits[:, :-1]
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
y = torch.concat([y, samples], dim=1)
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
stop = False
# for idx in range(1, 50):
for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
if(idx<11):###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
y = torch.concat([y, samples], dim=1)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
break
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
return y[:, -idx:].unsqueeze(0)
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
@torch.jit.script
def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# [sum(word2ph), 1024]
return phone_level_feature
class MyBertModel(torch.nn.Module):
def __init__(self, bert_model):
super(MyBertModel, self).__init__()
self.bert = bert_model
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
return build_phone_level_feature(res, word2ph)
class SSLModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.ssl = cnhubert.get_model().model
def forward(self, ref_audio_16k)-> torch.Tensor:
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
return ssl_content
class ExportSSLModel(torch.nn.Module):
def __init__(self,ssl:SSLModel):
super().__init__()
self.ssl = ssl
def forward(self, ref_audio:torch.Tensor):
return self.ssl(ref_audio)
@torch.jit.export
def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
audio = resamplex(ref_audio,src_sr,dst_sr).float()
return audio
def export_bert(ref_bert_inputs):
ref_bert_inputs = {
'input_ids': ref_bert_inputs['input_ids'],
'attention_mask': ref_bert_inputs['attention_mask'],
'token_type_ids': ref_bert_inputs['token_type_ids'],
'word2ph': ref_bert_inputs['word2ph']
}
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
my_bert_model = MyBertModel(bert_model)
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
my_bert_model.save("onnx/bert_model.pt")
print('#### exported bert ####')
def export(gpt_path, vits_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path)
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert = MyBertModel(bert_model)
# export_bert(ref_bert_inputs)
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
ssl = SSLModel()
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
torch.jit.script(s).save("onnx/xw/ssl_model.pt")
print('#### exported ssl ####')
ref_bert = bert(**ref_bert_inputs)
text_bert = bert(**text_berf_inputs)
ssl_content = ssl(ref_audio)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path)
vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1)
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
t2s = torch.jit.script(t2s_m)
print('#### script t2s_m ####')
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS(t2s,vits)
gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio,16000,32000)
print('ref_audio_sr:',ref_audio_sr.shape)
gpt_sovits_export = torch.jit.trace(
gpt_sovits,
example_inputs=(
ssl_content,
ref_audio_sr,
ref_seq,
text_seq,
ref_bert,
text_bert),
check_trace=False) # 默认是True 但是 check 的时候可能是随机生成的一个奇怪维度的值,导致报错
gpt_sovits_export.save("onnx/xw/gpt_sovits_model.pt")
print('#### exported gpt_sovits ####')
@torch.jit.script
def parse_audio(ref_audio):
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()#.to(ref_audio.device)
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device)
return ref_audio_16k,ref_audio_sr
@torch.jit.script
def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float()
class GPT_SoVITS(nn.Module):
def __init__(self, t2s:T2SModel,vits:VitsModel):
super().__init__()
self.t2s = t2s
self.vits = vits
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor):
codes = self.vits.vq_model.extract_latent(ssl_content.float())
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
audio = self.vits(text_seq, pred_semantic, ref_audio_sr)
return audio
def test(gpt_path, vits_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert = MyBertModel(bert_model)
# bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1)
t2s = T2SModel(raw_t2s)
t2s.eval()
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path)
vits.eval()
ssl = ExportSSLModel(SSLModel())
ssl.eval()
gpt_sovits = GPT_SoVITS(t2s,vits)
# vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda')
# ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda')
ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
text_berf_inputs = tokenizer("大家好,我有一个春晚问题.", return_tensors="pt")
text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
text_berf_inputs['word2ph'] = torch.Tensor([2,2,2,1,2,2,2,2,2,2,2,2,1]).int()
ref_bert = bert(
ref_bert_inputs['input_ids'],
ref_bert_inputs['attention_mask'],
ref_bert_inputs['token_type_ids'],
ref_bert_inputs['word2ph']
)
text_bert = bert(text_berf_inputs['input_ids'],
text_berf_inputs['attention_mask'],
text_berf_inputs['token_type_ids'],
text_berf_inputs['word2ph'])
#[1,N]
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
print('ref_audio:',ref_audio.shape)
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
print('start ssl')
ssl_content = ssl(ref_audio)
print('start gpt_sovits:')
with torch.no_grad():
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert)
print('start write wav')
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
# audio = vits(text_seq, pred_semantic1, ref_audio)
# soundfile.write("out.wav", audio, 32000)
import text
import json
def export_symbel(version='v2'):
if version=='v1':
symbols = text._symbol_to_id_v1
with open(f"onnx/symbols_v1.json", "w") as file:
json.dump(symbols, file, indent=4)
else:
symbols = text._symbol_to_id_v2
with open(f"onnx/symbols_v2.json", "w") as file:
json.dump(symbols, file, indent=4)
if __name__ == "__main__":
export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
# test(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
# export_symbel()

View File

@ -23,13 +23,15 @@ cnhubert_base_path = None
class CNHubert(nn.Module): class CNHubert(nn.Module):
def __init__(self): def __init__(self, base_path:str=None):
super().__init__() super().__init__()
if os.path.exists(cnhubert_base_path):... if base_path is None:
else:raise FileNotFoundError(cnhubert_base_path) base_path = cnhubert_base_path
self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True) if os.path.exists(base_path):...
else:raise FileNotFoundError(base_path)
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
cnhubert_base_path, local_files_only=True base_path, local_files_only=True
) )
def forward(self, x): def forward(self, x):

View File

@ -21,6 +21,11 @@ import LangSegment, os, re, sys, json
import pdb import pdb
import torch import torch
try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...
version=os.environ.get("version","v2") version=os.environ.get("version","v2")
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"] pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"] pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
@ -392,7 +397,8 @@ def merge_short_text_in_array(texts, threshold):
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
# cache_tokens={}#暂未实现清理机制 # cache_tokens={}#暂未实现清理机制
cache= {} cache= {}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=123): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free
=False,speed=1,if_freeze=False,inp_refs=None):
global cache global cache
if ref_wav_path:pass if ref_wav_path:pass
else:gr.Warning(i18n('请上传参考音频')) else:gr.Warning(i18n('请上传参考音频'))

View File

@ -0,0 +1,336 @@
'''
按中英混合识别
按日英混合识别
多语种启动切分识别语种
全部按中文识别
全部按英文识别
全部按日文识别
'''
import random
import os, re, logging
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
is_share = os.environ.get("is_share", "False")
is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None)
version=os.environ.get("version","v2")
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
if torch.cuda.is_available():
device = "cuda"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别
i18n("中英混合"): "zh",#按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种
}
dict_language_v2 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别
i18n("粤语"): "all_yue",#全部按中文识别
i18n("韩文"): "all_ko",#全部按韩文识别
i18n("中英混合"): "zh",#按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变
i18n("粤英混合"): "yue",#按粤英混合识别####不变
i18n("韩英混合"): "ko",#按韩英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
}
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
cut_method = {
i18n("不切"):"cut0",
i18n("凑四句一切"): "cut1",
i18n("凑50字一切"): "cut2",
i18n("按中文句号。切"): "cut3",
i18n("按英文句号.切"): "cut4",
i18n("按标点符号切"): "cut5",
}
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device
tts_config.is_half = is_half
tts_config.version = version
if gpt_path is not None:
tts_config.t2s_weights_path = gpt_path
if sovits_path is not None:
tts_config.vits_weights_path = sovits_path
if cnhubert_base_path is not None:
tts_config.cnhuhbert_base_path = cnhubert_base_path
if bert_path is not None:
tts_config.bert_base_path = bert_path
print(tts_config)
tts_pipeline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
version = tts_config.version
def inference(text, text_lang,
ref_audio_path,
aux_ref_audio_paths,
prompt_text,
prompt_lang, top_k,
top_p, temperature,
text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
):
seed = -1 if keep_random else seed
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
inputs={
"text": text,
"text_lang": dict_language[text_lang],
"ref_audio_path": ref_audio_path,
"aux_ref_audio_paths": [item.name for item in aux_ref_audio_paths] if aux_ref_audio_paths is not None else [],
"prompt_text": prompt_text if not ref_text_free else "",
"prompt_lang": dict_language[prompt_lang],
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
"fragment_interval":fragment_interval,
"seed":actual_seed,
"parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty,
}
for item in tts_pipeline.run(inputs):
yield item, actual_seed
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
# 将数字部分转换为整数,非数字部分保持不变
parts = [int(part) if part.isdigit() else part for part in parts]
return parts
def change_choices():
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
_ =[[],[]]
for i in range(2):
if os.path.exists(pretrained_gpt_name[i]):
_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name,pretrained_sovits_name = _
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
for path in SoVITS_weight_root+GPT_weight_root:
os.makedirs(path,exist_ok=True)
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names = [i for i in pretrained_sovits_name]
for path in SoVITS_weight_root:
for name in os.listdir(path):
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
GPT_names = [i for i in pretrained_gpt_name]
for path in GPT_weight_root:
for name in os.listdir(path):
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
return SoVITS_names, GPT_names
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
tts_pipeline.init_vits_weights(sovits_path)
global version, dict_language
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
else:
prompt_text_update = {'__type__':'update', 'value':''}
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
if text_language in list(dict_language.keys()):
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
else:
text_update = {'__type__':'update', 'value':''}
text_language_update = {'__type__':'update', 'value':i18n("中文")}
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
)
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
with gr.Row():
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"), type="filepath")
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple")
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
with gr.Row():
prompt_language = gr.Dropdown(
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
with gr.Column():
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
text_language = gr.Dropdown(
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
with gr.Column():
with gr.Row():
how_to_cut = gr.Dropdown(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True, scale=1
)
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
with gr.Row():
seed = gr.Number(label=i18n("随机种子"),value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
inference_button.click(
inference,
[
text,text_language, inp_ref, inp_refs,
prompt_text, prompt_language,
top_k, top_p, temperature,
how_to_cut, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
],
[output, seed],
)
stop_infer.click(tts_pipeline.stop, [], [])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language])
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
with gr.Group():
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column():
_how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
cut_text= gr.Button(i18n("切分"), variant="primary")
def to_cut(text_inp, how_to_cut):
if len(text_inp.strip()) == 0 or text_inp==[]:
return ""
method = get_method(cut_method[how_to_cut])
return method(text_inp)
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
if __name__ == '__main__':
app.queue().launch(#concurrency_count=511, max_size=1022
server_name="0.0.0.0",
inbrowser=True,
share=is_share,
server_port=infer_ttswebui,
quiet=True,
)

View File

@ -4,8 +4,8 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from module import commons
from module.modules import LayerNorm
from typing import Optional
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5): def __init__(self, channels, eps=1e-5):
@ -59,6 +59,7 @@ class Encoder(nn.Module):
# self.cond_layer = weight_norm(cond_layer, name='weight') # self.cond_layer = weight_norm(cond_layer, name='weight')
# self.gin_channels = 256 # self.gin_channels = 256
self.cond_layer_idx = self.n_layers self.cond_layer_idx = self.n_layers
self.spk_emb_linear = nn.Linear(256, self.hidden_channels)
if "gin_channels" in kwargs: if "gin_channels" in kwargs:
self.gin_channels = kwargs["gin_channels"] self.gin_channels = kwargs["gin_channels"]
if self.gin_channels != 0: if self.gin_channels != 0:
@ -98,22 +99,36 @@ class Encoder(nn.Module):
) )
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g=None): # def forward(self, x, x_mask, g=None):
# attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
# x = x * x_mask
# for i in range(self.n_layers):
# if i == self.cond_layer_idx and g is not None:
# g = self.spk_emb_linear(g.transpose(1, 2))
# g = g.transpose(1, 2)
# x = x + g
# x = x * x_mask
# y = self.attn_layers[i](x, x, attn_mask)
# y = self.drop(y)
# x = self.norm_layers_1[i](x + y)
# y = self.ffn_layers[i](x, x_mask)
# y = self.drop(y)
# x = self.norm_layers_2[i](x + y)
# x = x * x_mask
# return x
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask x = x * x_mask
for i in range(self.n_layers): for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
if i == self.cond_layer_idx and g is not None: y = attn_layers(x, x, attn_mask)
g = self.spk_emb_linear(g.transpose(1, 2))
g = g.transpose(1, 2)
x = x + g
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_1[i](x + y) x = norm_layers_1(x + y)
y = self.ffn_layers[i](x, x_mask) y = ffn_layers(x, x_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_2[i](x + y) x = norm_layers_2(x + y)
x = x * x_mask x = x * x_mask
return x return x
@ -172,17 +187,18 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.copy_(self.conv_q.weight) self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias) self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None): def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
q = self.conv_q(x) q = self.conv_q(x)
k = self.conv_k(c) k = self.conv_k(c)
v = self.conv_v(c) v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask) # x, self.attn = self.attention(q, k, v, mask=attn_mask)
x, _ = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x) x = self.conv_o(x)
return x return x
def attention(self, query, key, value, mask=None): def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
# reshape [b, d, t] -> [b, n_h, t, d_k] # reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, _ = (*key.size(), query.size(2)) b, d, t_s, _ = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
@ -304,7 +320,7 @@ class FFN(nn.Module):
filter_channels, filter_channels,
kernel_size, kernel_size,
p_dropout=0.0, p_dropout=0.0,
activation=None, activation="",
causal=False, causal=False,
): ):
super().__init__() super().__init__()
@ -316,10 +332,11 @@ class FFN(nn.Module):
self.activation = activation self.activation = activation
self.causal = causal self.causal = causal
if causal: # 从上下文看这里一定是 False
self.padding = self._causal_padding # if causal:
else: # self.padding = self._causal_padding
self.padding = self._same_padding # else:
# self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
@ -335,6 +352,9 @@ class FFN(nn.Module):
x = self.conv_2(self.padding(x * x_mask)) x = self.conv_2(self.padding(x * x_mask))
return x * x_mask return x * x_mask
def padding(self, x):
return self._same_padding(x)
def _causal_padding(self, x): def _causal_padding(self, x):
if self.kernel_size == 1: if self.kernel_size == 1:
return x return x
@ -352,3 +372,35 @@ class FFN(nn.Module):
padding = [[0, 0], [0, 0], [pad_l, pad_r]] padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding)) x = F.pad(x, commons.convert_pad_shape(padding))
return x return x
class MRTE(nn.Module):
def __init__(
self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer=2,
):
super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge):
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask)
return x

View File

@ -13,10 +13,10 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2) return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape): # def convert_pad_shape(pad_shape):
l = pad_shape[::-1] # l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist] # pad_shape = [item for sublist in l for item in sublist]
return pad_shape # return pad_shape
def intersperse(lst, item): def intersperse(lst, item):

View File

@ -1,5 +1,6 @@
import copy import copy
import math import math
from typing import Optional
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -11,9 +12,10 @@ from module import attentions_onnx as attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
from text import symbols # from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
@ -182,6 +184,7 @@ class TextEncoder(nn.Module):
kernel_size, kernel_size,
p_dropout, p_dropout,
latent_channels=192, latent_channels=192,
version="v2",
): ):
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
@ -192,6 +195,7 @@ class TextEncoder(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.latent_channels = latent_channels self.latent_channels = latent_channels
self.version = version
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1) self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
@ -207,9 +211,14 @@ class TextEncoder(nn.Module):
self.encoder_text = attentions.Encoder( self.encoder_text = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
) )
if self.version == "v1":
symbols = symbols_v1.symbols
else:
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE() self.mrte = attentions.MRTE()
self.encoder2 = attentions.Encoder( self.encoder2 = attentions.Encoder(
hidden_channels, hidden_channels,
@ -240,25 +249,6 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask return y, m, logs, y_mask
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
y = self.mrte(y, y_mask, refer, refer_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module): class ResidualCouplingBlock(nn.Module):
def __init__( def __init__(
@ -439,7 +429,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None): def forward(self, x, g:Optional[torch.Tensor]=None):
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: if g is not None:
x = x + self.cond(g) x = x + self.cond(g)
@ -817,6 +807,7 @@ class SynthesizerTrn(nn.Module):
use_sdp=True, use_sdp=True,
semantic_frame_rate=None, semantic_frame_rate=None,
freeze_quantizer=None, freeze_quantizer=None,
version="v2",
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@ -837,6 +828,7 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size self.segment_size = segment_size
self.n_speakers = n_speakers self.n_speakers = n_speakers
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.version = version
self.use_sdp = use_sdp self.use_sdp = use_sdp
self.enc_p = TextEncoder( self.enc_p = TextEncoder(
@ -847,6 +839,7 @@ class SynthesizerTrn(nn.Module):
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, p_dropout,
version=version,
) )
self.dec = Generator( self.dec = Generator(
inter_channels, inter_channels,
@ -858,22 +851,24 @@ class SynthesizerTrn(nn.Module):
upsample_kernel_sizes, upsample_kernel_sizes,
gin_channels=gin_channels, gin_channels=gin_channels,
) )
self.enc_q = PosteriorEncoder( # self.enc_q = PosteriorEncoder(
spec_channels, # spec_channels,
inter_channels, # inter_channels,
hidden_channels, # hidden_channels,
5, # 5,
1, # 1,
16, # 16,
gin_channels=gin_channels, # gin_channels=gin_channels,
) # )
self.flow = ResidualCouplingBlock( self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
) )
self.ref_enc = modules.MelStyleEncoder( # self.version=os.environ.get("version","v1")
spec_channels, style_vector_dim=gin_channels if self.version == "v1":
) self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
ssl_dim = 768 ssl_dim = 768
self.ssl_dim = ssl_dim self.ssl_dim = ssl_dim
@ -894,7 +889,10 @@ class SynthesizerTrn(nn.Module):
def forward(self, codes, text, refer): def forward(self, codes, text, refer):
refer_mask = torch.ones_like(refer[:1,:1,:]) refer_mask = torch.ones_like(refer[:1,:1,:])
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz": if self.semantic_frame_rate == "25hz":

View File

@ -1,10 +1,11 @@
from module.models_onnx import SynthesizerTrn, symbols from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch import torch
import torchaudio import torchaudio
from torch import nn from torch import nn
from feature_extractor import cnhubert from feature_extractor import cnhubert
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model() ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
@ -196,6 +197,11 @@ class VitsModel(nn.Module):
super().__init__() super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu") dict_s2 = torch.load(vits_path,map_location="cpu")
self.hps = dict_s2["config"] self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps) self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz" self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn( self.vq_model = SynthesizerTrn(
@ -267,13 +273,13 @@ class SSLModel(nn.Module):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def export(vits_path, gpt_path, project_name): def export(vits_path, gpt_path, project_name, vits_model="v2"):
vits = VitsModel(vits_path) vits = VitsModel(vits_path)
gpt = T2SModel(gpt_path, vits) gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt) gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel() ssl = SSLModel()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float() text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float() ref_audio = torch.randn((1, 48000 * 5)).float()
@ -288,19 +294,23 @@ def export(vits_path, gpt_path, project_name):
ssl_content = ssl(ref_audio_16k).float() ssl_content = ssl(ref_audio_16k).float()
debug = False # debug = False
debug = True
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if debug: if debug:
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
return else:
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
soundfile.write("out.wav", a, vits.hps.data.sampling_rate) soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) if vits_model == "v1":
symbols = symbols_v1
else:
symbols = symbols_v2
MoeVSConf = { MoeVSConf = {
"Folder": f"{project_name}", "Folder": f"{project_name}",
@ -311,8 +321,8 @@ def export(vits_path, gpt_path, project_name):
"EmbeddingDim": gpt.t2s_model.embedding_dim, "EmbeddingDim": gpt.t2s_model.embedding_dim,
"Dict": "BasicDict", "Dict": "BasicDict",
"BertPath": "chinese-roberta-wwm-ext-large", "BertPath": "chinese-roberta-wwm-ext-large",
"Symbol": symbols, # "Symbol": symbols,
"AddBlank": False "AddBlank": False,
} }
MoeVSConfJson = json.dumps(MoeVSConf) MoeVSConfJson = json.dumps(MoeVSConf)

View File

@ -27,7 +27,7 @@ if is_g2pw:
print("当前使用g2pw进行拼音推理") print("当前使用g2pw进行拼音推理")
from text.g2pw import G2PWPinyin, correct_pronunciation from text.g2pw import G2PWPinyin, correct_pronunciation
parent_directory = os.path.dirname(current_file_path) parent_directory = os.path.dirname(current_file_path)
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True) g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)
rep_map = { rep_map = {
"": ",", "": ",",

View File

@ -50,8 +50,6 @@ def clean_text(text, language, version=None):
else: else:
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
word2ph = None word2ph = None
for ph in phones:
phones = ['UNK' if ph not in symbols else ph for ph in phones] phones = ['UNK' if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text return phones, word2ph, norm_text

View File

@ -45022,3 +45022,5 @@
黄发台背: ['huang2', 'fa1', 'tai2', 'bei4'] 黄发台背: ['huang2', 'fa1', 'tai2', 'bei4']
鼎铛玉石: ['ding3', 'cheng1', 'yu4', 'shi2'] 鼎铛玉石: ['ding3', 'cheng1', 'yu4', 'shi2']
齿豁头童: ['chi3', 'huo1', 'tou2', 'tong2'] 齿豁头童: ['chi3', 'huo1', 'tou2', 'tong2']
牦牛: ['mao2', 'niu2']
牦: ['mao2']

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,9 @@
# modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py # modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py
import re import re
import pyopenjtalk
import os import os
import hashlib import hashlib
try:
import pyopenjtalk
current_file_path = os.path.dirname(__file__) current_file_path = os.path.dirname(__file__)
def get_hash(fp: str) -> str: def get_hash(fp: str) -> str:
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
@ -24,6 +24,11 @@ if os.path.exists(USERDIC_CSV_PATH):
if os.path.exists(USERDIC_BIN_PATH): if os.path.exists(USERDIC_BIN_PATH):
pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH) pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH)
except Exception as e:
# print(e)
import pyopenjtalk
# failed to load user dictionary, ignore.
pass
from text.symbols import punctuation from text.symbols import punctuation
@ -80,10 +85,6 @@ def post_replace_ph(ph):
if ph in rep_map.keys(): if ph in rep_map.keys():
ph = rep_map[ph] ph = rep_map[ph]
# if ph in symbols:
# return ph
# if ph not in symbols:
# ph = "UNK"
return ph return ph
@ -103,6 +104,8 @@ def symbols_to_japanese(text):
def preprocess_jap(text, with_prosody=False): def preprocess_jap(text, with_prosody=False):
"""Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html""" """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
text = symbols_to_japanese(text) text = symbols_to_japanese(text)
# English words to lower case, should have no influence on japanese words.
text = text.lower()
sentences = re.split(_japanese_marks, text) sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text) marks = re.findall(_japanese_marks, text)
text = [] text = []
@ -219,5 +222,5 @@ def g2p(norm_text, with_prosody=True):
if __name__ == "__main__": if __name__ == "__main__":
phones = g2p("こんにちは, hello, AKITOです,よろしくお願いしますね") phones = g2p("Hello.こんにちは今日もNiCe天気ですねtokyotowerに行きましょう")
print(phones) print(phones)

View File

@ -681,6 +681,7 @@ class ToneSandhi:
and seg[i - 1][0] == "" and seg[i - 1][0] == ""
and seg[i - 2][0] == word and seg[i - 2][0] == word
and pos == "v" and pos == "v"
and seg[i - 2][1] == "v"
): ):
continue continue
else: else:

View File

@ -186,6 +186,7 @@ def replace_positive_quantifier(match) -> str:
match_2: str = match_2 if match_2 else "" match_2: str = match_2 if match_2 else ""
quantifiers: str = match.group(3) quantifiers: str = match.group(3)
number: str = num2str(number) number: str = num2str(number)
number = "" if number == "" else number
result = f"{number}{match_2}{quantifiers}" result = f"{number}{match_2}{quantifiers}"
return result return result

View File

@ -184,8 +184,8 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|en|I like playing Genshin.
#### Integrated Package Users #### Integrated Package Users
Double-click `go-webui.bat`or use `go-webui.ps` Double-click `go-webui.bat`or use `go-webui.ps1`
if you want to switch to V1,then double-click`go-webui-v1.bat` or use `go-webui-v1.ps` if you want to switch to V1,then double-click`go-webui-v1.bat` or use `go-webui-v1.ps1`
#### Others #### Others
@ -220,7 +220,7 @@ Or maunally switch version in WebUI
#### Integrated Package Users #### Integrated Package Users
Double-click `go-webui-v2.bat` or use `go-webui-v2.ps` ,then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference` Double-click `go-webui-v2.bat` or use `go-webui-v2.ps1` ,then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference`
#### Others #### Others

174
api.py
View File

@ -20,6 +20,7 @@
`-hp` - `覆盖 config.py 使用半精度` `-hp` - `覆盖 config.py 使用半精度`
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` `-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` ·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"`
·-cp` - `文本切分符号设定, 默认为空, ",.,。"字符串的方式传入` ·-cp` - `文本切分符号设定, 默认为空, ",.,。"字符串的方式传入`
`-hb` - `cnhubert路径` `-hb` - `cnhubert路径`
@ -74,7 +75,7 @@ RESP:
手动指定当次推理所使用的参考音频并提供参数: 手动指定当次推理所使用的参考音频并提供参数:
GET: GET:
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1` `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"`
POST: POST:
```json ```json
{ {
@ -86,7 +87,8 @@ POST:
"top_k": 20, "top_k": 20,
"top_p": 0.6, "top_p": 0.6,
"temperature": 0.6, "temperature": 0.6,
"speed": 1 "speed": 1,
"inp_refs": ["456.wav","789.wav"]
} }
``` ```
@ -153,7 +155,7 @@ from time import time as ttime
import torch import torch
import librosa import librosa
import soundfile as sf import soundfile as sf
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, Query, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -195,8 +197,24 @@ def is_full(*items): # 任意一项为空返回False
return True return True
def change_sovits_weights(sovits_path): class Speaker:
global vq_model, hps def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
self.name = name
self.sovits = sovits
self.gpt = gpt
self.phones = phones
self.bert = bert
self.prompt = prompt
speaker_list = {}
class Sovits:
def __init__(self, vq_model, hps):
self.vq_model = vq_model
self.hps = hps
def get_sovits_weights(sovits_path):
dict_s2 = torch.load(sovits_path, map_location="cpu") dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"] hps = dict_s2["config"]
hps = DictToAttrRecursive(hps) hps = DictToAttrRecursive(hps)
@ -205,7 +223,7 @@ def change_sovits_weights(sovits_path):
hps.model.version = "v1" hps.model.version = "v1"
else: else:
hps.model.version = "v2" hps.model.version = "v2"
print("sovits版本:",hps.model.version) logger.info(f"模型版本: {hps.model.version}")
model_params_dict = vars(hps.model) model_params_dict = vars(hps.model)
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
@ -222,10 +240,17 @@ def change_sovits_weights(sovits_path):
vq_model.eval() vq_model.eval()
vq_model.load_state_dict(dict_s2["weight"], strict=False) vq_model.load_state_dict(dict_s2["weight"], strict=False)
sovits = Sovits(vq_model, hps)
return sovits
def change_gpt_weights(gpt_path): class Gpt:
global hz, max_sec, t2s_model, config def __init__(self, max_sec, t2s_model):
self.max_sec = max_sec
self.t2s_model = t2s_model
global hz
hz = 50 hz = 50
def get_gpt_weights(gpt_path):
dict_s1 = torch.load(gpt_path, map_location="cpu") dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"] config = dict_s1["config"]
max_sec = config["data"]["max_sec"] max_sec = config["data"]["max_sec"]
@ -238,6 +263,19 @@ def change_gpt_weights(gpt_path):
total = sum([param.nelement() for param in t2s_model.parameters()]) total = sum([param.nelement() for param in t2s_model.parameters()])
logger.info("Number of parameter: %.2fM" % (total / 1e6)) logger.info("Number of parameter: %.2fM" % (total / 1e6))
gpt = Gpt(max_sec, t2s_model)
return gpt
def change_gpt_sovits_weights(gpt_path,sovits_path):
try:
gpt = get_gpt_weights(gpt_path)
sovits = get_sovits_weights(sovits_path)
except Exception as e:
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
with torch.no_grad(): with torch.no_grad():
@ -289,14 +327,14 @@ def get_phones_and_bert(text,language,version,final=False):
if language == "zh": if language == "zh":
if re.search(r'[A-Za-z]', formattext): if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext) formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"zh",version) return get_phones_and_bert(formattext,"zh",version)
else: else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version) phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device) bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext): elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext) formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version) return get_phones_and_bert(formattext,"yue",version)
else: else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version) phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
@ -375,8 +413,11 @@ class DictToAttrRecursive(dict):
def get_spepc(hps, filename): def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate)) audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
maxx=audio.abs().max()
if(maxx>1):
audio/=min(2,maxx)
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
@ -448,22 +489,32 @@ def pack_raw(audio_bytes, data, rate):
def pack_wav(audio_bytes, rate): def pack_wav(audio_bytes, rate):
if is_int32:
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32')
else:
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16) data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
wav_bytes = BytesIO() wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='wav') sf.write(wav_bytes, data, rate, format='WAV')
return wav_bytes return wav_bytes
def pack_aac(audio_bytes, data, rate): def pack_aac(audio_bytes, data, rate):
if is_int32:
pcm = 's32le'
bit_rate = '256k'
else:
pcm = 's16le'
bit_rate = '128k'
process = subprocess.Popen([ process = subprocess.Popen([
'ffmpeg', 'ffmpeg',
'-f', 's16le', # 输入16位有符号小端整数PCM '-f', pcm, # 输入16位有符号小端整数PCM
'-ar', str(rate), # 设置采样率 '-ar', str(rate), # 设置采样率
'-ac', '1', # 单声道 '-ac', '1', # 单声道
'-i', 'pipe:0', # 从管道读取输入 '-i', 'pipe:0', # 从管道读取输入
'-c:a', 'aac', # 音频编码器为AAC '-c:a', 'aac', # 音频编码器为AAC
'-b:a', '192k', # 比特率 '-b:a', bit_rate, # 比特率
'-vn', # 不包含视频 '-vn', # 不包含视频
'-f', 'adts', # 输出AAC数据流格式 '-f', 'adts', # 输出AAC数据流格式
'pipe:1' # 将输出写入管道 'pipe:1' # 将输出写入管道
@ -504,10 +555,21 @@ def only_punc(text):
return not any(t.isalnum() or t.isalpha() for t in text) return not any(t.isalnum() or t.isalpha() for t in text)
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1): splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
infer_sovits = speaker_list[spk].sovits
vq_model = infer_sovits.vq_model
hps = infer_sovits.hps
infer_gpt = speaker_list[spk].gpt
t2s_model = infer_gpt.t2s_model
max_sec = infer_gpt.max_sec
t0 = ttime() t0 = ttime()
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
prompt_language, text = prompt_language, text.strip("\n") prompt_language, text = prompt_language, text.strip("\n")
dtype = torch.float16 if is_half == True else torch.float32
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad(): with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) wav16k, sr = librosa.load(ref_wav_path, sr=16000)
@ -523,6 +585,19 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content) codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
refers=[]
if(inp_refs):
for path in inp_refs:
try:
refer = get_spepc(hps, path).to(dtype).to(device)
refers.append(refer)
except Exception as e:
logger.error(e)
if(len(refers)==0):
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
t1 = ttime() t1 = ttime()
version = vq_model.version version = vq_model.version
os.environ['version'] = version os.environ['version'] = version
@ -538,16 +613,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue continue
audio_opt = [] audio_opt = []
if (text[-1] not in splits): text += "" if text_language != "en" else "."
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
bert = torch.cat([bert1, bert2], 1) bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime() t2 = ttime()
with torch.no_grad(): with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel( pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
all_phoneme_len, all_phoneme_len,
@ -558,22 +632,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
top_p = top_p, top_p = top_p,
temperature = temperature, temperature = temperature,
early_stop_num=hz * max_sec) early_stop_num=hz * max_sec)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime() t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \ audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer,speed=speed).detach().cpu().numpy()[ refers,speed=speed).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分 0, 0] ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()
if max_audio>1:
audio/=max_audio
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav) audio_opt.append(zero_wav)
t4 = ttime() t4 = ttime()
if is_int32:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
else:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == "normal": if stream_mode == "normal":
@ -615,7 +688,7 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200) return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed): def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
if ( if (
refer_wav_path == "" or refer_wav_path is None refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None or prompt_text == "" or prompt_text is None
@ -634,7 +707,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
else: else:
text = cut_text(text,cut_punc) text = cut_text(text,cut_punc)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type) return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
@ -691,6 +764,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals
# 此时 full_precision==True, half_precision==False # 此时 full_precision==True, half_precision==False
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32")
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
# 切割常用分句符为 `python ./api.py -cp ".?!。?!"` # 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
@ -752,6 +826,14 @@ else:
media_type = "ogg" media_type = "ogg"
logger.info(f"编码格式: {media_type}") logger.info(f"编码格式: {media_type}")
# 音频数据类型
if args.sub_type.lower() == 'int32':
is_int32 = True
logger.info(f"数据类型: int32")
else:
is_int32 = False
logger.info(f"数据类型: int16")
# 初始化模型 # 初始化模型
cnhubert.cnhubert_base_path = cnhubert_base_path cnhubert.cnhubert_base_path = cnhubert_base_path
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path)
@ -763,9 +845,7 @@ if is_half:
else: else:
bert_model = bert_model.to(device) bert_model = bert_model.to(device)
ssl_model = ssl_model.to(device) ssl_model = ssl_model.to(device)
change_sovits_weights(sovits_path) change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
change_gpt_weights(gpt_path)
@ -777,14 +857,18 @@ app = FastAPI()
@app.post("/set_model") @app.post("/set_model")
async def set_model(request: Request): async def set_model(request: Request):
json_post_raw = await request.json() json_post_raw = await request.json()
global gpt_path return change_gpt_sovits_weights(
gpt_path=json_post_raw.get("gpt_model_path") gpt_path = json_post_raw.get("gpt_model_path"),
global sovits_path
sovits_path = json_post_raw.get("sovits_model_path") sovits_path = json_post_raw.get("sovits_model_path")
logger.info("gptpath"+gpt_path+";vitspath"+sovits_path) )
change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path)
return "ok" @app.get("/set_model")
async def set_model(
gpt_model_path: str = None,
sovits_model_path: str = None,
):
return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path)
@app.post("/control") @app.post("/control")
@ -827,10 +911,11 @@ async def tts_endpoint(request: Request):
json_post_raw.get("text"), json_post_raw.get("text"),
json_post_raw.get("text_language"), json_post_raw.get("text_language"),
json_post_raw.get("cut_punc"), json_post_raw.get("cut_punc"),
json_post_raw.get("top_k", 10), json_post_raw.get("top_k", 15),
json_post_raw.get("top_p", 1.0), json_post_raw.get("top_p", 1.0),
json_post_raw.get("temperature", 1.0), json_post_raw.get("temperature", 1.0),
json_post_raw.get("speed", 1.0) json_post_raw.get("speed", 1.0),
json_post_raw.get("inp_refs", [])
) )
@ -842,12 +927,13 @@ async def tts_endpoint(
text: str = None, text: str = None,
text_language: str = None, text_language: str = None,
cut_punc: str = None, cut_punc: str = None,
top_k: int = 10, top_k: int = 15,
top_p: float = 1.0, top_p: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
speed: float = 1.0 speed: float = 1.0,
inp_refs: list = Query(default=[])
): ):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed) return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
if __name__ == "__main__": if __name__ == "__main__":

458
api_v2.py Normal file
View File

@ -0,0 +1,458 @@
"""
# WebAPI文档
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
## 执行参数:
`-a` - `绑定地址, 默认"127.0.0.1"`
`-p` - `绑定端口, 默认9880`
`-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
## 调用:
### 推理
endpoint: `/tts`
GET:
```
http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是罗浮云骑将军景元不必拘谨将军只是一时的身份你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
```
POST:
```json
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"streaming_mode": False, # bool. whether to return a streaming response.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
}
```
RESP:
成功: 直接返回 wav 音频流 http code 200
失败: 返回包含错误信息的 json, http code 400
### 命令控制
endpoint: `/control`
command:
"restart": 重新运行
"exit": 结束运行
GET:
```
http://127.0.0.1:9880/control?command=restart
```
POST:
```json
{
"command": "restart"
}
```
RESP:
### 切换GPT模型
endpoint: `/set_gpt_weights`
GET:
```
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
```
RESP:
成功: 返回"success", http code 200
失败: 返回包含错误信息的 json, http code 400
### 切换Sovits模型
endpoint: `/set_sovits_weights`
GET:
```
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
```
RESP:
成功: 返回"success", http code 200
失败: 返回包含错误信息的 json, http code 400
"""
import os
import sys
import traceback
from typing import Generator
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import argparse
import subprocess
import wave
import signal
import numpy as np
import soundfile as sf
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, UploadFile, File
import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
# print(sys.path)
i18n = I18nAuto()
cut_method_names = get_cut_method_names()
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
args = parser.parse_args()
config_path = args.tts_config
# device = args.device
port = args.port
host = args.bind_addr
argv = sys.argv
if config_path in [None, ""]:
config_path = "GPT-SoVITS/configs/tts_infer.yaml"
tts_config = TTS_Config(config_path)
print(tts_config)
tts_pipeline = TTS(tts_config)
APP = FastAPI()
class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
ref_audio_path: str = None
aux_ref_audio_paths: list = None
prompt_lang: str = None
prompt_text: str = ""
top_k:int = 5
top_p:float = 1
temperature:float = 1
text_split_method:str = "cut5"
batch_size:int = 1
batch_threshold:float = 0.75
split_bucket:bool = True
speed_factor:float = 1.0
fragment_interval:float = 0.3
seed:int = -1
media_type:str = "wav"
streaming_mode:bool = False
parallel_infer:bool = True
repetition_penalty:float = 1.35
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
audio_file.write(data)
return io_buffer
def pack_raw(io_buffer:BytesIO, data:np.ndarray, rate:int):
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer:BytesIO, data:np.ndarray, rate:int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format='wav')
return io_buffer
def pack_aac(io_buffer:BytesIO, data:np.ndarray, rate:int):
process = subprocess.Popen([
'ffmpeg',
'-f', 's16le', # 输入16位有符号小端整数PCM
'-ar', str(rate), # 设置采样率
'-ac', '1', # 单声道
'-i', 'pipe:0', # 从管道读取输入
'-c:a', 'aac', # 音频编码器为AAC
'-b:a', '192k', # 比特率
'-vn', # 不包含视频
'-f', 'adts', # 输出AAC数据流格式
'pipe:1' # 将输出写入管道
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
# This will create a wave header then append the frame input
# It should be first on a streaming wav file
# Other frames better should not have it (else you will hear some artifacts each chunk start)
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()
def handle_control(command:str):
if command == "restart":
os.execl(sys.executable, sys.executable, *argv)
elif command == "exit":
os.kill(os.getpid(), signal.SIGTERM)
exit(0)
def check_params(req:dict):
text:str = req.get("text", "")
text_lang:str = req.get("text_lang", "")
ref_audio_path:str = req.get("ref_audio_path", "")
streaming_mode:bool = req.get("streaming_mode", False)
media_type:str = req.get("media_type", "wav")
prompt_lang:str = req.get("prompt_lang", "")
text_split_method:str = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if (text_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
elif text_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"})
if (prompt_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
elif prompt_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"})
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
if text_split_method not in cut_method_names:
return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
return None
async def tts_handle(req:dict):
"""
Text to speech handler.
Args:
req (dict):
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
}
returns:
StreamingResponse: audio stream response.
"""
streaming_mode = req.get("streaming_mode", False)
return_fragment = req.get("return_fragment", False)
media_type = req.get("media_type", "wav")
check_res = check_params(req)
if check_res is not None:
return check_res
if streaming_mode or return_fragment:
req["return_fragment"] = True
try:
tts_generator=tts_pipeline.run(req)
if streaming_mode:
def streaming_generator(tts_generator:Generator, media_type:str):
if media_type == "wav":
yield wave_header_chunk()
media_type = "raw"
for sr, chunk in tts_generator:
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
else:
sr, audio_data = next(tts_generator)
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return Response(audio_data, media_type=f"audio/{media_type}")
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
@APP.get("/control")
async def control(command: str = None):
if command is None:
return JSONResponse(status_code=400, content={"message": "command is required"})
handle_control(command)
@APP.get("/tts")
async def tts_get_endpoint(
text: str = None,
text_lang: str = None,
ref_audio_path: str = None,
aux_ref_audio_paths:list = None,
prompt_lang: str = None,
prompt_text: str = "",
top_k:int = 5,
top_p:float = 1,
temperature:float = 1,
text_split_method:str = "cut0",
batch_size:int = 1,
batch_threshold:float = 0.75,
split_bucket:bool = True,
speed_factor:float = 1.0,
fragment_interval:float = 0.3,
seed:int = -1,
media_type:str = "wav",
streaming_mode:bool = False,
parallel_infer:bool = True,
repetition_penalty:float = 1.35
):
req = {
"text": text,
"text_lang": text_lang.lower(),
"ref_audio_path": ref_audio_path,
"aux_ref_audio_paths": aux_ref_audio_paths,
"prompt_text": prompt_text,
"prompt_lang": prompt_lang.lower(),
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": text_split_method,
"batch_size":int(batch_size),
"batch_threshold":float(batch_threshold),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"fragment_interval":fragment_interval,
"seed":seed,
"media_type":media_type,
"streaming_mode":streaming_mode,
"parallel_infer":parallel_infer,
"repetition_penalty":float(repetition_penalty)
}
return await tts_handle(req)
@APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request):
req = request.dict()
return await tts_handle(req)
@APP.get("/set_refer_audio")
async def set_refer_aduio(refer_audio_path: str = None):
try:
tts_pipeline.set_ref_audio(refer_audio_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
# @APP.post("/set_refer_audio")
# async def set_refer_aduio_post(audio_file: UploadFile = File(...)):
# try:
# # 检查文件类型,确保是音频文件
# if not audio_file.content_type.startswith("audio/"):
# return JSONResponse(status_code=400, content={"message": "file type is not supported"})
# os.makedirs("uploaded_audio", exist_ok=True)
# save_path = os.path.join("uploaded_audio", audio_file.filename)
# # 保存音频文件到服务器上的一个目录
# with open(save_path , "wb") as buffer:
# buffer.write(await audio_file.read())
# tts_pipeline.set_ref_audio(save_path)
# except Exception as e:
# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
# return JSONResponse(status_code=200, content={"message": "success"})
@APP.get("/set_gpt_weights")
async def set_gpt_weights(weights_path: str = None):
try:
if weights_path in ["", None]:
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
tts_pipeline.init_t2s_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"change gpt weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
@APP.get("/set_sovits_weights")
async def set_sovits_weights(weights_path: str = None):
try:
if weights_path in ["", None]:
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
tts_pipeline.init_vits_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
if __name__ == "__main__":
try:
uvicorn.run(app=APP, host=host, port=port, workers=1)
except Exception as e:
traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM)
exit(0)

View File

@ -232,5 +232,16 @@
7-计时逻辑优化 https://github.com/RVC-Boss/GPT-SoVITS/pull/1387 7-计时逻辑优化 https://github.com/RVC-Boss/GPT-SoVITS/pull/1387
### 20240821
1-fast_inference分支合并进mainhttps://github.com/RVC-Boss/GPT-SoVITS/pull/1490
2-支持通过ssml标签优化数字、电话、时间日期等https://github.com/RVC-Boss/GPT-SoVITS/issues/1508
3-api修复优化https://github.com/RVC-Boss/GPT-SoVITS/pull/1503
4-修复了参考音频混合只能上传一条的bug:https://github.com/RVC-Boss/GPT-SoVITS/pull/1422
5-增加了各种数据集检查,若缺失会弹出warning:https://github.com/RVC-Boss/GPT-SoVITS/pull/1422

View File

@ -181,8 +181,8 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|zh|我爱玩原神。
#### 整合包用户 #### 整合包用户
双击`go-webui.bat`或者使用`go-webui.ps` 双击`go-webui.bat`或者使用`go-webui.ps1`
若想使用V1,则双击`go-webui-v1.bat`或者使用`go-webui-v1.ps` 若想使用V1,则双击`go-webui-v1.bat`或者使用`go-webui-v1.ps1`
#### 其他 #### 其他
@ -217,7 +217,7 @@ python webui.py v1 <language(optional)>
#### 整合包用户 #### 整合包用户
双击 `go-webui.bat` 或者使用 `go-webui.ps` ,然后在 `1-GPT-SoVITS-TTS/1C-推理` 中打开推理webUI 双击 `go-webui.bat` 或者使用 `go-webui.ps1` ,然后在 `1-GPT-SoVITS-TTS/1C-推理` 中打开推理webUI
#### 其他 #### 其他

View File

@ -171,8 +171,8 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|en|I like playing Genshin.
#### 統合パッケージ利用者 #### 統合パッケージ利用者
`go-webui.bat`をダブルクリックするか、`go-webui.ps`を使用します。 `go-webui.bat`をダブルクリックするか、`go-webui.ps1`を使用します。
V1に切り替えたい場合は、`go-webui-v1.bat`をダブルクリックするか、`go-webui-v1.ps`を使用してください。 V1に切り替えたい場合は、`go-webui-v1.bat`をダブルクリックするか、`go-webui-v1.ps1`を使用してください。
#### その他 #### その他
@ -207,7 +207,7 @@ python webui.py v1 <言語(オプション)>
#### 統合パッケージ利用者 #### 統合パッケージ利用者
`go-webui-v2.bat`をダブルクリックするか、`go-webui-v2.ps`を使用して、`1-GPT-SoVITS-TTS/1C-inference`で推論webuiを開きます。 `go-webui-v2.bat`をダブルクリックするか、`go-webui-v2.ps1`を使用して、`1-GPT-SoVITS-TTS/1C-inference`で推論webuiを開きます。
#### その他 #### その他

View File

@ -175,8 +175,8 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|en|I like playing Genshin.
#### 통합 패키지 사용자 #### 통합 패키지 사용자
`go-webui.bat`을 더블 클릭하거나 `go-webui.ps`를 사용하십시오. `go-webui.bat`을 더블 클릭하거나 `go-webui.ps1`를 사용하십시오.
V1으로 전환하려면, `go-webui-v1.bat`을 더블 클릭하거나 `go-webui-v1.ps`를 사용하십시오. V1으로 전환하려면, `go-webui-v1.bat`을 더블 클릭하거나 `go-webui-v1.ps1`를 사용하십시오.
#### 기타 #### 기타
@ -211,7 +211,7 @@ python webui.py v1 <언어(옵션)>
#### 통합 패키지 사용자 #### 통합 패키지 사용자
`go-webui-v2.bat`을 더블 클릭하거나 `go-webui-v2.ps`를 사용한 다음 `1-GPT-SoVITS-TTS/1C-inference`에서 추론 webui를 엽니다. `go-webui-v2.bat`을 더블 클릭하거나 `go-webui-v2.ps1`를 사용한 다음 `1-GPT-SoVITS-TTS/1C-inference`에서 추론 webui를 엽니다.
#### 기타 #### 기타

View File

@ -172,8 +172,8 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|en|I like playing Genshin.
#### Entegre Paket Kullanıcıları #### Entegre Paket Kullanıcıları
`go-webui.bat` dosyasına çift tıklayın veya `go-webui.ps` kullanın. `go-webui.bat` dosyasına çift tıklayın veya `go-webui.ps1` kullanın.
V1'e geçmek istiyorsanız, `go-webui-v1.bat` dosyasına çift tıklayın veya `go-webui-v1.ps` kullanın. V1'e geçmek istiyorsanız, `go-webui-v1.bat` dosyasına çift tıklayın veya `go-webui-v1.ps1` kullanın.
#### Diğerleri #### Diğerleri
@ -208,7 +208,7 @@ veya WebUI'de manuel olarak sürüm değiştirin.
#### Entegre Paket Kullanıcıları #### Entegre Paket Kullanıcıları
`go-webui-v2.bat` dosyasına çift tıklayın veya `go-webui-v2.ps` kullanın, ardından çıkarım webui'sini `1-GPT-SoVITS-TTS/1C-inference` adresinde açın. `go-webui-v2.bat` dosyasına çift tıklayın veya `go-webui-v2.ps1` kullanın, ardından çıkarım webui'sini `1-GPT-SoVITS-TTS/1C-inference` adresinde açın.
#### Diğerleri #### Diğerleri

View File

@ -1,2 +0,0 @@
runtime\python.exe webui.py v1 zh_CN
pause

View File

@ -1,4 +0,0 @@
$ErrorActionPreference = "SilentlyContinue"
chcp 65001
& "$PSScriptRoot\runtime\python.exe" "$PSScriptRoot\webui.py v1 zh_CN"
pause

View File

@ -40,3 +40,4 @@ ko_pron
opencc; sys_platform != 'linux' opencc; sys_platform != 'linux'
opencc==1.1.1; sys_platform == 'linux' opencc==1.1.1; sys_platform == 'linux'
python_mecab_ko; sys_platform != 'win32' python_mecab_ko; sys_platform != 'win32'
fastapi<0.112.2

View File

@ -9,6 +9,7 @@ DEFAULT_LANGUAGE: str = "zh_CN" # 默认语言
TITLE_LEN : int = 60 # 标题显示长度 TITLE_LEN : int = 60 # 标题显示长度
KEY_LEN : int = 30 # 键名显示长度 KEY_LEN : int = 30 # 键名显示长度
SHOW_KEYS : bool = False # 是否显示键信息 SHOW_KEYS : bool = False # 是否显示键信息
SORT_KEYS : bool = False # 是否按全局键名写入文件
def extract_i18n_strings(node): def extract_i18n_strings(node):
i18n_strings = [] i18n_strings = []
@ -49,6 +50,7 @@ def scan_i18n_strings():
return code_keys return code_keys
def update_i18n_json(json_file, standard_keys): def update_i18n_json(json_file, standard_keys):
standard_keys = sorted(standard_keys)
print(f" Process {json_file} ".center(TITLE_LEN, "=")) print(f" Process {json_file} ".center(TITLE_LEN, "="))
# 读取 JSON 文件 # 读取 JSON 文件
with open(json_file, "r", encoding="utf-8") as f: with open(json_file, "r", encoding="utf-8") as f:
@ -79,8 +81,13 @@ def update_i18n_json(json_file, standard_keys):
print(f"{'Removed Unused Key'.ljust(KEY_LEN)}: {key}") print(f"{'Removed Unused Key'.ljust(KEY_LEN)}: {key}")
# 按键顺序排序 # 按键顺序排序
json_data = OrderedDict( json_data = OrderedDict(
sorted(json_data.items(), sorted(
key=lambda x: list(standard_keys).index(x[0]))) json_data.items(),
key=lambda x: (
list(standard_keys).index(x[0]) if x[0] in standard_keys and not x[1].startswith('#!') else len(json_data),
)
)
)
# 打印处理后的 JSON 条目数 # 打印处理后的 JSON 条目数
if len(miss_keys) != 0 or len(diff_keys) != 0: if len(miss_keys) != 0 or len(diff_keys) != 0:
print(f"{'Total Keys (After)'.ljust(KEY_LEN)}: {len(json_data)}") print(f"{'Total Keys (After)'.ljust(KEY_LEN)}: {len(json_data)}")
@ -107,7 +114,7 @@ def update_i18n_json(json_file, standard_keys):
print(f"\033[32m[Passed] All Keys Translated\033[0m") print(f"\033[32m[Passed] All Keys Translated\033[0m")
# 将处理后的结果写入 JSON 文件 # 将处理后的结果写入 JSON 文件
with open(json_file, "w", encoding="utf-8") as f: with open(json_file, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=4, sort_keys=True) json.dump(json_data, f, ensure_ascii=False, indent=4, sort_keys=SORT_KEYS)
f.write("\n") f.write("\n")
print(f" Updated {json_file} ".center(TITLE_LEN, "=") + '\n') print(f" Updated {json_file} ".center(TITLE_LEN, "=") + '\n')

View File

@ -4,6 +4,11 @@ import json
import os import os
import uuid import uuid
try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...
import librosa import librosa
import gradio as gr import gradio as gr
import numpy as np import numpy as np

View File

@ -14,6 +14,11 @@ from mdxnet import MDXNetDereverb
from vr import AudioPre, AudioPreDeEcho from vr import AudioPre, AudioPreDeEcho
from bsroformer import BsRoformer_Loader from bsroformer import BsRoformer_Loader
try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...
weight_uvr5_root = "tools/uvr5/uvr5_weights" weight_uvr5_root = "tools/uvr5/uvr5_weights"
uvr5_names = [] uvr5_names = []
for name in os.listdir(weight_uvr5_root): for name in os.listdir(weight_uvr5_root):
@ -194,6 +199,7 @@ with gr.Blocks(title="UVR5 WebUI") as app:
[vc_output4], [vc_output4],
api_name="uvr_convert", api_name="uvr_convert",
) )
app.queue(max_size=1022).launch( app.queue(max_size=1022).launch(
server_name="0.0.0.0", server_name="0.0.0.0",
inbrowser=True, inbrowser=True,

View File

@ -223,8 +223,12 @@ def change_uvr5():
p_uvr5=None p_uvr5=None
yield i18n("UVR5已关闭"), {'__type__':'update','visible':True}, {'__type__':'update','visible':False} yield i18n("UVR5已关闭"), {'__type__':'update','visible':True}, {'__type__':'update','visible':False}
def change_tts_inference(bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits_path): def change_tts_inference(bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits_path, batched_infer_enabled):
global p_tts_inference global p_tts_inference
if batched_infer_enabled:
cmd = '"%s" GPT_SoVITS/inference_webui_fast.py "%s"'%(python_exec, language)
else:
cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language)
if(p_tts_inference==None): if(p_tts_inference==None):
os.environ["gpt_path"]=gpt_path if "/" in gpt_path else "%s/%s"%(GPT_weight_root,gpt_path) os.environ["gpt_path"]=gpt_path if "/" in gpt_path else "%s/%s"%(GPT_weight_root,gpt_path)
os.environ["sovits_path"]=sovits_path if "/"in sovits_path else "%s/%s"%(SoVITS_weight_root,sovits_path) os.environ["sovits_path"]=sovits_path if "/"in sovits_path else "%s/%s"%(SoVITS_weight_root,sovits_path)
@ -234,7 +238,6 @@ def change_tts_inference(bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits
os.environ["is_half"]=str(is_half) os.environ["is_half"]=str(is_half)
os.environ["infer_ttswebui"]=str(webui_port_infer_tts) os.environ["infer_ttswebui"]=str(webui_port_infer_tts)
os.environ["is_share"]=str(is_share) os.environ["is_share"]=str(is_share)
cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language)
yield i18n("TTS推理进程已开启"), {'__type__':'update','visible':False}, {'__type__':'update','visible':True} yield i18n("TTS推理进程已开启"), {'__type__':'update','visible':False}, {'__type__':'update','visible':True}
print(cmd) print(cmd)
p_tts_inference = Popen(cmd, shell=True) p_tts_inference = Popen(cmd, shell=True)
@ -1031,15 +1034,18 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary") refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices,inputs=[],outputs=[SoVITS_dropdown,GPT_dropdown]) refresh_button.click(fn=change_choices,inputs=[],outputs=[SoVITS_dropdown,GPT_dropdown])
with gr.Row(): with gr.Row():
with gr.Row():
batched_infer_enabled = gr.Checkbox(label=i18n("启用并行推理版本(推理速度更快)"), value=False, interactive=True, show_label=True)
with gr.Row(): with gr.Row():
open_tts = gr.Button(value=i18n("开启TTS推理WebUI"),variant='primary',visible=True) open_tts = gr.Button(value=i18n("开启TTS推理WebUI"),variant='primary',visible=True)
close_tts = gr.Button(value=i18n("关闭TTS推理WebUI"),variant='primary',visible=False) close_tts = gr.Button(value=i18n("关闭TTS推理WebUI"),variant='primary',visible=False)
with gr.Row(): with gr.Row():
tts_info = gr.Textbox(label=i18n("TTS推理WebUI进程输出信息")) tts_info = gr.Textbox(label=i18n("TTS推理WebUI进程输出信息"))
open_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown], [tts_info,open_tts,close_tts]) open_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
close_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown], [tts_info,open_tts,close_tts]) close_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
version_checkbox.change(switch_version,[version_checkbox],[pretrained_s2G,pretrained_s2D,pretrained_s1,GPT_dropdown,SoVITS_dropdown]) version_checkbox.change(switch_version,[version_checkbox],[pretrained_s2G,pretrained_s2D,pretrained_s1,GPT_dropdown,SoVITS_dropdown])
with gr.TabItem(i18n("2-GPT-SoVITS-变声")):gr.Markdown(value=i18n("施工中,请静候佳音")) with gr.TabItem(i18n("2-GPT-SoVITS-变声")):gr.Markdown(value=i18n("施工中,请静候佳音"))
app.queue(max_size=1022).launch( app.queue(max_size=1022).launch(
server_name="0.0.0.0", server_name="0.0.0.0",
inbrowser=True, inbrowser=True,