mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
Merge branch 'fast_inference_' into half-fix
This commit is contained in:
commit
3ca117450a
@ -13,11 +13,11 @@ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
|
||||
class Text2SemanticLightningModule(LightningModule):
|
||||
def __init__(self, config, output_dir, is_train=True):
|
||||
def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = 3
|
||||
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
|
||||
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled)
|
||||
pretrained_s1 = config.get("pretrained_s1")
|
||||
if pretrained_s1 and is_train:
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
|
@ -1,7 +1,9 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import os, sys
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -174,7 +176,7 @@ class T2STransformer:
|
||||
|
||||
|
||||
class Text2SemanticDecoder(nn.Module):
|
||||
def __init__(self, config, norm_first=False, top_k=3):
|
||||
def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled:bool=False):
|
||||
super(Text2SemanticDecoder, self).__init__()
|
||||
self.model_dim = config["model"]["hidden_dim"]
|
||||
self.embedding_dim = config["model"]["embedding_dim"]
|
||||
@ -227,6 +229,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
ignore_index=self.EOS,
|
||||
)
|
||||
|
||||
self.enable_flash_attn(flash_attn_enabled)
|
||||
|
||||
def enable_flash_attn(self, enable:bool=True):
|
||||
|
||||
if not enable:
|
||||
print("Not Using Flash Attention")
|
||||
self.infer_panel = self.infer_panel_batch_only
|
||||
else:
|
||||
self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
|
||||
print("Using Flash Attention")
|
||||
blocks = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
@ -490,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def infer_panel(
|
||||
def infer_panel_batch_infer_with_flash_attn(
|
||||
self,
|
||||
x, #####全部文本token
|
||||
x_lens,
|
||||
@ -501,8 +513,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
|
||||
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = x + bert_feature
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
# AR Decoder
|
||||
@ -540,29 +554,27 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_mask = make_pad_mask(y_lens)
|
||||
x_mask = make_pad_mask(x_lens)
|
||||
|
||||
|
||||
# (bsz, x_len + y_len)
|
||||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||
_xy_padding_mask = (
|
||||
xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1)
|
||||
)
|
||||
|
||||
x_attn_mask_pad = F.pad(
|
||||
x_mask = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
x.device
|
||||
)
|
||||
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
||||
|
||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).expand(bsz, -1, -1).to(x.device)
|
||||
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
|
||||
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
|
||||
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||||
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
||||
xy_attn_mask = new_attn_mask
|
||||
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
@ -643,3 +655,165 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if ref_free:
|
||||
return y_list, [0]*x.shape[0]
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_batch_only(
|
||||
self,
|
||||
x, #####全部文本token
|
||||
x_lens,
|
||||
prompts, ####参考音频token
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
|
||||
x_len = x.shape[1]
|
||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||
stop = False
|
||||
# print(1111111,self.num_layers)
|
||||
cache = {
|
||||
"all_stage": self.num_layers,
|
||||
"k": [None] * self.num_layers, ###根据配置自己手写
|
||||
"v": [None] * self.num_layers,
|
||||
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
|
||||
"y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
|
||||
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
|
||||
# "xy_dec":None,###不需要,本来只需要最后一个做logits
|
||||
"first_infer": 1,
|
||||
"stage": 0,
|
||||
}
|
||||
################### first step ##########################
|
||||
if y is not None:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
cache["y_emb"] = y_emb
|
||||
ref_free = False
|
||||
else:
|
||||
y_emb = None
|
||||
y_len = 0
|
||||
prefix_len = 0
|
||||
y_pos = None
|
||||
xy_pos = x
|
||||
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
||||
ref_free = True
|
||||
|
||||
x_attn_mask_pad = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
x.device
|
||||
)
|
||||
|
||||
y_list = [None]*y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
|
||||
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
|
||||
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||
if(idx==0):###第一次跑不能EOS否则没有了
|
||||
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0]
|
||||
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
||||
# print(samples.shape)#[1,1]#第一个1是bs
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
# 移除已经生成完毕的序列
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in torch.argmax(logits, dim=-1)) or \
|
||||
(self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
|
||||
l = samples[:, 0]==self.EOS
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx - 1
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留未生成完毕的序列
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if cache["y_emb"] is not None:
|
||||
cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if cache["k"] is not None:
|
||||
for i in range(self.num_layers):
|
||||
# 因为kv转置了,所以batch dim是1
|
||||
cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
|
||||
cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
|
||||
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
|
||||
if not (None in idx_list):
|
||||
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
||||
stop = True
|
||||
if stop:
|
||||
# if prompts.shape[1] == y.shape[1]:
|
||||
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
# print("bad zero prediction")
|
||||
if y.shape[1]==0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
break
|
||||
|
||||
####################### update next step ###################################
|
||||
cache["first_infer"] = 0
|
||||
if cache["y_emb"] is not None:
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
|
||||
)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = y_pos[:, -1:]
|
||||
else:
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = y_pos
|
||||
y_len = y_pos.shape[1]
|
||||
|
||||
###最右边一列(是错的)
|
||||
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
|
||||
# xy_attn_mask[:,-1]=False
|
||||
###最下面一行(是对的)
|
||||
xy_attn_mask = torch.zeros(
|
||||
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
|
||||
)
|
||||
|
||||
if (None in idx_list):
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
if ref_free:
|
||||
return y_list, [0]*x.shape[0]
|
||||
return y_list, idx_list
|
@ -143,7 +143,7 @@ def logits_to_probs(
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
@ -1,4 +1,8 @@
|
||||
from copy import deepcopy
|
||||
import math
|
||||
import os, sys
|
||||
import random
|
||||
import traceback
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
import ffmpeg
|
||||
@ -6,6 +10,7 @@ import os
|
||||
from typing import Generator, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
@ -17,8 +22,8 @@ from time import time as ttime
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from my_utils import load_audio
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from .text_segmentation_method import splits
|
||||
from .TextPreprocessor import TextPreprocessor
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
i18n = I18nAuto()
|
||||
|
||||
# configs/tts_infer.yaml
|
||||
@ -30,6 +35,7 @@ default:
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
flash_attn_enabled: true
|
||||
|
||||
custom:
|
||||
device: cuda
|
||||
@ -38,41 +44,81 @@ custom:
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
|
||||
flash_attn_enabled: true
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# def set_seed(seed):
|
||||
# random.seed(seed)
|
||||
# os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
# np.random.seed(seed)
|
||||
# torch.manual_seed(seed)
|
||||
# torch.cuda.manual_seed(seed)
|
||||
# torch.cuda.manual_seed_all(seed)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cudnn.enabled = True
|
||||
# set_seed(1234)
|
||||
|
||||
class TTS_Config:
|
||||
def __init__(self, configs: Union[dict, str]):
|
||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||
os.makedirs(configs_base_path, exist_ok=True)
|
||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(configs)
|
||||
|
||||
# assert isinstance(configs, dict)
|
||||
self.default_configs:dict = configs.get("default", None)
|
||||
if self.default_configs is None:
|
||||
self.default_configs={
|
||||
default_configs={
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
"flash_attn_enabled": True
|
||||
}
|
||||
self.configs:dict = configs.get("custom", self.default_configs)
|
||||
configs:dict = None
|
||||
def __init__(self, configs: Union[dict, str]=None):
|
||||
|
||||
self.device = self.configs.get("device")
|
||||
self.is_half = self.configs.get("is_half")
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path")
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path")
|
||||
self.bert_base_path = self.configs.get("bert_base_path")
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
|
||||
# 设置默认配置文件路径
|
||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||
os.makedirs(configs_base_path, exist_ok=True)
|
||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||
|
||||
if configs in ["", None]:
|
||||
if not os.path.exists(self.configs_path):
|
||||
self.save_configs()
|
||||
print(f"Create default config file at {self.configs_path}")
|
||||
configs:dict = {"default": deepcopy(self.default_configs)}
|
||||
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(self.configs_path)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
default_configs:dict = configs.get("default", None)
|
||||
if default_configs is not None:
|
||||
self.default_configs = default_configs
|
||||
|
||||
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
|
||||
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||
|
||||
|
||||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||
self.t2s_weights_path = self.default_configs['t2s_weights_path']
|
||||
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||||
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
||||
self.vits_weights_path = self.default_configs['vits_weights_path']
|
||||
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
||||
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
||||
self.bert_base_path = self.default_configs['bert_base_path']
|
||||
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
||||
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||||
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
|
||||
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||
self.update_configs()
|
||||
|
||||
|
||||
self.max_sec = None
|
||||
@ -86,7 +132,7 @@ class TTS_Config:
|
||||
self.n_speakers:int = 300
|
||||
|
||||
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
print(self)
|
||||
# print(self)
|
||||
|
||||
def _load_configs(self, configs_path: str)->dict:
|
||||
with open(configs_path, 'r') as f:
|
||||
@ -94,43 +140,41 @@ class TTS_Config:
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def save_configs(self, configs_path:str=None)->None:
|
||||
configs={
|
||||
"default": {
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
},
|
||||
"custom": {
|
||||
"device": str(self.device),
|
||||
"is_half": self.is_half,
|
||||
"t2s_weights_path": self.t2s_weights_path,
|
||||
"vits_weights_path": self.vits_weights_path,
|
||||
"bert_base_path": self.bert_base_path,
|
||||
"cnhuhbert_base_path": self.cnhuhbert_base_path
|
||||
}
|
||||
"default":self.default_configs,
|
||||
}
|
||||
if self.configs is not None:
|
||||
configs["custom"] = self.update_configs()
|
||||
|
||||
if configs_path is None:
|
||||
configs_path = self.configs_path
|
||||
with open(configs_path, 'w') as f:
|
||||
yaml.dump(configs, f)
|
||||
|
||||
def update_configs(self):
|
||||
self.config = {
|
||||
"device" : str(self.device),
|
||||
"is_half" : self.is_half,
|
||||
"t2s_weights_path" : self.t2s_weights_path,
|
||||
"vits_weights_path" : self.vits_weights_path,
|
||||
"bert_base_path" : self.bert_base_path,
|
||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||
"flash_attn_enabled" : self.flash_attn_enabled
|
||||
}
|
||||
return self.config
|
||||
|
||||
def __str__(self):
|
||||
string = "----------------TTS Config--------------\n"
|
||||
string += "device: {}\n".format(self.device)
|
||||
string += "is_half: {}\n".format(self.is_half)
|
||||
string += "bert_base_path: {}\n".format(self.bert_base_path)
|
||||
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
|
||||
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
|
||||
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
|
||||
string += "----------------------------------------\n"
|
||||
self.configs = self.update_configs()
|
||||
string = "TTS Config".center(100, '-') + '\n'
|
||||
for k, v in self.configs.items():
|
||||
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
||||
string += "-" * 100 + '\n'
|
||||
return string
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class TTS:
|
||||
def __init__(self, configs: Union[dict, str, TTS_Config]):
|
||||
@ -166,34 +210,40 @@ class TTS:
|
||||
|
||||
|
||||
self.stop_flag:bool = False
|
||||
self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||
|
||||
def _init_models(self,):
|
||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||
self.init_vits_weights(self.configs.vits_weights_path)
|
||||
self.init_bert_weights(self.configs.bert_base_path)
|
||||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||||
# self.enable_half_precision(self.configs.is_half)
|
||||
|
||||
|
||||
|
||||
def init_cnhuhbert_weights(self, base_path: str):
|
||||
print(f"Loading CNHuBERT weights from {base_path}")
|
||||
self.cnhuhbert_model = CNHubert(base_path)
|
||||
self.cnhuhbert_model.eval()
|
||||
if self.configs.is_half == True:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
|
||||
|
||||
|
||||
def init_bert_weights(self, base_path: str):
|
||||
print(f"Loading BERT weights from {base_path}")
|
||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||
self.bert_model=self.bert_model.eval()
|
||||
self.bert_model = self.bert_model.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
self.bert_model = self.bert_model.half()
|
||||
self.bert_model = self.bert_model.to(self.configs.device)
|
||||
|
||||
|
||||
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
print(f"Loading VITS weights from {weights_path}")
|
||||
self.configs.vits_weights_path = weights_path
|
||||
self.configs.save_configs()
|
||||
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
|
||||
@ -216,28 +266,80 @@ class TTS:
|
||||
if hasattr(vits_model, "enc_q"):
|
||||
del vits_model.enc_q
|
||||
|
||||
if self.configs.is_half:
|
||||
vits_model = vits_model.half()
|
||||
vits_model = vits_model.to(self.configs.device)
|
||||
vits_model.eval()
|
||||
vits_model = vits_model.eval()
|
||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
self.vits_model = vits_model
|
||||
if self.configs.is_half:
|
||||
self.vits_model = self.vits_model.half()
|
||||
|
||||
|
||||
def init_t2s_weights(self, weights_path: str):
|
||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
self.configs.save_configs()
|
||||
self.configs.hz = 50
|
||||
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
|
||||
config = dict_s1["config"]
|
||||
self.configs.max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
|
||||
flash_attn_enabled=self.configs.flash_attn_enabled)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if self.configs.is_half:
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model.eval()
|
||||
t2s_model = t2s_model.eval()
|
||||
self.t2s_model = t2s_model
|
||||
if self.configs.is_half:
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
def enable_half_precision(self, enable: bool = True):
|
||||
'''
|
||||
To enable half precision for the TTS model.
|
||||
Args:
|
||||
enable: bool, whether to enable half precision.
|
||||
|
||||
'''
|
||||
if self.configs.device == "cpu" and enable:
|
||||
print("Half precision is not supported on CPU.")
|
||||
return
|
||||
|
||||
self.configs.is_half = enable
|
||||
self.precison = torch.float16 if enable else torch.float32
|
||||
self.configs.save_configs()
|
||||
if enable:
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model =self.t2s_model.half()
|
||||
if self.vits_model is not None:
|
||||
self.vits_model = self.vits_model.half()
|
||||
if self.bert_model is not None:
|
||||
self.bert_model =self.bert_model.half()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
else:
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model = self.t2s_model.float()
|
||||
if self.vits_model is not None:
|
||||
self.vits_model = self.vits_model.float()
|
||||
if self.bert_model is not None:
|
||||
self.bert_model = self.bert_model.float()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||
|
||||
def set_device(self, device: torch.device):
|
||||
'''
|
||||
To set the device for all models.
|
||||
Args:
|
||||
device: torch.device, the device to use for all models.
|
||||
'''
|
||||
self.configs.device = device
|
||||
self.configs.save_configs()
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model = self.t2s_model.to(device)
|
||||
if self.vits_model is not None:
|
||||
self.vits_model = self.vits_model.to(device)
|
||||
if self.bert_model is not None:
|
||||
self.bert_model = self.bert_model.to(device)
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
|
||||
|
||||
def set_ref_audio(self, ref_audio_path:str):
|
||||
'''
|
||||
@ -338,7 +440,7 @@ class TTS:
|
||||
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
||||
while pos < pos_end:
|
||||
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||||
score=batch[(pos_end-pos)//2]/batch.mean()
|
||||
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
|
||||
if (score>=threshold) or (pos_end-pos==1):
|
||||
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
||||
batch_index_list_len += len(batch_index)
|
||||
@ -359,6 +461,7 @@ class TTS:
|
||||
for batch_idx, index_list in enumerate(batch_index_list):
|
||||
item_list = [data[idx] for idx in index_list]
|
||||
phones_list = []
|
||||
phones_len_list = []
|
||||
# bert_features_list = []
|
||||
all_phones_list = []
|
||||
all_phones_len_list = []
|
||||
@ -368,37 +471,40 @@ class TTS:
|
||||
phones_max_len = 0
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||
.to(dtype=self.precison)
|
||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||
else:
|
||||
all_bert_features = item["bert_features"]
|
||||
all_bert_features = item["bert_features"]\
|
||||
.to(dtype=self.precison)
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
all_phones = phones.clone()
|
||||
all_phones = phones
|
||||
# norm_text = item["norm_text"]
|
||||
|
||||
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
||||
phones_max_len = max(phones_max_len, phones.shape[-1])
|
||||
|
||||
phones_list.append(phones)
|
||||
phones_len_list.append(phones.shape[-1])
|
||||
all_phones_list.append(all_phones)
|
||||
all_phones_len_list.append(all_phones.shape[-1])
|
||||
all_bert_features_list.append(all_bert_features)
|
||||
norm_text_batch.append(item["norm_text"])
|
||||
|
||||
phones_batch = phones_list
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
|
||||
all_bert_features_batch.zero_()
|
||||
|
||||
# all_bert_features_batch = all_bert_features_list
|
||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
|
||||
for idx, item in enumerate(all_bert_features_list):
|
||||
if item != None:
|
||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
|
||||
batch = {
|
||||
"phones": phones_batch,
|
||||
"phones_len": torch.LongTensor(phones_len_list),
|
||||
"all_phones": all_phones_batch,
|
||||
"all_phones_len": torch.LongTensor(all_phones_len_list),
|
||||
"all_bert_features": all_bert_features_batch,
|
||||
@ -446,8 +552,8 @@ class TTS:
|
||||
"prompt_text": "", # str. prompt text for the reference audio
|
||||
"prompt_lang": "", # str. language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_p": 0.9, # float. top p sampling
|
||||
"temperature": 0.6, # float. temperature for sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
@ -465,9 +571,9 @@ class TTS:
|
||||
ref_audio_path:str = inputs.get("ref_audio_path", "")
|
||||
prompt_text:str = inputs.get("prompt_text", "")
|
||||
prompt_lang:str = inputs.get("prompt_lang", "")
|
||||
top_k:int = inputs.get("top_k", 20)
|
||||
top_p:float = inputs.get("top_p", 0.9)
|
||||
temperature:float = inputs.get("temperature", 0.6)
|
||||
top_k:int = inputs.get("top_k", 5)
|
||||
top_p:float = inputs.get("top_p", 1)
|
||||
temperature:float = inputs.get("temperature", 1)
|
||||
text_split_method:str = inputs.get("text_split_method", "")
|
||||
batch_size = inputs.get("batch_size", 1)
|
||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||
@ -522,7 +628,11 @@ class TTS:
|
||||
|
||||
###### text preprocessing ########
|
||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
||||
audio = []
|
||||
if len(data) == 0:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
|
||||
dtype=np.int16)
|
||||
return
|
||||
|
||||
t1 = ttime()
|
||||
data, batch_index_list = self.to_batch(data,
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
@ -531,24 +641,23 @@ class TTS:
|
||||
split_bucket=split_bucket
|
||||
)
|
||||
t2 = ttime()
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
dtype=torch.float16 if self.configs.is_half else torch.float32,
|
||||
device=self.configs.device
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
print("############ 推理 ############")
|
||||
###### inference ######
|
||||
t_34 = 0.0
|
||||
t_45 = 0.0
|
||||
audio = []
|
||||
for item in data:
|
||||
t3 = ttime()
|
||||
batch_phones = item["phones"]
|
||||
batch_phones_len = item["phones_len"]
|
||||
all_phoneme_ids = item["all_phones"]
|
||||
all_phoneme_lens = item["all_phones_len"]
|
||||
all_bert_features = item["all_bert_features"]
|
||||
norm_text = item["norm_text"]
|
||||
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
batch_phones_len = batch_phones_len.to(self.configs.device)
|
||||
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
|
||||
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
|
||||
all_bert_features = all_bert_features.to(self.configs.device)
|
||||
@ -559,7 +668,7 @@ class TTS:
|
||||
if no_prompt_text :
|
||||
prompt = None
|
||||
else:
|
||||
prompt = self.prompt_cache["prompt_semantic"].clone().repeat(all_phoneme_ids.shape[0], 1).to(self.configs.device)
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
@ -576,41 +685,54 @@ class TTS:
|
||||
t4 = ttime()
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].clone().to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
refer_audio_spepc = refer_audio_spepc.half()
|
||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
|
||||
.to(dtype=self.precison, device=self.configs.device)
|
||||
|
||||
## 直接对batch进行decode 生成的音频会有问题
|
||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
# batch_audio_fragment =(self.vits_model.decode(
|
||||
# pred_semantic, batch_phones, refer_audio_spepc
|
||||
# ).detach()[:, 0, :])
|
||||
# max_audio=torch.abs(batch_audio_fragment).max()#简单防止16bit爆音
|
||||
# if max_audio>1: batch_audio_fragment/=max_audio
|
||||
# batch_audio_fragment = batch_audio_fragment.cpu().numpy()
|
||||
|
||||
## 改成串行处理
|
||||
batch_audio_fragment = []
|
||||
for i, idx in enumerate(idx_list):
|
||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
audio_fragment =(self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spepc
|
||||
|
||||
# ## vits并行推理 method 1
|
||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||
# max_len = 0
|
||||
# for i in range(0, len(batch_phones)):
|
||||
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||||
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
|
||||
# ))
|
||||
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
|
||||
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
_batch_audio_fragment = (self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones,refer_audio_spepc
|
||||
).detach()[0, 0, :])
|
||||
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
|
||||
if max_audio>1: audio_fragment/=max_audio
|
||||
audio_fragment = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
batch_audio_fragment.append(
|
||||
audio_fragment.cpu().numpy()
|
||||
) ###试试重建不带上prompt部分
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
|
||||
|
||||
|
||||
# ## vits串行推理
|
||||
# for i, idx in enumerate(idx_list):
|
||||
# phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||
# _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
# audio_fragment =(self.vits_model.decode(
|
||||
# _pred_semantic, phones, refer_audio_spepc
|
||||
# ).detach()[0, 0, :])
|
||||
# batch_audio_fragment.append(
|
||||
# audio_fragment
|
||||
# ) ###试试重建不带上prompt部分
|
||||
|
||||
t5 = ttime()
|
||||
t_45 += t5 - t4
|
||||
if return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||
yield self.audio_postprocess(batch_audio_fragment,
|
||||
yield self.audio_postprocess([batch_audio_fragment],
|
||||
self.configs.sampling_rate,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
@ -619,7 +741,8 @@ class TTS:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16)
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
|
||||
dtype=np.int16)
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
@ -629,19 +752,55 @@ class TTS:
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# 必须返回一个空音频, 否则会导致显存不释放。
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
# 重置模型, 否则会导致显存释放不完全。
|
||||
del self.t2s_model
|
||||
del self.vits_model
|
||||
self.t2s_model = None
|
||||
self.vits_model = None
|
||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||
self.init_vits_weights(self.configs.vits_weights_path)
|
||||
finally:
|
||||
self.empty_cache()
|
||||
|
||||
|
||||
def empty_cache(self):
|
||||
try:
|
||||
if str(self.configs.device) == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
elif str(self.configs.device) == "mps":
|
||||
torch.mps.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
def audio_postprocess(self,
|
||||
audio:np.ndarray,
|
||||
audio:List[torch.Tensor],
|
||||
sr:int,
|
||||
batch_index_list:list=None,
|
||||
speed_factor:float=1.0,
|
||||
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
dtype=self.precison,
|
||||
device=self.configs.device
|
||||
)
|
||||
|
||||
for i, batch in enumerate(audio):
|
||||
for j, audio_fragment in enumerate(batch):
|
||||
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
|
||||
if max_audio>1: audio_fragment/=max_audio
|
||||
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
audio[i][j] = audio_fragment.cpu().numpy()
|
||||
|
||||
|
||||
if split_bucket:
|
||||
audio = self.recovery_order(audio, batch_index_list)
|
||||
else:
|
||||
audio = [item for batch in audio for item in batch]
|
||||
# audio = [item for batch in audio for item in batch]
|
||||
audio = sum(audio, [])
|
||||
|
||||
|
||||
audio = np.concatenate(audio, 0)
|
||||
|
@ -1,4 +1,9 @@
|
||||
|
||||
import os, sys
|
||||
|
||||
from tqdm import tqdm
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
import re
|
||||
import torch
|
||||
@ -7,11 +12,11 @@ from typing import Dict, List, Tuple
|
||||
from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from .text_segmentation_method import splits, get_method as get_seg_method
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
|
||||
# from tools.i18n.i18n import I18nAuto
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
# i18n = I18nAuto()
|
||||
i18n = I18nAuto()
|
||||
|
||||
def get_first(text:str) -> str:
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
@ -36,6 +41,10 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
tokenizer:AutoTokenizer, device:torch.device):
|
||||
@ -44,10 +53,14 @@ class TextPreprocessor:
|
||||
self.device = device
|
||||
|
||||
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
|
||||
print(i18n("############ 切分文本 ############"))
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
for text in texts:
|
||||
print(i18n("############ 提取文本Bert特征 ############"))
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
|
||||
if phones is None:
|
||||
continue
|
||||
res={
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
@ -60,30 +73,42 @@ class TextPreprocessor:
|
||||
text = text.strip("\n")
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
# print(i18n("实际输入的目标文本:"), text)
|
||||
print(i18n("实际输入的目标文本:"))
|
||||
print(text)
|
||||
|
||||
seg_method = get_seg_method(text_split_method)
|
||||
text = seg_method(text)
|
||||
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
# print(i18n("实际输入的目标文本(切句后):"), text)
|
||||
|
||||
_texts = text.split("\n")
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||
|
||||
# 解决句子过长导致Bert报错的问题
|
||||
if (len(text) > 510):
|
||||
texts.extend(split_big_text(text))
|
||||
else:
|
||||
texts.append(text)
|
||||
|
||||
print(i18n("实际输入的目标文本(切句后):"))
|
||||
print(texts)
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
|
||||
textlist, langlist = self.seg_text(texts, language)
|
||||
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
|
||||
if len(textlist) == 0:
|
||||
return None, None, None
|
||||
|
||||
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
|
||||
return phones, bert_features, norm_text
|
||||
|
||||
|
||||
@ -92,8 +117,10 @@ class TextPreprocessor:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
if language in ["auto", "zh", "ja"]:
|
||||
# LangSegment.setfilters(["zh","ja","en","ko"])
|
||||
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
if tmp["text"] == "":
|
||||
continue
|
||||
if tmp["lang"] == "ko":
|
||||
langlist.append("zh")
|
||||
elif tmp["lang"] == "en":
|
||||
@ -103,18 +130,22 @@ class TextPreprocessor:
|
||||
langlist.append(language if language!="auto" else tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
# LangSegment.setfilters(["en"])
|
||||
LangSegment.setfilters(["en"])
|
||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if formattext != "":
|
||||
textlist.append(formattext)
|
||||
langlist.append("en")
|
||||
|
||||
elif language in ["all_zh","all_ja"]:
|
||||
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
language = language.replace("all_","")
|
||||
if text == "":
|
||||
return [],[]
|
||||
textlist.append(formattext)
|
||||
langlist.append(language)
|
||||
|
||||
@ -139,8 +170,7 @@ class TextPreprocessor:
|
||||
bert_feature = torch.cat(bert_feature_list, dim=1)
|
||||
# phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
|
||||
return phones, bert_feature, norm_text
|
||||
return phones_list, bert_feature, norm_text
|
||||
|
||||
|
||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||
@ -174,3 +204,7 @@ class TextPreprocessor:
|
||||
).to(self.device)
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -24,6 +24,32 @@ def register_method(name):
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
|
||||
def split_big_text(text, max_len=510):
|
||||
# 定义全角和半角标点符号
|
||||
punctuation = "".join(splits)
|
||||
|
||||
# 切割文本
|
||||
segments = re.split('([' + punctuation + '])', text)
|
||||
|
||||
# 初始化结果列表和当前片段
|
||||
result = []
|
||||
current_segment = ''
|
||||
|
||||
for segment in segments:
|
||||
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||
if len(current_segment + segment) > max_len:
|
||||
result.append(current_segment)
|
||||
current_segment = segment
|
||||
else:
|
||||
current_segment += segment
|
||||
|
||||
# 将最后一个片段加入结果列表
|
||||
if current_segment:
|
||||
result.append(current_segment)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
@ -121,6 +147,6 @@ def cut5(inp):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
method = get_method("cut1")
|
||||
method = get_method("cut5")
|
||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||
|
@ -2,6 +2,7 @@ custom:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cuda
|
||||
flash_attn_enabled: true
|
||||
is_half: true
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
@ -9,6 +10,7 @@ default:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
flash_attn_enabled: true
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
|
@ -20,7 +20,6 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
import pdb
|
||||
import torch
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
|
||||
|
||||
|
||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||
@ -28,17 +27,24 @@ infer_ttswebui = int(infer_ttswebui)
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
from config import is_half,is_share
|
||||
gpt_path = os.environ.get("gpt_path", None)
|
||||
sovits_path = os.environ.get("sovits_path", None)
|
||||
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||
bert_path = os.environ.get("bert_path", None)
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
# elif torch.backends.mps.is_available():
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
@ -63,6 +69,16 @@ cut_method = {
|
||||
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
|
||||
tts_config.device = device
|
||||
tts_config.is_half = is_half
|
||||
if gpt_path is not None:
|
||||
tts_config.t2s_weights_path = gpt_path
|
||||
if sovits_path is not None:
|
||||
tts_config.vits_weights_path = sovits_path
|
||||
if cnhubert_base_path is not None:
|
||||
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
||||
if bert_path is not None:
|
||||
tts_config.bert_base_path = bert_path
|
||||
|
||||
print(tts_config)
|
||||
tts_pipline = TTS(tts_config)
|
||||
gpt_path = tts_config.t2s_weights_path
|
||||
sovits_path = tts_config.vits_weights_path
|
||||
@ -88,9 +104,11 @@ def inference(text, text_lang,
|
||||
"batch_size":int(batch_size),
|
||||
"speed_factor":float(speed_factor),
|
||||
"split_bucket":split_bucket,
|
||||
"return_fragment":False,
|
||||
"return_fragment":False
|
||||
}
|
||||
yield next(tts_pipline.run(inputs))
|
||||
|
||||
for item in tts_pipline.run(inputs):
|
||||
yield item
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
@ -167,7 +185,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True)
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
@ -179,7 +197,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
)
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
|
||||
with gr.Row():
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
|
||||
# with gr.Column():
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
with gr.Row():
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@ -986,6 +987,55 @@ class SynthesizerTrn(nn.Module):
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||
).to(refer.dtype)
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
|
||||
# y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
|
||||
# codes.dtype
|
||||
# )
|
||||
y_lengths = (y_lengths * 2).long().to(codes.device)
|
||||
text_lengths = text_lengths.long().to(text.device)
|
||||
# y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
# text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
# 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
z_masked = (z * y_mask)[:, :, :]
|
||||
|
||||
# 串行。把padding部分去掉再decode
|
||||
o_list:List[torch.Tensor] = []
|
||||
for i in range(z_masked.shape[0]):
|
||||
z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
|
||||
o = self.dec(z_slice, g=ge)[0, 0, :].detach()
|
||||
o_list.append(o)
|
||||
|
||||
# 并行(会有问题)。先decode,再把padding的部分去掉
|
||||
# o = self.dec(z_masked, g=ge)
|
||||
# upsample_rate = int(math.prod(self.upsample_rates))
|
||||
# o_lengths = y_lengths*upsample_rate
|
||||
# o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
|
||||
|
||||
return o_list
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
|
Loading…
x
Reference in New Issue
Block a user