mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
[fast_inference] 优化batch inference的mask策略 (#1477)
* 优化了batch inference的mask策略,使音频合成的质量更加稳定;改善了一些代码逻辑。 * 删除无用代码
This commit is contained in:
parent
7c43b41e6d
commit
f5a5f1890f
2
.gitignore
vendored
2
.gitignore
vendored
@ -10,3 +10,5 @@ reference
|
||||
GPT_weights
|
||||
SoVITS_weights
|
||||
TEMP
|
||||
ffmpeg*
|
||||
ffprobe*
|
@ -1,9 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import math
|
||||
import os, sys
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -38,6 +39,34 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
@torch.jit.script
|
||||
# Efficient implementation equivalent to the following:
|
||||
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
|
||||
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
if scale is None:
|
||||
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
||||
else:
|
||||
scale_factor = scale
|
||||
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_bias.masked_fill_(attn_mask, float("-inf"))
|
||||
else:
|
||||
attn_bias += attn_mask
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight += attn_bias
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
else:
|
||||
attn_mask[attn_mask!=float("-inf")] =0
|
||||
attn_mask[attn_mask==float("-inf")] =1
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
|
||||
return attn_weight @ value
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
@ -86,10 +115,16 @@ class T2SBlock:
|
||||
self.norm_eps2 = norm_eps2
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x, padding_mask):
|
||||
return x*padding_mask if padding_mask is not None else x
|
||||
|
||||
def process_prompt(self, x, attn_mask : torch.Tensor, padding_mask:torch.Tensor=None):
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
if padding_mask.dtype == torch.bool:
|
||||
return x.masked_fill(padding_mask, 0)
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
|
||||
|
||||
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
@ -106,27 +141,48 @@ class T2SBlock:
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
x = self.to_mask(x + attn, padding_mask)
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = self.to_mask(x + self.mlp.forward(self.to_mask(x, padding_mask)), padding_mask)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
idx = torch.where(padding_mask[i,:,0]==False)[0]
|
||||
x_item = x[i,idx,:].unsqueeze(0)
|
||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(
|
||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i,idx,:] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w2,
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x, k_cache, v_cache):
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
@ -141,7 +197,7 @@ class T2SBlock:
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
@ -169,8 +225,8 @@ class T2STransformer:
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x, attn_mask : torch.Tensor,
|
||||
padding_mask : torch.Tensor=None,
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,
|
||||
padding_mask : Optional[torch.Tensor]=None,
|
||||
):
|
||||
k_cache : List[torch.Tensor] = []
|
||||
v_cache : List[torch.Tensor] = []
|
||||
@ -181,10 +237,10 @@ class T2STransformer:
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
|
||||
self, x:torch.Tensor, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : Optional[torch.Tensor]=None,
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@ -506,12 +562,12 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def infer_panel_batch_infer_with_flash_attn(
|
||||
def infer_panel_batch_infer(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
bert_feature:List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
@ -519,30 +575,21 @@ class Text2SemanticDecoder(nn.Module):
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs,
|
||||
):
|
||||
# # fp16 会对结果产生影响(和没pad相比)
|
||||
# bert_feature_dtype = bert_feature[0].dtype
|
||||
# if not hasattr(self.bert_proj, "dtype"):
|
||||
# self.bert_proj.dtype = torch.float32
|
||||
# self.bert_proj=self.bert_proj.float()
|
||||
if prompts is None:
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_0307(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
|
||||
|
||||
## 先对phones进行embedding、对bert_features进行project,再pad到相同长度(padding策略会影响T2S模型生成的结果。)
|
||||
## pad之后再进行Linear会有误差(和没pad相比),就离谱。。。
|
||||
max_len = kwargs.get("max_len",x_lens.max())
|
||||
# for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
x_list = [self.ar_text_embedding(item) for item in x]
|
||||
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
|
||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
|
||||
x_list.append(x_item)
|
||||
x = torch.stack(x_list, dim=0)
|
||||
|
||||
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||
bert_feature = torch.stack(bert_features_list, dim=0)
|
||||
|
||||
|
||||
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2).float()).to(dtype=bert_feature_dtype)
|
||||
# x = self.ar_text_embedding(x)
|
||||
x = x + bert_feature
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
@ -593,16 +640,17 @@ class Text2SemanticDecoder(nn.Module):
|
||||
value=False,
|
||||
)
|
||||
|
||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
|
||||
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
|
||||
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
|
||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
for i in range(bsz):
|
||||
l = x_lens[i]
|
||||
_xy_padding_mask[i,l:max_len,:]=True
|
||||
|
||||
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||||
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
|
||||
|
||||
xy_padding_mask = ~xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
|
||||
xy_padding_mask = xy_padding_mask.to(dtype=x.dtype)
|
||||
xy_attn_mask = xy_attn_mask.bool()
|
||||
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
@ -612,27 +660,32 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask)
|
||||
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
|
||||
if idx == 0:
|
||||
xy_attn_mask = None
|
||||
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
else:
|
||||
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||
tokens = torch.argmax(logits, dim=-1)
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in samples[:, 0]) or \
|
||||
(self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS,则停止
|
||||
l = samples[:, 0]==self.EOS
|
||||
(self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0]==self.EOS
|
||||
l2 = tokens==self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
@ -647,6 +700,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None :
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
@ -682,13 +736,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
if ref_free:
|
||||
return y_list, [0]*x.shape[0]
|
||||
# print(idx_list)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_0307(self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
bert_feature:List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
@ -699,9 +754,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list = []
|
||||
idx_list = []
|
||||
for i in range(len(x)):
|
||||
y, idx = self.infer_panel_with_flash_attn_only(x[i].unsqueeze(0),
|
||||
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0),
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
top_k,
|
||||
top_p,
|
||||
@ -714,7 +769,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_with_flash_attn_only(
|
||||
def infer_panel_naive(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
@ -771,8 +826,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).unsqueeze(0).expand(bsz*self.num_head, -1, -1).view(bsz, self.num_head, src_len, src_len).to(x.device)
|
||||
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||||
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
|
||||
xy_attn_mask = xy_attn_mask.bool()
|
||||
# new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||||
# xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
|
||||
for idx in tqdm(range(1500)):
|
||||
if xy_attn_mask is not None:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||
@ -813,3 +869,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx - 1
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
):
|
||||
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
|
||||
|
@ -81,6 +81,8 @@ class TTS_Config:
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
}
|
||||
configs:dict = None
|
||||
languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
|
||||
def __init__(self, configs: Union[dict, str]=None):
|
||||
|
||||
# 设置默认配置文件路径
|
||||
@ -97,7 +99,7 @@ class TTS_Config:
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(self.configs_path)
|
||||
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
default_configs:dict = configs.get("default", None)
|
||||
if default_configs is not None:
|
||||
@ -138,8 +140,7 @@ class TTS_Config:
|
||||
self.hop_length:int = 640
|
||||
self.win_length:int = 2048
|
||||
self.n_speakers:int = 300
|
||||
|
||||
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
|
||||
|
||||
|
||||
def _load_configs(self, configs_path: str)->dict:
|
||||
@ -489,8 +490,8 @@ class TTS:
|
||||
all_phones_len_list = []
|
||||
all_bert_features_list = []
|
||||
norm_text_batch = []
|
||||
bert_max_len = 0
|
||||
phones_max_len = 0
|
||||
all_bert_max_len = 0
|
||||
all_phones_max_len = 0
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||
@ -505,8 +506,8 @@ class TTS:
|
||||
all_phones = phones
|
||||
# norm_text = item["norm_text"]
|
||||
|
||||
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
||||
phones_max_len = max(phones_max_len, phones.shape[-1])
|
||||
all_bert_max_len = max(all_bert_max_len, all_bert_features.shape[-1])
|
||||
all_phones_max_len = max(all_phones_max_len, all_phones.shape[-1])
|
||||
|
||||
phones_list.append(phones)
|
||||
phones_len_list.append(phones.shape[-1])
|
||||
@ -520,7 +521,7 @@ class TTS:
|
||||
all_bert_features_batch = all_bert_features_list
|
||||
|
||||
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
max_len = max(all_bert_max_len, all_phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
#### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
@ -630,7 +631,7 @@ class TTS:
|
||||
|
||||
if parallel_infer:
|
||||
print(i18n("并行推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer_with_flash_attn
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||
else:
|
||||
print(i18n("并行推理模式已关闭"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_0307
|
||||
@ -942,4 +943,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||
# 将管道输出解码为 NumPy 数组
|
||||
processed_audio = np.frombuffer(out, np.int16)
|
||||
|
||||
return processed_audio
|
||||
return processed_audio
|
Loading…
x
Reference in New Issue
Block a user