mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
允许指定经典推理而不并行
This commit is contained in:
parent
a057c697e7
commit
c80f4b9557
@ -1,25 +0,0 @@
|
||||
使用说明:
|
||||
|
||||
当您想推理时,请开启 TTS后台服务 与 TTS WebUI
|
||||
当您想调整模型时,请开启 TTS模型管理 与 TTS后台服务 、 TTS WebUI
|
||||
当您仅仅想训练时,只需开启 GPTsoVITS WebUI
|
||||
|
||||
|
||||
GPT-soVITS Start V (推理特化包版本)
|
||||
版本:1.3.20240311.1
|
||||
作者:小贼丑
|
||||
QQ/Wechat:406267780
|
||||
|
||||
本界面启动器适配:推理特化包2.1 by Xter
|
||||
|
||||
推理特化包信息:
|
||||
fork地址 by XTer:https://github.com/X-T-E-R/GPT-SoVITS-Inference
|
||||
推理与模型管理子模块地址 by XTer:https://github.com/X-T-E-R/TTS-for-GPT-soVITS
|
||||
原项目 by RVC-Boss(花儿不哭):https://github.com/RVC-Boss/GPT-SoVITS
|
||||
|
||||
推理特化包交流群:863760614
|
||||
|
||||
版权说明:
|
||||
本软件仅为人工智能语音合成项目的一个启动器,请使用者注意,所有数据来源和模型使用均应尊重原始版权所有者的权益。使用本软件产生的任何内容,用户需自行承担相应的版权责任。
|
||||
用户应理解,本软件的使用不授予任何数据来源或模型的版权、许可或使用权。所有使用必须遵守相关版权法律和协议,确保数据的合法获取和使用。
|
||||
如不认可本声明或不承担相应责任,请立即停止使用本软件,并删除相关文件。继续使用即视为接受本声明条款,愿意遵守上述规定。
|
@ -1,30 +1,73 @@
|
||||
[Title]
|
||||
Text=特化推理包2.1 by Xter | 原项目 by RVC-Boss(花儿不哭)
|
||||
[HomeImage]
|
||||
Route=\Cfg\HomeImage.jpg
|
||||
[TTSModuleManagerStart]
|
||||
Route=10 启动模型管理界面(可选).bat
|
||||
[TTSBackgroundServiceStart]
|
||||
FolderName=TTS 后台服务
|
||||
ButtonName=TTS 后台服务
|
||||
Route=3 启动后端程序.bat
|
||||
[TTSWebUIStart]
|
||||
FolderName=TTS WebUI
|
||||
Hide=N
|
||||
ButtonName=TTS WebUI
|
||||
Route=4 启动前端合成程序(可选,依赖后端).bat
|
||||
[GPTsoVITSWebUIStart]
|
||||
FolderName=GPTsV WebUI
|
||||
Hide=N
|
||||
ButtonName=GPTsV WebUI
|
||||
Route=11 启动原项目的训练界面(小白别开,请根据页面上的文档链接自行研究,推理群不包解答).bat
|
||||
[TTSModuleFolder]
|
||||
Route=..\trained
|
||||
[TTSOnekeyStartFolder]
|
||||
Route=..\0 一键启动脚本
|
||||
[GPTsoVITSRootFolder]
|
||||
Route=..\
|
||||
[TTSJsonConfig]
|
||||
Route=..\Inference\config.json
|
||||
[TTSModuleManagerStart]
|
||||
FolderName=TTS 模型管理
|
||||
Hide=N
|
||||
ButtonName=TTS 模型管理
|
||||
Route=10 启动模型管理界面(可选).bat
|
||||
[UpdateProgramStart]
|
||||
FolderName=更新项目
|
||||
ButtonName=更新项目
|
||||
Route=0 一键更新项目.bat
|
||||
[UpdateDependenciesStart]
|
||||
FolderName=更新依赖
|
||||
Hide=N
|
||||
ButtonName=更新依赖
|
||||
Route=1 一键更新本项目所需要的依赖.bat
|
||||
[ForceUpdateStart]
|
||||
FolderName=强制更新
|
||||
Hide=N
|
||||
ButtonName=强制更新
|
||||
Route=999 强制更新:会覆盖你的设置,慎用,和0功能类似.bat
|
||||
[TTSModuleFolder]
|
||||
ButtonName=TTS模型目录
|
||||
Hide=N
|
||||
Route=..\trained
|
||||
[TTSOnekeyStartFolder]
|
||||
ButtonName=TTS一键启动目录
|
||||
Hide=N
|
||||
Route=..\0 一键启动脚本
|
||||
[GPTsoVITSRootFolder]
|
||||
ButtonName=GPTsoVITS根目录
|
||||
Hide=N
|
||||
Route=..\
|
||||
[TTSDocument]
|
||||
ButtonName=TTS项目文档
|
||||
Hide=N
|
||||
Route=https://www.yuque.com/xter/zibxlp
|
||||
[GPTsoVITSDocument]
|
||||
ButtonName=GPTsoVITS项目文档
|
||||
Hide=N
|
||||
Route=https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e
|
||||
[StartDocument]
|
||||
ButtonName=启动器项目文档
|
||||
Hide=N
|
||||
Route=https://note.youdao.com/s/DlzSyaLl
|
||||
[JsonConfig]
|
||||
Route=..\Inference\config.json
|
||||
[AboutTxt]
|
||||
Route=\Cfg\About.txt
|
||||
[Status]
|
||||
Hide=N
|
||||
TWU Hide=N
|
||||
GWU Hide=N
|
||||
TMM Hide=N
|
||||
UpD Hide=N
|
||||
FUp Hide=N
|
||||
Line Hide=156
|
||||
|
Binary file not shown.
@ -14,7 +14,7 @@ from AR.models.utils import (
|
||||
logits_to_probs,
|
||||
multinomial_sample_one_no_sync,
|
||||
dpo_loss,
|
||||
make_reject_y,
|
||||
make_reject_y,
|
||||
get_batch_logps
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding
|
||||
@ -26,11 +26,6 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
except ImportError:
|
||||
flash_attn_with_kvcache = None
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
"hidden_dim": 512,
|
||||
@ -302,7 +297,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
|
||||
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(
|
||||
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
||||
@ -363,7 +358,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
||||
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
||||
|
||||
|
||||
loss = loss_1 + loss_2
|
||||
|
||||
return loss, acc
|
||||
@ -432,14 +427,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
||||
def infer(
|
||||
self,
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
self,
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@ -678,18 +673,22 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
|
||||
|
||||
x_len = x.shape[1]
|
||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||
stop = False
|
||||
# print(1111111,self.num_layers)
|
||||
|
||||
if flash_attn_with_kvcache is not None:
|
||||
k_cache = [torch.empty(x.shape[0], 2048, 16, 32, dtype=x.dtype, device=x.device) for _ in range(self.num_layers)]
|
||||
v_cache = [torch.empty(x.shape[0], 2048, 16, 32, dtype=x.dtype, device=x.device) for _ in range(self.num_layers)]
|
||||
else:
|
||||
k_cache = [None] * self.num_layers
|
||||
v_cache = [None] * self.num_layers
|
||||
cache = {
|
||||
"all_stage": self.num_layers,
|
||||
"k": [None] * self.num_layers, ###根据配置自己手写
|
||||
"v": [None] * self.num_layers,
|
||||
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
|
||||
"y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
|
||||
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
|
||||
# "xy_dec":None,###不需要,本来只需要最后一个做logits
|
||||
"first_infer": 1,
|
||||
"stage": 0,
|
||||
}
|
||||
################### first step ##########################
|
||||
if y is not None:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
@ -697,6 +696,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
prefix_len = y.shape[1]
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
cache["y_emb"] = y_emb
|
||||
ref_free = False
|
||||
else:
|
||||
y_emb = None
|
||||
@ -708,10 +708,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
ref_free = True
|
||||
|
||||
x_attn_mask_pad = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
value=True,
|
||||
)
|
||||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0),
|
||||
@ -725,16 +725,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
logits = self.infer_one_step(xy_pos, xy_attn_mask, k_cache, v_cache, cache_seqlens)
|
||||
|
||||
if idx == 0:
|
||||
cache_seqlens += xy_pos.shape[1]
|
||||
else:
|
||||
cache_seqlens += 1
|
||||
xy_attn_mask = None
|
||||
|
||||
if idx == 0:
|
||||
logits = logits[:, :-1]
|
||||
|
||||
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
|
||||
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||
if(idx==0):###第一次跑不能EOS否则没有了
|
||||
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0]
|
||||
@ -778,12 +776,15 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
||||
stop = True
|
||||
if stop:
|
||||
if y.shape[1] == 0:
|
||||
# if prompts.shape[1] == y.shape[1]:
|
||||
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
# print("bad zero prediction")
|
||||
if y.shape[1]==0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
break
|
||||
|
||||
|
||||
####################### update next step ###################################
|
||||
cache["first_infer"] = 0
|
||||
if cache["y_emb"] is not None:
|
||||
|
@ -1,3 +1,4 @@
|
||||
from copy import deepcopy
|
||||
import math
|
||||
import os, sys
|
||||
import random
|
||||
@ -50,22 +51,7 @@ custom:
|
||||
|
||||
|
||||
class TTS_Config:
|
||||
def __init__(self, configs: Union[dict, str]):
|
||||
if isinstance(configs, str) and configs=="":
|
||||
self.default_configs:dict = None
|
||||
self.configs_path = "GPT_SoVITS/configs/tts_infer.yaml"
|
||||
else:
|
||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||
os.makedirs(configs_base_path, exist_ok=True)
|
||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(configs)
|
||||
|
||||
# assert isinstance(configs, dict)
|
||||
self.default_configs:dict = configs.get("default", None)
|
||||
if self.default_configs is None:
|
||||
self.default_configs={
|
||||
default_configs={
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
@ -74,18 +60,54 @@ class TTS_Config:
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
"flash_attn_enabled": True
|
||||
}
|
||||
if isinstance(configs, dict):
|
||||
self.configs:dict = configs.get("custom", self.default_configs)
|
||||
else:
|
||||
self.configs:dict = self.default_configs
|
||||
configs:dict = None
|
||||
def __init__(self, configs: Union[dict, str]=None):
|
||||
|
||||
self.device = self.configs.get("device")
|
||||
self.is_half = self.configs.get("is_half")
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path")
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path")
|
||||
self.bert_base_path = self.configs.get("bert_base_path")
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
|
||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
|
||||
# 设置默认配置文件路径
|
||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||
os.makedirs(configs_base_path, exist_ok=True)
|
||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||
|
||||
if configs in ["", None]:
|
||||
if not os.path.exists(self.configs_path):
|
||||
self.save_configs()
|
||||
print(f"Create default config file at {self.configs_path}")
|
||||
configs:dict = {"default": deepcopy(self.default_configs)}
|
||||
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(self.configs_path)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
default_configs:dict = configs.get("default", None)
|
||||
if default_configs is not None:
|
||||
self.default_configs = default_configs
|
||||
|
||||
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
|
||||
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||
|
||||
|
||||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||
self.t2s_weights_path = self.default_configs['t2s_weights_path']
|
||||
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||||
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
||||
self.vits_weights_path = self.default_configs['vits_weights_path']
|
||||
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
||||
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
||||
self.bert_base_path = self.default_configs['bert_base_path']
|
||||
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
||||
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||||
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
|
||||
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||
self.update_configs()
|
||||
|
||||
|
||||
self.max_sec = None
|
||||
@ -109,24 +131,18 @@ class TTS_Config:
|
||||
|
||||
def save_configs(self, configs_path:str=None)->None:
|
||||
configs={
|
||||
"default": {
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
"flash_attn_enabled": True
|
||||
},
|
||||
"custom": self.update_configs()
|
||||
"default":self.default_configs,
|
||||
}
|
||||
if self.configs is not None:
|
||||
configs["custom"] = self.update_configs()
|
||||
|
||||
if configs_path is None:
|
||||
configs_path = self.configs_path
|
||||
with open(configs_path, 'w') as f:
|
||||
yaml.dump(configs, f)
|
||||
|
||||
def update_configs(self):
|
||||
config = {
|
||||
self.config = {
|
||||
"device" : str(self.device),
|
||||
"is_half" : self.is_half,
|
||||
"t2s_weights_path" : self.t2s_weights_path,
|
||||
@ -135,7 +151,7 @@ class TTS_Config:
|
||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||
"flash_attn_enabled" : self.flash_attn_enabled
|
||||
}
|
||||
return config
|
||||
return self.config
|
||||
|
||||
def __str__(self):
|
||||
self.configs = self.update_configs()
|
||||
@ -144,6 +160,9 @@ class TTS_Config:
|
||||
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
||||
string += "-" * 100 + '\n'
|
||||
return string
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class TTS:
|
||||
@ -180,35 +199,40 @@ class TTS:
|
||||
|
||||
|
||||
self.stop_flag:bool = False
|
||||
self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||
|
||||
def _init_models(self,):
|
||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||
self.init_vits_weights(self.configs.vits_weights_path)
|
||||
self.init_bert_weights(self.configs.bert_base_path)
|
||||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||||
# self.enable_half_precision(self.configs.is_half)
|
||||
|
||||
|
||||
|
||||
def init_cnhuhbert_weights(self, base_path: str):
|
||||
print(f"Loading CNHuBERT weights from {base_path}")
|
||||
self.cnhuhbert_model = CNHubert(base_path)
|
||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||
if self.configs.is_half == True:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
|
||||
|
||||
|
||||
def init_bert_weights(self, base_path: str):
|
||||
print(f"Loading BERT weights from {base_path}")
|
||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||
self.bert_model=self.bert_model.eval()
|
||||
self.bert_model = self.bert_model.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
self.bert_model = self.bert_model.half()
|
||||
self.bert_model = self.bert_model.to(self.configs.device)
|
||||
|
||||
|
||||
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
print(f"Loading VITS weights from {weights_path}")
|
||||
self.configs.vits_weights_path = weights_path
|
||||
self.configs.save_configs()
|
||||
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
|
||||
@ -231,15 +255,16 @@ class TTS:
|
||||
if hasattr(vits_model, "enc_q"):
|
||||
del vits_model.enc_q
|
||||
|
||||
if self.configs.is_half:
|
||||
vits_model = vits_model.half()
|
||||
vits_model = vits_model.to(self.configs.device)
|
||||
vits_model = vits_model.eval()
|
||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
self.vits_model = vits_model
|
||||
if self.configs.is_half:
|
||||
self.vits_model = self.vits_model.half()
|
||||
|
||||
|
||||
def init_t2s_weights(self, weights_path: str):
|
||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
self.configs.save_configs()
|
||||
self.configs.hz = 50
|
||||
@ -249,11 +274,61 @@ class TTS:
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
|
||||
flash_attn_enabled=self.configs.flash_attn_enabled)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if self.configs.is_half:
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model = t2s_model.eval()
|
||||
self.t2s_model = t2s_model
|
||||
if self.configs.is_half:
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
def enable_half_precision(self, enable: bool = True):
|
||||
'''
|
||||
To enable half precision for the TTS model.
|
||||
Args:
|
||||
enable: bool, whether to enable half precision.
|
||||
|
||||
'''
|
||||
if self.configs.device == "cpu" and enable:
|
||||
print("Half precision is not supported on CPU.")
|
||||
return
|
||||
|
||||
self.configs.is_half = enable
|
||||
self.precison = torch.float16 if enable else torch.float32
|
||||
self.configs.save_configs()
|
||||
if enable:
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model =self.t2s_model.half()
|
||||
if self.vits_model is not None:
|
||||
self.vits_model = self.vits_model.half()
|
||||
if self.bert_model is not None:
|
||||
self.bert_model =self.bert_model.half()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
else:
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model = self.t2s_model.float()
|
||||
if self.vits_model is not None:
|
||||
self.vits_model = self.vits_model.float()
|
||||
if self.bert_model is not None:
|
||||
self.bert_model = self.bert_model.float()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||
|
||||
def set_device(self, device: torch.device):
|
||||
'''
|
||||
To set the device for all models.
|
||||
Args:
|
||||
device: torch.device, the device to use for all models.
|
||||
'''
|
||||
self.configs.device = device
|
||||
self.configs.save_configs()
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model = self.t2s_model.to(device)
|
||||
if self.vits_model is not None:
|
||||
self.vits_model = self.vits_model.to(device)
|
||||
if self.bert_model is not None:
|
||||
self.bert_model = self.bert_model.to(device)
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
|
||||
|
||||
def set_ref_audio(self, ref_audio_path:str):
|
||||
'''
|
||||
@ -354,7 +429,7 @@ class TTS:
|
||||
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
||||
while pos < pos_end:
|
||||
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||||
score=batch[(pos_end-pos)//2]/batch.mean()
|
||||
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
|
||||
if (score>=threshold) or (pos_end-pos==1):
|
||||
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
||||
batch_index_list_len += len(batch_index)
|
||||
@ -386,13 +461,13 @@ class TTS:
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
||||
.to(dtype=self.precison)
|
||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||
else:
|
||||
all_bert_features = item["bert_features"]\
|
||||
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
||||
.to(dtype=self.precison)
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
all_phones = phones
|
||||
# norm_text = item["norm_text"]
|
||||
@ -412,7 +487,7 @@ class TTS:
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
# all_bert_features_batch = all_bert_features_list
|
||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32)
|
||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
|
||||
for idx, item in enumerate(all_bert_features_list):
|
||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
|
||||
@ -542,6 +617,11 @@ class TTS:
|
||||
|
||||
###### text preprocessing ########
|
||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
||||
if len(data) == 0:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
|
||||
dtype=np.int16)
|
||||
return
|
||||
|
||||
t1 = ttime()
|
||||
data, batch_index_list = self.to_batch(data,
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
@ -594,10 +674,8 @@ class TTS:
|
||||
t4 = ttime()
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
refer_audio_spepc = refer_audio_spepc.half()
|
||||
|
||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
|
||||
.to(dtype=self.precison, device=self.configs.device)
|
||||
|
||||
batch_audio_fragment = []
|
||||
|
||||
@ -679,7 +757,7 @@ class TTS:
|
||||
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
dtype=torch.float16 if self.configs.is_half else torch.float32,
|
||||
dtype=self.precison,
|
||||
device=self.configs.device
|
||||
)
|
||||
|
||||
@ -697,13 +775,9 @@ class TTS:
|
||||
# audio = [item for batch in audio for item in batch]
|
||||
audio = sum(audio, [])
|
||||
|
||||
|
||||
try:
|
||||
audio = np.concatenate(audio, 0)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
except:
|
||||
audio = np.array(audio)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
audio = np.concatenate(audio, 0)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
try:
|
||||
if speed_factor != 1.0:
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit d0716731840c8f405a8c1c60277a646029a2c2ae
|
||||
Subproject commit 66f358cc1140490bb1f06ae4550e5c8a0dedd86f
|
Loading…
x
Reference in New Issue
Block a user