允许指定经典推理而不并行

This commit is contained in:
XTer 2024-03-12 23:29:04 +08:00
parent a057c697e7
commit c80f4b9557
6 changed files with 231 additions and 138 deletions

View File

@ -1,25 +0,0 @@
使用说明:
当您想推理时,请开启 TTS后台服务 与 TTS WebUI
当您想调整模型时,请开启 TTS模型管理 与 TTS后台服务 、 TTS WebUI
当您仅仅想训练时,只需开启 GPTsoVITS WebUI
GPT-soVITS Start V (推理特化包版本)
版本1.3.20240311.1
作者:小贼丑
QQ/Wechat406267780
本界面启动器适配推理特化包2.1 by Xter
推理特化包信息:
fork地址 by XTerhttps://github.com/X-T-E-R/GPT-SoVITS-Inference
推理与模型管理子模块地址 by XTerhttps://github.com/X-T-E-R/TTS-for-GPT-soVITS
原项目 by RVC-Boss花儿不哭https://github.com/RVC-Boss/GPT-SoVITS
推理特化包交流群863760614
版权说明:
本软件仅为人工智能语音合成项目的一个启动器,请使用者注意,所有数据来源和模型使用均应尊重原始版权所有者的权益。使用本软件产生的任何内容,用户需自行承担相应的版权责任。
用户应理解,本软件的使用不授予任何数据来源或模型的版权、许可或使用权。所有使用必须遵守相关版权法律和协议,确保数据的合法获取和使用。
如不认可本声明或不承担相应责任,请立即停止使用本软件,并删除相关文件。继续使用即视为接受本声明条款,愿意遵守上述规定。

View File

@ -1,30 +1,73 @@
[Title]
Text=特化推理包2.1 by Xter | 原项目 by RVC-Boss花儿不哭
[HomeImage]
Route=\Cfg\HomeImage.jpg
[TTSModuleManagerStart]
Route=10 启动模型管理界面(可选).bat
[TTSBackgroundServiceStart]
FolderName=TTS 后台服务
ButtonName=TTS 后台服务
Route=3 启动后端程序.bat
[TTSWebUIStart]
FolderName=TTS WebUI
Hide=N
ButtonName=TTS WebUI
Route=4 启动前端合成程序(可选,依赖后端).bat
[GPTsoVITSWebUIStart]
FolderName=GPTsV WebUI
Hide=N
ButtonName=GPTsV WebUI
Route=11 启动原项目的训练界面(小白别开,请根据页面上的文档链接自行研究,推理群不包解答).bat
[TTSModuleFolder]
Route=..\trained
[TTSOnekeyStartFolder]
Route=..\0 一键启动脚本
[GPTsoVITSRootFolder]
Route=..\
[TTSJsonConfig]
Route=..\Inference\config.json
[TTSModuleManagerStart]
FolderName=TTS 模型管理
Hide=N
ButtonName=TTS 模型管理
Route=10 启动模型管理界面(可选).bat
[UpdateProgramStart]
FolderName=更新项目
ButtonName=更新项目
Route=0 一键更新项目.bat
[UpdateDependenciesStart]
FolderName=更新依赖
Hide=N
ButtonName=更新依赖
Route=1 一键更新本项目所需要的依赖.bat
[ForceUpdateStart]
FolderName=强制更新
Hide=N
ButtonName=强制更新
Route=999 强制更新会覆盖你的设置慎用和0功能类似.bat
[TTSModuleFolder]
ButtonName=TTS模型目录
Hide=N
Route=..\trained
[TTSOnekeyStartFolder]
ButtonName=TTS一键启动目录
Hide=N
Route=..\0 一键启动脚本
[GPTsoVITSRootFolder]
ButtonName=GPTsoVITS根目录
Hide=N
Route=..\
[TTSDocument]
ButtonName=TTS项目文档
Hide=N
Route=https://www.yuque.com/xter/zibxlp
[GPTsoVITSDocument]
ButtonName=GPTsoVITS项目文档
Hide=N
Route=https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e
[StartDocument]
ButtonName=启动器项目文档
Hide=N
Route=https://note.youdao.com/s/DlzSyaLl
[JsonConfig]
Route=..\Inference\config.json
[AboutTxt]
Route=\Cfg\About.txt
[Status]
Hide=N
TWU Hide=N
GWU Hide=N
TMM Hide=N
UpD Hide=N
FUp Hide=N
Line Hide=156

View File

