mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-08 16:00:01 +08:00
Merge remote-tracking branch 'beta/fast_inference_' 修正了多语言问题
This commit is contained in:
commit
df213e6aee
@ -234,10 +234,15 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
ignore_index=self.EOS,
|
ignore_index=self.EOS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not flash_attn_enabled:
|
self.enable_flash_attn(flash_attn_enabled)
|
||||||
|
|
||||||
|
def enable_flash_attn(self, enable:bool=True):
|
||||||
|
|
||||||
|
if not enable:
|
||||||
print("Not Using Flash Attention")
|
print("Not Using Flash Attention")
|
||||||
self.infer_panel = self.infer_panel_batch_only
|
self.infer_panel = self.infer_panel_batch_only
|
||||||
else:
|
else:
|
||||||
|
self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
|
||||||
print("Using Flash Attention")
|
print("Using Flash Attention")
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
@ -502,91 +507,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# 错位
|
# 错位
|
||||||
return targets[:, :-1], targets[:, 1:]
|
return targets[:, :-1], targets[:, 1:]
|
||||||
|
|
||||||
def infer_one_step(self, x, xy_attn_mask, k_cache, v_cache, cache_seqlens):
|
def infer_panel_batch_infer_with_flash_attn(
|
||||||
hidden_dim = x.shape[-1]
|
|
||||||
|
|
||||||
for layer_id in range(self.num_layers):
|
|
||||||
layer = self.h.layers[layer_id]
|
|
||||||
|
|
||||||
q, k, v = F.linear(
|
|
||||||
x,
|
|
||||||
layer.self_attn.in_proj_weight,
|
|
||||||
layer.self_attn.in_proj_bias
|
|
||||||
).chunk(3, dim=-1)
|
|
||||||
|
|
||||||
batch_size = q.shape[0]
|
|
||||||
q_len = q.shape[1]
|
|
||||||
|
|
||||||
if flash_attn_with_kvcache is None:
|
|
||||||
past_k = k_cache[layer_id]
|
|
||||||
past_v = v_cache[layer_id]
|
|
||||||
|
|
||||||
if past_k is not None:
|
|
||||||
k = torch.cat([past_k, k], 1)
|
|
||||||
v = torch.cat([past_v, v], 1)
|
|
||||||
k_cache[layer_id] = k
|
|
||||||
v_cache[layer_id] = v
|
|
||||||
kv_len = k.shape[1]
|
|
||||||
|
|
||||||
q = q.view(batch_size, q_len, layer.self_attn.num_heads, -1).transpose(1, 2)
|
|
||||||
k = k.view(batch_size, kv_len, layer.self_attn.num_heads, -1).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, kv_len, layer.self_attn.num_heads, -1).transpose(1, 2)
|
|
||||||
|
|
||||||
if xy_attn_mask is None:
|
|
||||||
attn = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
else:
|
|
||||||
attn = F.scaled_dot_product_attention(q, k, v, ~xy_attn_mask)
|
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(-1, hidden_dim)
|
|
||||||
else:
|
|
||||||
q = q.view(batch_size, q_len, layer.self_attn.num_heads, -1)
|
|
||||||
k = k.view(batch_size, q_len, layer.self_attn.num_heads, -1)
|
|
||||||
v = v.view(batch_size, q_len, layer.self_attn.num_heads, -1)
|
|
||||||
|
|
||||||
if xy_attn_mask is None:
|
|
||||||
attn = flash_attn_with_kvcache(q, k_cache[layer_id], v_cache[layer_id], k, v, cache_seqlens=cache_seqlens, causal=True)
|
|
||||||
else:
|
|
||||||
# NOTE: there's a slight difference with the result produced by SDPA.
|
|
||||||
x_len = (~xy_attn_mask).sum(1)[0].item()
|
|
||||||
|
|
||||||
attn_x = flash_attn_with_kvcache(
|
|
||||||
q[:, :x_len],
|
|
||||||
k_cache[layer_id],
|
|
||||||
v_cache[layer_id],
|
|
||||||
k[:, :x_len],
|
|
||||||
v[:, :x_len],
|
|
||||||
cache_seqlens=cache_seqlens,
|
|
||||||
causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_y = flash_attn_with_kvcache(
|
|
||||||
q[:, x_len:],
|
|
||||||
k_cache[layer_id],
|
|
||||||
v_cache[layer_id],
|
|
||||||
k[:, x_len:],
|
|
||||||
v[:, x_len:],
|
|
||||||
cache_seqlens=cache_seqlens + x_len,
|
|
||||||
causal=True
|
|
||||||
)
|
|
||||||
|
|
||||||
attn = torch.cat([attn_x, attn_y], dim=1)
|
|
||||||
attn = attn.view(-1, hidden_dim)
|
|
||||||
|
|
||||||
attn_out = F.linear(attn, layer.self_attn.out_proj.weight, layer.self_attn.out_proj.bias)
|
|
||||||
|
|
||||||
x = layer.norm1(x + attn_out, None)
|
|
||||||
|
|
||||||
x = layer.norm2(x + layer.linear2(F.relu(layer.linear1(x))), None)
|
|
||||||
|
|
||||||
xy_dec = x
|
|
||||||
|
|
||||||
logits = self.ar_predict_layer(
|
|
||||||
xy_dec[:, -1]
|
|
||||||
)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def infer_panel(
|
|
||||||
self,
|
self,
|
||||||
x, #####全部文本token
|
x, #####全部文本token
|
||||||
x_lens,
|
x_lens,
|
||||||
@ -597,8 +518,10 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + bert_feature
|
||||||
x = self.ar_text_position(x)
|
x = self.ar_text_position(x)
|
||||||
|
|
||||||
# AR Decoder
|
# AR Decoder
|
||||||
@ -635,30 +558,28 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
|
y_lens = torch.LongTensor([y_len]*bsz).to(x.device)
|
||||||
y_mask = make_pad_mask(y_lens)
|
y_mask = make_pad_mask(y_lens)
|
||||||
x_mask = make_pad_mask(x_lens)
|
x_mask = make_pad_mask(x_lens)
|
||||||
|
|
||||||
|
|
||||||
|
# (bsz, x_len + y_len)
|
||||||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||||
_xy_padding_mask = (
|
|
||||||
xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1)
|
|
||||||
)
|
|
||||||
|
|
||||||
x_attn_mask_pad = F.pad(
|
x_mask = F.pad(
|
||||||
x_attn_mask,
|
x_attn_mask,
|
||||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||||
value=True,
|
value=True,
|
||||||
)
|
)
|
||||||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
value=False,
|
value=False,
|
||||||
)
|
)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
|
||||||
x.device
|
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
|
||||||
)
|
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
|
||||||
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
|
||||||
|
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
|
||||||
|
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||||
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||||||
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
|
||||||
xy_attn_mask = new_attn_mask
|
|
||||||
|
|
||||||
###### decode #####
|
###### decode #####
|
||||||
y_list = [None]*y.shape[0]
|
y_list = [None]*y.shape[0]
|
||||||
@ -730,7 +651,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
####################### update next step ###################################
|
####################### update next step ###################################
|
||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
||||||
|
|
||||||
if (None in idx_list):
|
if (None in idx_list):
|
||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
if idx_list[i] is None:
|
if idx_list[i] is None:
|
||||||
|
@ -143,7 +143,7 @@ def logits_to_probs(
|
|||||||
|
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
pivot = v[: , -1].unsqueeze(-1)
|
||||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||||
|
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import os, sys
|
import os, sys
|
||||||
|
import random
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
@ -7,6 +8,7 @@ import os
|
|||||||
from typing import Generator, List, Union
|
from typing import Generator, List, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import yaml
|
import yaml
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
|
|
||||||
@ -97,7 +99,6 @@ class TTS_Config:
|
|||||||
configs = yaml.load(f, Loader=yaml.FullLoader)
|
configs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
def save_configs(self, configs_path:str=None)->None:
|
def save_configs(self, configs_path:str=None)->None:
|
||||||
configs={
|
configs={
|
||||||
@ -110,32 +111,31 @@ class TTS_Config:
|
|||||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||||
"flash_attn_enabled": True
|
"flash_attn_enabled": True
|
||||||
},
|
},
|
||||||
"custom": {
|
"custom": self.update_configs()
|
||||||
"device": str(self.device),
|
|
||||||
"is_half": self.is_half,
|
|
||||||
"t2s_weights_path": self.t2s_weights_path,
|
|
||||||
"vits_weights_path": self.vits_weights_path,
|
|
||||||
"bert_base_path": self.bert_base_path,
|
|
||||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
|
||||||
"flash_attn_enabled": self.flash_attn_enabled
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if configs_path is None:
|
if configs_path is None:
|
||||||
configs_path = self.configs_path
|
configs_path = self.configs_path
|
||||||
with open(configs_path, 'w') as f:
|
with open(configs_path, 'w') as f:
|
||||||
yaml.dump(configs, f)
|
yaml.dump(configs, f)
|
||||||
|
|
||||||
|
def update_configs(self):
|
||||||
|
config = {
|
||||||
|
"device" : str(self.device),
|
||||||
|
"is_half" : self.is_half,
|
||||||
|
"t2s_weights_path" : self.t2s_weights_path,
|
||||||
|
"vits_weights_path" : self.vits_weights_path,
|
||||||
|
"bert_base_path" : self.bert_base_path,
|
||||||
|
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||||
|
"flash_attn_enabled" : self.flash_attn_enabled
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
string = "----------------TTS Config--------------\n"
|
self.configs = self.update_configs()
|
||||||
string += "device: {}\n".format(self.device)
|
string = "TTS Config".center(100, '-') + '\n'
|
||||||
string += "is_half: {}\n".format(self.is_half)
|
for k, v in self.configs.items():
|
||||||
string += "bert_base_path: {}\n".format(self.bert_base_path)
|
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
||||||
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
|
string += "-" * 100 + '\n'
|
||||||
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
|
|
||||||
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
|
|
||||||
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
|
|
||||||
string += "----------------------------------------\n"
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ class TTS:
|
|||||||
|
|
||||||
def init_cnhuhbert_weights(self, base_path: str):
|
def init_cnhuhbert_weights(self, base_path: str):
|
||||||
self.cnhuhbert_model = CNHubert(base_path)
|
self.cnhuhbert_model = CNHubert(base_path)
|
||||||
self.cnhuhbert_model.eval()
|
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||||
if self.configs.is_half == True:
|
if self.configs.is_half == True:
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||||
@ -194,6 +194,7 @@ class TTS:
|
|||||||
def init_bert_weights(self, base_path: str):
|
def init_bert_weights(self, base_path: str):
|
||||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||||
|
self.bert_model=self.bert_model.eval()
|
||||||
if self.configs.is_half:
|
if self.configs.is_half:
|
||||||
self.bert_model = self.bert_model.half()
|
self.bert_model = self.bert_model.half()
|
||||||
self.bert_model = self.bert_model.to(self.configs.device)
|
self.bert_model = self.bert_model.to(self.configs.device)
|
||||||
@ -226,7 +227,7 @@ class TTS:
|
|||||||
if self.configs.is_half:
|
if self.configs.is_half:
|
||||||
vits_model = vits_model.half()
|
vits_model = vits_model.half()
|
||||||
vits_model = vits_model.to(self.configs.device)
|
vits_model = vits_model.to(self.configs.device)
|
||||||
vits_model.eval()
|
vits_model = vits_model.eval()
|
||||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
self.vits_model = vits_model
|
self.vits_model = vits_model
|
||||||
|
|
||||||
@ -244,7 +245,7 @@ class TTS:
|
|||||||
if self.configs.is_half:
|
if self.configs.is_half:
|
||||||
t2s_model = t2s_model.half()
|
t2s_model = t2s_model.half()
|
||||||
t2s_model = t2s_model.to(self.configs.device)
|
t2s_model = t2s_model.to(self.configs.device)
|
||||||
t2s_model.eval()
|
t2s_model = t2s_model.eval()
|
||||||
self.t2s_model = t2s_model
|
self.t2s_model = t2s_model
|
||||||
|
|
||||||
def set_ref_audio(self, ref_audio_path:str):
|
def set_ref_audio(self, ref_audio_path:str):
|
||||||
@ -377,12 +378,14 @@ class TTS:
|
|||||||
phones_max_len = 0
|
phones_max_len = 0
|
||||||
for item in item_list:
|
for item in item_list:
|
||||||
if prompt_data is not None:
|
if prompt_data is not None:
|
||||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)
|
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||||
|
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
||||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||||
else:
|
else:
|
||||||
all_bert_features = item["bert_features"]
|
all_bert_features = item["bert_features"]\
|
||||||
|
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
all_phones = phones
|
all_phones = phones
|
||||||
# norm_text = item["norm_text"]
|
# norm_text = item["norm_text"]
|
||||||
@ -401,12 +404,10 @@ class TTS:
|
|||||||
max_len = max(bert_max_len, phones_max_len)
|
max_len = max(bert_max_len, phones_max_len)
|
||||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
|
# all_bert_features_batch = all_bert_features_list
|
||||||
all_bert_features_batch.zero_()
|
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32)
|
||||||
|
|
||||||
for idx, item in enumerate(all_bert_features_list):
|
for idx, item in enumerate(all_bert_features_list):
|
||||||
if item != None:
|
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
|
||||||
|
|
||||||
batch = {
|
batch = {
|
||||||
"phones": phones_batch,
|
"phones": phones_batch,
|
||||||
@ -458,8 +459,8 @@ class TTS:
|
|||||||
"prompt_text": "", # str. prompt text for the reference audio
|
"prompt_text": "", # str. prompt text for the reference audio
|
||||||
"prompt_lang": "", # str. language of the prompt text for the reference audio
|
"prompt_lang": "", # str. language of the prompt text for the reference audio
|
||||||
"top_k": 5, # int. top k sampling
|
"top_k": 5, # int. top k sampling
|
||||||
"top_p": 0.9, # float. top p sampling
|
"top_p": 1, # float. top p sampling
|
||||||
"temperature": 0.6, # float. temperature for sampling
|
"temperature": 1, # float. temperature for sampling
|
||||||
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
|
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
|
||||||
"batch_size": 1, # int. batch size for inference
|
"batch_size": 1, # int. batch size for inference
|
||||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||||
@ -477,9 +478,9 @@ class TTS:
|
|||||||
ref_audio_path:str = inputs.get("ref_audio_path", "")
|
ref_audio_path:str = inputs.get("ref_audio_path", "")
|
||||||
prompt_text:str = inputs.get("prompt_text", "")
|
prompt_text:str = inputs.get("prompt_text", "")
|
||||||
prompt_lang:str = inputs.get("prompt_lang", "")
|
prompt_lang:str = inputs.get("prompt_lang", "")
|
||||||
top_k:int = inputs.get("top_k", 20)
|
top_k:int = inputs.get("top_k", 5)
|
||||||
top_p:float = inputs.get("top_p", 0.9)
|
top_p:float = inputs.get("top_p", 1)
|
||||||
temperature:float = inputs.get("temperature", 0.6)
|
temperature:float = inputs.get("temperature", 1)
|
||||||
text_split_method:str = inputs.get("text_split_method", "")
|
text_split_method:str = inputs.get("text_split_method", "")
|
||||||
batch_size = inputs.get("batch_size", 1)
|
batch_size = inputs.get("batch_size", 1)
|
||||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||||
@ -497,10 +498,6 @@ class TTS:
|
|||||||
if split_bucket:
|
if split_bucket:
|
||||||
print(i18n("分桶处理模式已开启"))
|
print(i18n("分桶处理模式已开启"))
|
||||||
|
|
||||||
# if vits_batched_inference:
|
|
||||||
# print(i18n("VITS批量推理模式已开启"))
|
|
||||||
# else:
|
|
||||||
# print(i18n("VITS单句推理模式已开启"))
|
|
||||||
|
|
||||||
no_prompt_text = False
|
no_prompt_text = False
|
||||||
if prompt_text in [None, ""]:
|
if prompt_text in [None, ""]:
|
||||||
@ -547,7 +544,7 @@ class TTS:
|
|||||||
)
|
)
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
|
|
||||||
|
print("############ 推理 ############")
|
||||||
###### inference ######
|
###### inference ######
|
||||||
t_34 = 0.0
|
t_34 = 0.0
|
||||||
t_45 = 0.0
|
t_45 = 0.0
|
||||||
@ -601,6 +598,10 @@ class TTS:
|
|||||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||||
|
# max_len = 0
|
||||||
|
# for i in range(0, len(batch_phones)):
|
||||||
|
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||||||
|
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||||||
# batch_phones = batch_phones.to(self.configs.device)
|
# batch_phones = batch_phones.to(self.configs.device)
|
||||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
|
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
|
||||||
@ -654,7 +655,12 @@ class TTS:
|
|||||||
self.configs.sampling_rate,
|
self.configs.sampling_rate,
|
||||||
batch_index_list,
|
batch_index_list,
|
||||||
speed_factor,
|
speed_factor,
|
||||||
split_bucket)
|
split_bucket)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
|
||||||
import os, sys
|
import os, sys
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
|
|
||||||
@ -12,9 +14,9 @@ from text import cleaned_text_to_sequence
|
|||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||||
|
|
||||||
# from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
# i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
def get_first(text:str) -> str:
|
def get_first(text:str) -> str:
|
||||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||||
@ -51,9 +53,11 @@ class TextPreprocessor:
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
|
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
|
||||||
|
print(i18n("############ 切分文本 ############"))
|
||||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
result = []
|
result = []
|
||||||
for text in texts:
|
print(i18n("############ 提取文本Bert特征 ############"))
|
||||||
|
for text in tqdm(texts):
|
||||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
|
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
|
||||||
res={
|
res={
|
||||||
"phones": phones,
|
"phones": phones,
|
||||||
@ -67,14 +71,16 @@ class TextPreprocessor:
|
|||||||
text = text.strip("\n")
|
text = text.strip("\n")
|
||||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||||
text = "。" + text if lang != "en" else "." + text
|
text = "。" + text if lang != "en" else "." + text
|
||||||
# print(i18n("实际输入的目标文本:"), text)
|
print(i18n("实际输入的目标文本:"))
|
||||||
|
print(text)
|
||||||
|
|
||||||
seg_method = get_seg_method(text_split_method)
|
seg_method = get_seg_method(text_split_method)
|
||||||
text = seg_method(text)
|
text = seg_method(text)
|
||||||
|
|
||||||
while "\n\n" in text:
|
while "\n\n" in text:
|
||||||
text = text.replace("\n\n", "\n")
|
text = text.replace("\n\n", "\n")
|
||||||
# print(i18n("实际输入的目标文本(切句后):"), text)
|
print(i18n("实际输入的目标文本(切句后):"))
|
||||||
|
print(text)
|
||||||
_texts = text.split("\n")
|
_texts = text.split("\n")
|
||||||
_texts = merge_short_text_in_array(_texts, 5)
|
_texts = merge_short_text_in_array(_texts, 5)
|
||||||
texts = []
|
texts = []
|
||||||
@ -105,7 +111,7 @@ class TextPreprocessor:
|
|||||||
textlist=[]
|
textlist=[]
|
||||||
langlist=[]
|
langlist=[]
|
||||||
if language in ["auto", "zh", "ja"]:
|
if language in ["auto", "zh", "ja"]:
|
||||||
# LangSegment.setfilters(["zh","ja","en","ko"])
|
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||||
for tmp in LangSegment.getTexts(text):
|
for tmp in LangSegment.getTexts(text):
|
||||||
if tmp["lang"] == "ko":
|
if tmp["lang"] == "ko":
|
||||||
langlist.append("zh")
|
langlist.append("zh")
|
||||||
@ -116,7 +122,7 @@ class TextPreprocessor:
|
|||||||
langlist.append(language if language!="auto" else tmp["lang"])
|
langlist.append(language if language!="auto" else tmp["lang"])
|
||||||
textlist.append(tmp["text"])
|
textlist.append(tmp["text"])
|
||||||
elif language == "en":
|
elif language == "en":
|
||||||
# LangSegment.setfilters(["en"])
|
LangSegment.setfilters(["en"])
|
||||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||||
while " " in formattext:
|
while " " in formattext:
|
||||||
formattext = formattext.replace(" ", " ")
|
formattext = formattext.replace(" ", " ")
|
||||||
@ -152,8 +158,7 @@ class TextPreprocessor:
|
|||||||
bert_feature = torch.cat(bert_feature_list, dim=1)
|
bert_feature = torch.cat(bert_feature_list, dim=1)
|
||||||
# phones = sum(phones_list, [])
|
# phones = sum(phones_list, [])
|
||||||
norm_text = ''.join(norm_text_list)
|
norm_text = ''.join(norm_text_list)
|
||||||
|
return phones_list, bert_feature, norm_text
|
||||||
return phones, bert_feature, norm_text
|
|
||||||
|
|
||||||
|
|
||||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||||
|
@ -45,9 +45,11 @@ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时
|
|||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
dict_language = {
|
dict_language = {
|
||||||
i18n("中文"): "all_zh",#全部按中文识别
|
i18n("中文"): "all_zh",#全部按中文识别
|
||||||
i18n("英文"): "en",#全部按英文识别#######不变
|
i18n("英文"): "en",#全部按英文识别#######不变
|
||||||
@ -103,10 +105,12 @@ def inference(text, text_lang,
|
|||||||
"batch_size":int(batch_size),
|
"batch_size":int(batch_size),
|
||||||
"speed_factor":float(speed_factor),
|
"speed_factor":float(speed_factor),
|
||||||
"split_bucket":split_bucket,
|
"split_bucket":split_bucket,
|
||||||
"return_fragment":False,
|
"return_fragment":False
|
||||||
}
|
}
|
||||||
yield next(tts_pipline.run(inputs))
|
|
||||||
|
for item in tts_pipline.run(inputs):
|
||||||
|
yield item
|
||||||
|
|
||||||
def custom_sort_key(s):
|
def custom_sort_key(s):
|
||||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||||
parts = re.split('(\d+)', s)
|
parts = re.split('(\d+)', s)
|
||||||
@ -182,7 +186,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
batch_size = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=1,interactive=True)
|
||||||
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
||||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user