mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-09 00:10:00 +08:00
允许指定经典推理而不并行
This commit is contained in:
parent
a057c697e7
commit
c80f4b9557
@ -1,25 +0,0 @@
|
|||||||
使用说明:
|
|
||||||
|
|
||||||
当您想推理时,请开启 TTS后台服务 与 TTS WebUI
|
|
||||||
当您想调整模型时,请开启 TTS模型管理 与 TTS后台服务 、 TTS WebUI
|
|
||||||
当您仅仅想训练时,只需开启 GPTsoVITS WebUI
|
|
||||||
|
|
||||||
|
|
||||||
GPT-soVITS Start V (推理特化包版本)
|
|
||||||
版本:1.3.20240311.1
|
|
||||||
作者:小贼丑
|
|
||||||
QQ/Wechat:406267780
|
|
||||||
|
|
||||||
本界面启动器适配:推理特化包2.1 by Xter
|
|
||||||
|
|
||||||
推理特化包信息:
|
|
||||||
fork地址 by XTer:https://github.com/X-T-E-R/GPT-SoVITS-Inference
|
|
||||||
推理与模型管理子模块地址 by XTer:https://github.com/X-T-E-R/TTS-for-GPT-soVITS
|
|
||||||
原项目 by RVC-Boss(花儿不哭):https://github.com/RVC-Boss/GPT-SoVITS
|
|
||||||
|
|
||||||
推理特化包交流群:863760614
|
|
||||||
|
|
||||||
版权说明:
|
|
||||||
本软件仅为人工智能语音合成项目的一个启动器,请使用者注意,所有数据来源和模型使用均应尊重原始版权所有者的权益。使用本软件产生的任何内容,用户需自行承担相应的版权责任。
|
|
||||||
用户应理解,本软件的使用不授予任何数据来源或模型的版权、许可或使用权。所有使用必须遵守相关版权法律和协议,确保数据的合法获取和使用。
|
|
||||||
如不认可本声明或不承担相应责任,请立即停止使用本软件,并删除相关文件。继续使用即视为接受本声明条款,愿意遵守上述规定。
|
|
@ -1,30 +1,73 @@
|
|||||||
|
[Title]
|
||||||
|
Text=特化推理包2.1 by Xter | 原项目 by RVC-Boss(花儿不哭)
|
||||||
[HomeImage]
|
[HomeImage]
|
||||||
Route=\Cfg\HomeImage.jpg
|
Route=\Cfg\HomeImage.jpg
|
||||||
[TTSModuleManagerStart]
|
|
||||||
Route=10 启动模型管理界面(可选).bat
|
|
||||||
[TTSBackgroundServiceStart]
|
[TTSBackgroundServiceStart]
|
||||||
|
FolderName=TTS 后台服务
|
||||||
|
ButtonName=TTS 后台服务
|
||||||
Route=3 启动后端程序.bat
|
Route=3 启动后端程序.bat
|
||||||
[TTSWebUIStart]
|
[TTSWebUIStart]
|
||||||
|
FolderName=TTS WebUI
|
||||||
|
Hide=N
|
||||||
|
ButtonName=TTS WebUI
|
||||||
Route=4 启动前端合成程序(可选,依赖后端).bat
|
Route=4 启动前端合成程序(可选,依赖后端).bat
|
||||||
[GPTsoVITSWebUIStart]
|
[GPTsoVITSWebUIStart]
|
||||||
|
FolderName=GPTsV WebUI
|
||||||
|
Hide=N
|
||||||
|
ButtonName=GPTsV WebUI
|
||||||
Route=11 启动原项目的训练界面(小白别开,请根据页面上的文档链接自行研究,推理群不包解答).bat
|
Route=11 启动原项目的训练界面(小白别开,请根据页面上的文档链接自行研究,推理群不包解答).bat
|
||||||
[TTSModuleFolder]
|
[TTSModuleManagerStart]
|
||||||
Route=..\trained
|
FolderName=TTS 模型管理
|
||||||
[TTSOnekeyStartFolder]
|
Hide=N
|
||||||
Route=..\0 一键启动脚本
|
ButtonName=TTS 模型管理
|
||||||
[GPTsoVITSRootFolder]
|
Route=10 启动模型管理界面(可选).bat
|
||||||
Route=..\
|
|
||||||
[TTSJsonConfig]
|
|
||||||
Route=..\Inference\config.json
|
|
||||||
[UpdateProgramStart]
|
[UpdateProgramStart]
|
||||||
|
FolderName=更新项目
|
||||||
|
ButtonName=更新项目
|
||||||
Route=0 一键更新项目.bat
|
Route=0 一键更新项目.bat
|
||||||
[UpdateDependenciesStart]
|
[UpdateDependenciesStart]
|
||||||
|
FolderName=更新依赖
|
||||||
|
Hide=N
|
||||||
|
ButtonName=更新依赖
|
||||||
Route=1 一键更新本项目所需要的依赖.bat
|
Route=1 一键更新本项目所需要的依赖.bat
|
||||||
[ForceUpdateStart]
|
[ForceUpdateStart]
|
||||||
|
FolderName=强制更新
|
||||||
|
Hide=N
|
||||||
|
ButtonName=强制更新
|
||||||
Route=999 强制更新:会覆盖你的设置,慎用,和0功能类似.bat
|
Route=999 强制更新:会覆盖你的设置,慎用,和0功能类似.bat
|
||||||
|
[TTSModuleFolder]
|
||||||
|
ButtonName=TTS模型目录
|
||||||
|
Hide=N
|
||||||
|
Route=..\trained
|
||||||
|
[TTSOnekeyStartFolder]
|
||||||
|
ButtonName=TTS一键启动目录
|
||||||
|
Hide=N
|
||||||
|
Route=..\0 一键启动脚本
|
||||||
|
[GPTsoVITSRootFolder]
|
||||||
|
ButtonName=GPTsoVITS根目录
|
||||||
|
Hide=N
|
||||||
|
Route=..\
|
||||||
[TTSDocument]
|
[TTSDocument]
|
||||||
|
ButtonName=TTS项目文档
|
||||||
|
Hide=N
|
||||||
Route=https://www.yuque.com/xter/zibxlp
|
Route=https://www.yuque.com/xter/zibxlp
|
||||||
[GPTsoVITSDocument]
|
[GPTsoVITSDocument]
|
||||||
|
ButtonName=GPTsoVITS项目文档
|
||||||
|
Hide=N
|
||||||
Route=https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e
|
Route=https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e
|
||||||
|
[StartDocument]
|
||||||
|
ButtonName=启动器项目文档
|
||||||
|
Hide=N
|
||||||
|
Route=https://note.youdao.com/s/DlzSyaLl
|
||||||
|
[JsonConfig]
|
||||||
|
Route=..\Inference\config.json
|
||||||
[AboutTxt]
|
[AboutTxt]
|
||||||
Route=\Cfg\About.txt
|
Route=\Cfg\About.txt
|
||||||
|
[Status]
|
||||||
|
Hide=N
|
||||||
|
TWU Hide=N
|
||||||
|
GWU Hide=N
|
||||||
|
TMM Hide=N
|
||||||
|
UpD Hide=N
|
||||||
|
FUp Hide=N
|
||||||
|
Line Hide=156
|
||||||
|
Binary file not shown.
@ -14,7 +14,7 @@ from AR.models.utils import (
|
|||||||
logits_to_probs,
|
logits_to_probs,
|
||||||
multinomial_sample_one_no_sync,
|
multinomial_sample_one_no_sync,
|
||||||
dpo_loss,
|
dpo_loss,
|
||||||
make_reject_y,
|
make_reject_y,
|
||||||
get_batch_logps
|
get_batch_logps
|
||||||
)
|
)
|
||||||
from AR.modules.embedding import SinePositionalEmbedding
|
from AR.modules.embedding import SinePositionalEmbedding
|
||||||
@ -26,11 +26,6 @@ from torch import nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn import flash_attn_with_kvcache
|
|
||||||
except ImportError:
|
|
||||||
flash_attn_with_kvcache = None
|
|
||||||
|
|
||||||
default_config = {
|
default_config = {
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
"hidden_dim": 512,
|
"hidden_dim": 512,
|
||||||
@ -302,7 +297,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
(0, y_len),
|
(0, y_len),
|
||||||
value=True,
|
value=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
y_attn_mask = F.pad(
|
y_attn_mask = F.pad(
|
||||||
torch.triu(
|
torch.triu(
|
||||||
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
||||||
@ -363,7 +358,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
||||||
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
||||||
|
|
||||||
loss = loss_1 + loss_2
|
loss = loss_1 + loss_2
|
||||||
|
|
||||||
return loss, acc
|
return loss, acc
|
||||||
@ -432,14 +427,14 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
x_lens,
|
x_lens,
|
||||||
prompts,
|
prompts,
|
||||||
bert_feature,
|
bert_feature,
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
@ -678,18 +673,22 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
# AR Decoder
|
# AR Decoder
|
||||||
y = prompts
|
y = prompts
|
||||||
|
|
||||||
x_len = x.shape[1]
|
x_len = x.shape[1]
|
||||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||||
stop = False
|
stop = False
|
||||||
# print(1111111,self.num_layers)
|
# print(1111111,self.num_layers)
|
||||||
|
cache = {
|
||||||
if flash_attn_with_kvcache is not None:
|
"all_stage": self.num_layers,
|
||||||
k_cache = [torch.empty(x.shape[0], 2048, 16, 32, dtype=x.dtype, device=x.device) for _ in range(self.num_layers)]
|
"k": [None] * self.num_layers, ###根据配置自己手写
|
||||||
v_cache = [torch.empty(x.shape[0], 2048, 16, 32, dtype=x.dtype, device=x.device) for _ in range(self.num_layers)]
|
"v": [None] * self.num_layers,
|
||||||
else:
|
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
|
||||||
k_cache = [None] * self.num_layers
|
"y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
|
||||||
v_cache = [None] * self.num_layers
|
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
|
||||||
|
# "xy_dec":None,###不需要,本来只需要最后一个做logits
|
||||||
|
"first_infer": 1,
|
||||||
|
"stage": 0,
|
||||||
|
}
|
||||||
################### first step ##########################
|
################### first step ##########################
|
||||||
if y is not None:
|
if y is not None:
|
||||||
y_emb = self.ar_audio_embedding(y)
|
y_emb = self.ar_audio_embedding(y)
|
||||||
@ -697,6 +696,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
prefix_len = y.shape[1]
|
prefix_len = y.shape[1]
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
cache["y_emb"] = y_emb
|
||||||
ref_free = False
|
ref_free = False
|
||||||
else:
|
else:
|
||||||
y_emb = None
|
y_emb = None
|
||||||
@ -708,10 +708,10 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
ref_free = True
|
ref_free = True
|
||||||
|
|
||||||
x_attn_mask_pad = F.pad(
|
x_attn_mask_pad = F.pad(
|
||||||
x_attn_mask,
|
x_attn_mask,
|
||||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||||
value=True,
|
value=True,
|
||||||
)
|
)
|
||||||
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
@ -725,16 +725,14 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
batch_idx_map = list(range(y.shape[0]))
|
batch_idx_map = list(range(y.shape[0]))
|
||||||
idx_list = [None]*y.shape[0]
|
idx_list = [None]*y.shape[0]
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
logits = self.infer_one_step(xy_pos, xy_attn_mask, k_cache, v_cache, cache_seqlens)
|
|
||||||
|
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
|
||||||
if idx == 0:
|
logits = self.ar_predict_layer(
|
||||||
cache_seqlens += xy_pos.shape[1]
|
xy_dec[:, -1]
|
||||||
else:
|
) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
|
||||||
cache_seqlens += 1
|
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||||
xy_attn_mask = None
|
if(idx==0):###第一次跑不能EOS否则没有了
|
||||||
|
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
||||||
if idx == 0:
|
|
||||||
logits = logits[:, :-1]
|
|
||||||
samples = sample(
|
samples = sample(
|
||||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||||
)[0]
|
)[0]
|
||||||
@ -778,12 +776,15 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
||||||
stop = True
|
stop = True
|
||||||
if stop:
|
if stop:
|
||||||
if y.shape[1] == 0:
|
# if prompts.shape[1] == y.shape[1]:
|
||||||
|
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||||
|
# print("bad zero prediction")
|
||||||
|
if y.shape[1]==0:
|
||||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||||
print("bad zero prediction")
|
print("bad zero prediction")
|
||||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||||
break
|
break
|
||||||
|
|
||||||
####################### update next step ###################################
|
####################### update next step ###################################
|
||||||
cache["first_infer"] = 0
|
cache["first_infer"] = 0
|
||||||
if cache["y_emb"] is not None:
|
if cache["y_emb"] is not None:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from copy import deepcopy
|
||||||
import math
|
import math
|
||||||
import os, sys
|
import os, sys
|
||||||
import random
|
import random
|
||||||
@ -50,22 +51,7 @@ custom:
|
|||||||
|
|
||||||
|
|
||||||
class TTS_Config:
|
class TTS_Config:
|
||||||
def __init__(self, configs: Union[dict, str]):
|
default_configs={
|
||||||
if isinstance(configs, str) and configs=="":
|
|
||||||
self.default_configs:dict = None
|
|
||||||
self.configs_path = "GPT_SoVITS/configs/tts_infer.yaml"
|
|
||||||
else:
|
|
||||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
|
||||||
os.makedirs(configs_base_path, exist_ok=True)
|
|
||||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
|
||||||
if isinstance(configs, str):
|
|
||||||
self.configs_path = configs
|
|
||||||
configs:dict = self._load_configs(configs)
|
|
||||||
|
|
||||||
# assert isinstance(configs, dict)
|
|
||||||
self.default_configs:dict = configs.get("default", None)
|
|
||||||
if self.default_configs is None:
|
|
||||||
self.default_configs={
|
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
"is_half": False,
|
"is_half": False,
|
||||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||||
@ -74,18 +60,54 @@ class TTS_Config:
|
|||||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||||
"flash_attn_enabled": True
|
"flash_attn_enabled": True
|
||||||
}
|
}
|
||||||
if isinstance(configs, dict):
|
configs:dict = None
|
||||||
self.configs:dict = configs.get("custom", self.default_configs)
|
def __init__(self, configs: Union[dict, str]=None):
|
||||||
else:
|
|
||||||
self.configs:dict = self.default_configs
|
|
||||||
|
|
||||||
self.device = self.configs.get("device")
|
# 设置默认配置文件路径
|
||||||
self.is_half = self.configs.get("is_half")
|
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||||
self.t2s_weights_path = self.configs.get("t2s_weights_path")
|
os.makedirs(configs_base_path, exist_ok=True)
|
||||||
self.vits_weights_path = self.configs.get("vits_weights_path")
|
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||||
self.bert_base_path = self.configs.get("bert_base_path")
|
|
||||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
|
if configs in ["", None]:
|
||||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
|
if not os.path.exists(self.configs_path):
|
||||||
|
self.save_configs()
|
||||||
|
print(f"Create default config file at {self.configs_path}")
|
||||||
|
configs:dict = {"default": deepcopy(self.default_configs)}
|
||||||
|
|
||||||
|
if isinstance(configs, str):
|
||||||
|
self.configs_path = configs
|
||||||
|
configs:dict = self._load_configs(self.configs_path)
|
||||||
|
|
||||||
|
assert isinstance(configs, dict)
|
||||||
|
default_configs:dict = configs.get("default", None)
|
||||||
|
if default_configs is not None:
|
||||||
|
self.default_configs = default_configs
|
||||||
|
|
||||||
|
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
|
||||||
|
|
||||||
|
|
||||||
|
self.device = self.configs.get("device", torch.device("cpu"))
|
||||||
|
self.is_half = self.configs.get("is_half", False)
|
||||||
|
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
|
||||||
|
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||||
|
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||||
|
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||||
|
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||||
|
|
||||||
|
|
||||||
|
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||||
|
self.t2s_weights_path = self.default_configs['t2s_weights_path']
|
||||||
|
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||||||
|
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
||||||
|
self.vits_weights_path = self.default_configs['vits_weights_path']
|
||||||
|
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
||||||
|
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
||||||
|
self.bert_base_path = self.default_configs['bert_base_path']
|
||||||
|
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
||||||
|
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||||||
|
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
|
||||||
|
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||||
|
self.update_configs()
|
||||||
|
|
||||||
|
|
||||||
self.max_sec = None
|
self.max_sec = None
|
||||||
@ -109,24 +131,18 @@ class TTS_Config:
|
|||||||
|
|
||||||
def save_configs(self, configs_path:str=None)->None:
|
def save_configs(self, configs_path:str=None)->None:
|
||||||
configs={
|
configs={
|
||||||
"default": {
|
"default":self.default_configs,
|
||||||
"device": "cpu",
|
|
||||||
"is_half": False,
|
|
||||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
|
||||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
|
||||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
|
||||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
|
||||||
"flash_attn_enabled": True
|
|
||||||
},
|
|
||||||
"custom": self.update_configs()
|
|
||||||
}
|
}
|
||||||
|
if self.configs is not None:
|
||||||
|
configs["custom"] = self.update_configs()
|
||||||
|
|
||||||
if configs_path is None:
|
if configs_path is None:
|
||||||
configs_path = self.configs_path
|
configs_path = self.configs_path
|
||||||
with open(configs_path, 'w') as f:
|
with open(configs_path, 'w') as f:
|
||||||
yaml.dump(configs, f)
|
yaml.dump(configs, f)
|
||||||
|
|
||||||
def update_configs(self):
|
def update_configs(self):
|
||||||
config = {
|
self.config = {
|
||||||
"device" : str(self.device),
|
"device" : str(self.device),
|
||||||
"is_half" : self.is_half,
|
"is_half" : self.is_half,
|
||||||
"t2s_weights_path" : self.t2s_weights_path,
|
"t2s_weights_path" : self.t2s_weights_path,
|
||||||
@ -135,7 +151,7 @@ class TTS_Config:
|
|||||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||||
"flash_attn_enabled" : self.flash_attn_enabled
|
"flash_attn_enabled" : self.flash_attn_enabled
|
||||||
}
|
}
|
||||||
return config
|
return self.config
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
self.configs = self.update_configs()
|
self.configs = self.update_configs()
|
||||||
@ -144,6 +160,9 @@ class TTS_Config:
|
|||||||
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
||||||
string += "-" * 100 + '\n'
|
string += "-" * 100 + '\n'
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
class TTS:
|
class TTS:
|
||||||
@ -180,35 +199,40 @@ class TTS:
|
|||||||
|
|
||||||
|
|
||||||
self.stop_flag:bool = False
|
self.stop_flag:bool = False
|
||||||
|
self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||||
|
|
||||||
def _init_models(self,):
|
def _init_models(self,):
|
||||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||||
self.init_vits_weights(self.configs.vits_weights_path)
|
self.init_vits_weights(self.configs.vits_weights_path)
|
||||||
self.init_bert_weights(self.configs.bert_base_path)
|
self.init_bert_weights(self.configs.bert_base_path)
|
||||||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||||||
|
# self.enable_half_precision(self.configs.is_half)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_cnhuhbert_weights(self, base_path: str):
|
def init_cnhuhbert_weights(self, base_path: str):
|
||||||
|
print(f"Loading CNHuBERT weights from {base_path}")
|
||||||
self.cnhuhbert_model = CNHubert(base_path)
|
self.cnhuhbert_model = CNHubert(base_path)
|
||||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||||
if self.configs.is_half == True:
|
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||||
|
if self.configs.is_half:
|
||||||
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_bert_weights(self, base_path: str):
|
def init_bert_weights(self, base_path: str):
|
||||||
|
print(f"Loading BERT weights from {base_path}")
|
||||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||||
self.bert_model=self.bert_model.eval()
|
self.bert_model=self.bert_model.eval()
|
||||||
|
self.bert_model = self.bert_model.to(self.configs.device)
|
||||||
if self.configs.is_half:
|
if self.configs.is_half:
|
||||||
self.bert_model = self.bert_model.half()
|
self.bert_model = self.bert_model.half()
|
||||||
self.bert_model = self.bert_model.to(self.configs.device)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_vits_weights(self, weights_path: str):
|
def init_vits_weights(self, weights_path: str):
|
||||||
|
print(f"Loading VITS weights from {weights_path}")
|
||||||
self.configs.vits_weights_path = weights_path
|
self.configs.vits_weights_path = weights_path
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
|
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
|
||||||
@ -231,15 +255,16 @@ class TTS:
|
|||||||
if hasattr(vits_model, "enc_q"):
|
if hasattr(vits_model, "enc_q"):
|
||||||
del vits_model.enc_q
|
del vits_model.enc_q
|
||||||
|
|
||||||
if self.configs.is_half:
|
|
||||||
vits_model = vits_model.half()
|
|
||||||
vits_model = vits_model.to(self.configs.device)
|
vits_model = vits_model.to(self.configs.device)
|
||||||
vits_model = vits_model.eval()
|
vits_model = vits_model.eval()
|
||||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
self.vits_model = vits_model
|
self.vits_model = vits_model
|
||||||
|
if self.configs.is_half:
|
||||||
|
self.vits_model = self.vits_model.half()
|
||||||
|
|
||||||
|
|
||||||
def init_t2s_weights(self, weights_path: str):
|
def init_t2s_weights(self, weights_path: str):
|
||||||
|
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||||
self.configs.t2s_weights_path = weights_path
|
self.configs.t2s_weights_path = weights_path
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
self.configs.hz = 50
|
self.configs.hz = 50
|
||||||
@ -249,11 +274,61 @@ class TTS:
|
|||||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
|
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
|
||||||
flash_attn_enabled=self.configs.flash_attn_enabled)
|
flash_attn_enabled=self.configs.flash_attn_enabled)
|
||||||
t2s_model.load_state_dict(dict_s1["weight"])
|
t2s_model.load_state_dict(dict_s1["weight"])
|
||||||
if self.configs.is_half:
|
|
||||||
t2s_model = t2s_model.half()
|
|
||||||
t2s_model = t2s_model.to(self.configs.device)
|
t2s_model = t2s_model.to(self.configs.device)
|
||||||
t2s_model = t2s_model.eval()
|
t2s_model = t2s_model.eval()
|
||||||
self.t2s_model = t2s_model
|
self.t2s_model = t2s_model
|
||||||
|
if self.configs.is_half:
|
||||||
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
|
def enable_half_precision(self, enable: bool = True):
|
||||||
|
'''
|
||||||
|
To enable half precision for the TTS model.
|
||||||
|
Args:
|
||||||
|
enable: bool, whether to enable half precision.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self.configs.device == "cpu" and enable:
|
||||||
|
print("Half precision is not supported on CPU.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.configs.is_half = enable
|
||||||
|
self.precison = torch.float16 if enable else torch.float32
|
||||||
|
self.configs.save_configs()
|
||||||
|
if enable:
|
||||||
|
if self.t2s_model is not None:
|
||||||
|
self.t2s_model =self.t2s_model.half()
|
||||||
|
if self.vits_model is not None:
|
||||||
|
self.vits_model = self.vits_model.half()
|
||||||
|
if self.bert_model is not None:
|
||||||
|
self.bert_model =self.bert_model.half()
|
||||||
|
if self.cnhuhbert_model is not None:
|
||||||
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||||
|
else:
|
||||||
|
if self.t2s_model is not None:
|
||||||
|
self.t2s_model = self.t2s_model.float()
|
||||||
|
if self.vits_model is not None:
|
||||||
|
self.vits_model = self.vits_model.float()
|
||||||
|
if self.bert_model is not None:
|
||||||
|
self.bert_model = self.bert_model.float()
|
||||||
|
if self.cnhuhbert_model is not None:
|
||||||
|
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||||
|
|
||||||
|
def set_device(self, device: torch.device):
|
||||||
|
'''
|
||||||
|
To set the device for all models.
|
||||||
|
Args:
|
||||||
|
device: torch.device, the device to use for all models.
|
||||||
|
'''
|
||||||
|
self.configs.device = device
|
||||||
|
self.configs.save_configs()
|
||||||
|
if self.t2s_model is not None:
|
||||||
|
self.t2s_model = self.t2s_model.to(device)
|
||||||
|
if self.vits_model is not None:
|
||||||
|
self.vits_model = self.vits_model.to(device)
|
||||||
|
if self.bert_model is not None:
|
||||||
|
self.bert_model = self.bert_model.to(device)
|
||||||
|
if self.cnhuhbert_model is not None:
|
||||||
|
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
|
||||||
|
|
||||||
def set_ref_audio(self, ref_audio_path:str):
|
def set_ref_audio(self, ref_audio_path:str):
|
||||||
'''
|
'''
|
||||||
@ -354,7 +429,7 @@ class TTS:
|
|||||||
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
||||||
while pos < pos_end:
|
while pos < pos_end:
|
||||||
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||||||
score=batch[(pos_end-pos)//2]/batch.mean()
|
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
|
||||||
if (score>=threshold) or (pos_end-pos==1):
|
if (score>=threshold) or (pos_end-pos==1):
|
||||||
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
||||||
batch_index_list_len += len(batch_index)
|
batch_index_list_len += len(batch_index)
|
||||||
@ -386,13 +461,13 @@ class TTS:
|
|||||||
for item in item_list:
|
for item in item_list:
|
||||||
if prompt_data is not None:
|
if prompt_data is not None:
|
||||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||||
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
.to(dtype=self.precison)
|
||||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||||
else:
|
else:
|
||||||
all_bert_features = item["bert_features"]\
|
all_bert_features = item["bert_features"]\
|
||||||
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
.to(dtype=self.precison)
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
all_phones = phones
|
all_phones = phones
|
||||||
# norm_text = item["norm_text"]
|
# norm_text = item["norm_text"]
|
||||||
@ -412,7 +487,7 @@ class TTS:
|
|||||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
# all_bert_features_batch = all_bert_features_list
|
# all_bert_features_batch = all_bert_features_list
|
||||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32)
|
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
|
||||||
for idx, item in enumerate(all_bert_features_list):
|
for idx, item in enumerate(all_bert_features_list):
|
||||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||||
|
|
||||||
@ -542,6 +617,11 @@ class TTS:
|
|||||||
|
|
||||||
###### text preprocessing ########
|
###### text preprocessing ########
|
||||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
||||||
|
if len(data) == 0:
|
||||||
|
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
|
||||||
|
dtype=np.int16)
|
||||||
|
return
|
||||||
|
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
data, batch_index_list = self.to_batch(data,
|
data, batch_index_list = self.to_batch(data,
|
||||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||||
@ -594,10 +674,8 @@ class TTS:
|
|||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
t_34 += t4 - t3
|
t_34 += t4 - t3
|
||||||
|
|
||||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
|
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
|
||||||
if self.configs.is_half:
|
.to(dtype=self.precison, device=self.configs.device)
|
||||||
refer_audio_spepc = refer_audio_spepc.half()
|
|
||||||
|
|
||||||
|
|
||||||
batch_audio_fragment = []
|
batch_audio_fragment = []
|
||||||
|
|
||||||
@ -679,7 +757,7 @@ class TTS:
|
|||||||
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
||||||
zero_wav = torch.zeros(
|
zero_wav = torch.zeros(
|
||||||
int(self.configs.sampling_rate * 0.3),
|
int(self.configs.sampling_rate * 0.3),
|
||||||
dtype=torch.float16 if self.configs.is_half else torch.float32,
|
dtype=self.precison,
|
||||||
device=self.configs.device
|
device=self.configs.device
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -697,13 +775,9 @@ class TTS:
|
|||||||
# audio = [item for batch in audio for item in batch]
|
# audio = [item for batch in audio for item in batch]
|
||||||
audio = sum(audio, [])
|
audio = sum(audio, [])
|
||||||
|
|
||||||
|
|
||||||
try:
|
audio = np.concatenate(audio, 0)
|
||||||
audio = np.concatenate(audio, 0)
|
audio = (audio * 32768).astype(np.int16)
|
||||||
audio = (audio * 32768).astype(np.int16)
|
|
||||||
except:
|
|
||||||
audio = np.array(audio)
|
|
||||||
audio = (audio * 32768).astype(np.int16)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if speed_factor != 1.0:
|
if speed_factor != 1.0:
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit d0716731840c8f405a8c1c60277a646029a2c2ae
|
Subproject commit 66f358cc1140490bb1f06ae4550e5c8a0dedd86f
|
Loading…
x
Reference in New Issue
Block a user