@ -14,7 +14,7 @@ from AR.models.utils import (
logits_to_probs,
multinomial_sample_one_no_sync,
dpo_loss,
make_reject_y,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
@ -26,11 +26,6 @@ from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
try:
from flash_attn import flash_attn_with_kvcache
except ImportError:
flash_attn_with_kvcache = None
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
@ -302,7 +297,7 @@ class Text2SemanticDecoder(nn.Module):
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@ -363,7 +358,7 @@ class Text2SemanticDecoder(nn.Module):
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2
return loss, acc
@ -432,14 +427,14 @@ class Text2SemanticDecoder(nn.Module):
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -678,18 +673,22 @@ class Text2SemanticDecoder(nn.Module):
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
if flash_attn_with_kvcache is not None:
k_cache = [torch.empty(x.shape[0], 2048, 16, 32, dtype=x.dtype, device=x.device) for _ in range(self.num_layers)]
v_cache = [torch.empty(x.shape[0], 2048, 16, 32, dtype=x.dtype, device=x.device) for _ in range(self.num_layers)]
else:
k_cache = [None] * self.num_layers
v_cache = [None] * self.num_layers
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer": 1,
"stage": 0,
}
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
@ -697,6 +696,7 @@ class Text2SemanticDecoder(nn.Module):
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
cache["y_emb"] = y_emb
ref_free = False
else:
y_emb = None
@ -708,10 +708,10 @@ class Text2SemanticDecoder(nn.Module):
ref_free = True
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
@ -725,16 +725,14 @@ class Text2SemanticDecoder(nn.Module):
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):
logits = self.infer_one_step(xy_pos, xy_attn_mask, k_cache, v_cache, cache_seqlens)
if idx == 0:
cache_seqlens += xy_pos.shape[1]
else:
cache_seqlens += 1
xy_attn_mask = None
if idx == 0:
logits = logits[:, :-1]
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
@ -778,12 +776,15 @@ class Text2SemanticDecoder(nn.Module):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if y.shape[1] == 0:
# if prompts.shape[1] == y.shape[1]:
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
# print("bad zero prediction")
if y.shape[1]==0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
cache["first_infer"] = 0
if cache["y_emb"] is not None:

View File

@ -1,3 +1,4 @@
from copy import deepcopy
import math
import os, sys
import random
@ -50,22 +51,7 @@ custom:
class TTS_Config:
def __init__(self, configs: Union[dict, str]):
if isinstance(configs, str) and configs=="":
self.default_configs:dict = None
self.configs_path = "GPT_SoVITS/configs/tts_infer.yaml"
else:
configs_base_path:str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(configs)
# assert isinstance(configs, dict)
self.default_configs:dict = configs.get("default", None)
if self.default_configs is None:
self.default_configs={
default_configs={
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
@ -74,18 +60,54 @@ class TTS_Config:
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
}
if isinstance(configs, dict):
self.configs:dict = configs.get("custom", self.default_configs)
else:
self.configs:dict = self.default_configs
configs:dict = None
def __init__(self, configs: Union[dict, str]=None):
self.device = self.configs.get("device")
self.is_half = self.configs.get("is_half")
self.t2s_weights_path = self.configs.get("t2s_weights_path")
self.vits_weights_path = self.configs.get("vits_weights_path")
self.bert_base_path = self.configs.get("bert_base_path")
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
# 设置默认配置文件路径
configs_base_path:str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
if configs in ["", None]:
if not os.path.exists(self.configs_path):
self.save_configs()
print(f"Create default config file at {self.configs_path}")
configs:dict = {"default": deepcopy(self.default_configs)}
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict)
default_configs:dict = configs.get("default", None)
if default_configs is not None:
self.default_configs = default_configs
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
self.device = self.configs.get("device", torch.device("cpu"))
self.is_half = self.configs.get("is_half", False)
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
self.vits_weights_path = self.default_configs['vits_weights_path']
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
self.bert_base_path = self.default_configs['bert_base_path']
print(f"fall back to default bert_base_path: {self.bert_base_path}")
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs()
self.max_sec = None
@ -109,24 +131,18 @@ class TTS_Config:
def save_configs(self, configs_path:str=None)->None:
configs={
"default": {
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
},
"custom": self.update_configs()
"default":self.default_configs,
}
if self.configs is not None:
configs["custom"] = self.update_configs()
if configs_path is None:
configs_path = self.configs_path
with open(configs_path, 'w') as f:
yaml.dump(configs, f)
def update_configs(self):
config = {
self.config = {
"device" : str(self.device),
"is_half" : self.is_half,
"t2s_weights_path" : self.t2s_weights_path,
@ -135,7 +151,7 @@ class TTS_Config:
"cnhuhbert_base_path": self.cnhuhbert_base_path,
"flash_attn_enabled" : self.flash_attn_enabled
}
return config
return self.config
def __str__(self):
self.configs = self.update_configs()
@ -144,6 +160,9 @@ class TTS_Config:
string += f"{str(k).ljust(20)}: {str(v)}\n"
string += "-" * 100 + '\n'
return string
def __repr__(self):
return self.__str__()
class TTS:
@ -180,35 +199,40 @@ class TTS:
self.stop_flag:bool = False
self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
def _init_models(self,):
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
self.init_bert_weights(self.configs.bert_base_path)
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
# self.enable_half_precision(self.configs.is_half)
def init_cnhuhbert_weights(self, base_path: str):
print(f"Loading CNHuBERT weights from {base_path}")
self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model=self.cnhuhbert_model.eval()
if self.configs.is_half == True:
self.cnhuhbert_model = self.cnhuhbert_model.half()
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
if self.configs.is_half:
self.cnhuhbert_model = self.cnhuhbert_model.half()
def init_bert_weights(self, base_path: str):
print(f"Loading BERT weights from {base_path}")
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
self.bert_model=self.bert_model.eval()
self.bert_model = self.bert_model.to(self.configs.device)
if self.configs.is_half:
self.bert_model = self.bert_model.half()
self.bert_model = self.bert_model.to(self.configs.device)
def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path
self.configs.save_configs()
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
@ -231,15 +255,16 @@ class TTS:
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
if self.configs.is_half:
vits_model = vits_model.half()
vits_model = vits_model.to(self.configs.device)
vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model
if self.configs.is_half:
self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path
self.configs.save_configs()
self.configs.hz = 50
@ -249,11 +274,61 @@ class TTS:
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
flash_attn_enabled=self.configs.flash_attn_enabled)
t2s_model.load_state_dict(dict_s1["weight"])
if self.configs.is_half:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(self.configs.device)
t2s_model = t2s_model.eval()
self.t2s_model = t2s_model
if self.configs.is_half:
self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True):
'''
To enable half precision for the TTS model.
Args:
enable: bool, whether to enable half precision.
'''
if self.configs.device == "cpu" and enable:
print("Half precision is not supported on CPU.")
return
self.configs.is_half = enable
self.precison = torch.float16 if enable else torch.float32
self.configs.save_configs()
if enable:
if self.t2s_model is not None:
self.t2s_model =self.t2s_model.half()
if self.vits_model is not None:
self.vits_model = self.vits_model.half()
if self.bert_model is not None:
self.bert_model =self.bert_model.half()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half()
else:
if self.t2s_model is not None:
self.t2s_model = self.t2s_model.float()
if self.vits_model is not None:
self.vits_model = self.vits_model.float()
if self.bert_model is not None:
self.bert_model = self.bert_model.float()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float()
def set_device(self, device: torch.device):
'''
To set the device for all models.
Args:
device: torch.device, the device to use for all models.
'''
self.configs.device = device
self.configs.save_configs()
if self.t2s_model is not None:
self.t2s_model = self.t2s_model.to(device)
if self.vits_model is not None:
self.vits_model = self.vits_model.to(device)
if self.bert_model is not None:
self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
def set_ref_audio(self, ref_audio_path:str):
'''
@ -354,7 +429,7 @@ class TTS:
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/batch.mean()
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
@ -386,13 +461,13 @@ class TTS:
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
.to(dtype=self.precison)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
phones = torch.LongTensor(item["phones"])
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]\
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
.to(dtype=self.precison)
phones = torch.LongTensor(item["phones"])
all_phones = phones
# norm_text = item["norm_text"]
@ -412,7 +487,7 @@ class TTS:
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
# all_bert_features_batch = all_bert_features_list
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32)
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
for idx, item in enumerate(all_bert_features_list):
all_bert_features_batch[idx, :, : item.shape[-1]] = item
@ -542,6 +617,11 @@ class TTS:
###### text preprocessing ########
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
dtype=np.int16)
return
t1 = ttime()
data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
@ -594,10 +674,8 @@ class TTS:
t4 = ttime()
t_34 += t4 - t3
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
if self.configs.is_half:
refer_audio_spepc = refer_audio_spepc.half()
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
.to(dtype=self.precison, device=self.configs.device)
batch_audio_fragment = []
@ -679,7 +757,7 @@ class TTS:
split_bucket:bool=True)->tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=torch.float16 if self.configs.is_half else torch.float32,
dtype=self.precison,
device=self.configs.device
)
@ -697,13 +775,9 @@ class TTS:
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])
try:
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
except:
audio = np.array(audio)
audio = (audio * 32768).astype(np.int16)
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
try:
if speed_factor != 1.0:

@ -1 +1 @@
Subproject commit d0716731840c8f405a8c1c60277a646029a2c2ae
Subproject commit 66f358cc1140490bb1f06ae4550e5c8a0dedd86f