mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-14 22:18:11 +08:00
Refactor the audio preparation workflow to utilize torchaudio for resampling, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms and update the PrepareRefSemanticBatchWorker to include detailed timing metrics for profiling. Additionally, implement a new CPU limiter for managing resource allocation during audio processing. These changes improve the efficiency and maintainability of the TTS system.
2583 lines
119 KiB
Python
2583 lines
119 KiB
Python
import gc
|
||
import asyncio
|
||
import concurrent.futures
|
||
import math
|
||
import os
|
||
import random
|
||
import sys
|
||
import time
|
||
import traceback
|
||
from copy import deepcopy
|
||
|
||
now_dir = os.getcwd()
|
||
sys.path.append(now_dir)
|
||
from typing import List, Tuple, Union
|
||
|
||
from runtime_preload import preload_text_runtime_deps
|
||
|
||
preload_text_runtime_deps()
|
||
|
||
import ffmpeg
|
||
import librosa
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import torchaudio
|
||
import yaml
|
||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||
from BigVGAN.bigvgan import BigVGAN
|
||
from feature_extractor.cnhubert import CNHubert
|
||
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
||
from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
|
||
from peft import LoraConfig, get_peft_model
|
||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||
from tqdm import tqdm
|
||
|
||
from tools.audio_sr import AP_BWE
|
||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||
from TTS_infer_pack.text_segmentation_method import splits
|
||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter
|
||
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
|
||
from TTS_infer_pack.prepare_ref_semantic_batch_worker import (
|
||
PrepareRefSemanticBatchWorker,
|
||
prepare_prompt_semantic_wav16k,
|
||
)
|
||
from TTS_infer_pack.prepare_text_cpu_worker import PrepareTextCpuWorker
|
||
from sv import SV
|
||
|
||
resample_transform_dict = {}
|
||
|
||
|
||
def resample(audio_tensor, sr0, sr1, device):
|
||
global resample_transform_dict
|
||
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
||
if key not in resample_transform_dict:
|
||
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
||
return resample_transform_dict[key](audio_tensor)
|
||
|
||
|
||
language = os.environ.get("language", "Auto")
|
||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||
i18n = I18nAuto(language=language)
|
||
|
||
|
||
spec_min = -12
|
||
spec_max = 2
|
||
|
||
|
||
def norm_spec(x):
|
||
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
||
|
||
|
||
def denorm_spec(x):
|
||
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||
|
||
|
||
mel_fn = lambda x: mel_spectrogram_torch(
|
||
x,
|
||
**{
|
||
"n_fft": 1024,
|
||
"win_size": 1024,
|
||
"hop_size": 256,
|
||
"num_mels": 100,
|
||
"sampling_rate": 24000,
|
||
"fmin": 0,
|
||
"fmax": None,
|
||
"center": False,
|
||
},
|
||
)
|
||
|
||
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||
x,
|
||
**{
|
||
"n_fft": 1280,
|
||
"win_size": 1280,
|
||
"hop_size": 320,
|
||
"num_mels": 100,
|
||
"sampling_rate": 32000,
|
||
"fmin": 0,
|
||
"fmax": None,
|
||
"center": False,
|
||
},
|
||
)
|
||
|
||
|
||
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
||
# 将 NumPy 数组转换为原始 PCM 流
|
||
raw_audio = input_audio.astype(np.int16).tobytes()
|
||
|
||
# 设置 ffmpeg 输入流
|
||
input_stream = ffmpeg.input("pipe:", format="s16le", acodec="pcm_s16le", ar=str(sr), ac=1)
|
||
|
||
# 变速处理
|
||
output_stream = input_stream.filter("atempo", speed)
|
||
|
||
# 输出流到管道
|
||
out, _ = output_stream.output("pipe:", format="s16le", acodec="pcm_s16le").run(
|
||
input=raw_audio, capture_stdout=True, capture_stderr=True
|
||
)
|
||
|
||
# 将管道输出解码为 NumPy 数组
|
||
processed_audio = np.frombuffer(out, np.int16)
|
||
|
||
return processed_audio
|
||
|
||
|
||
class DictToAttrRecursive(dict):
|
||
def __init__(self, input_dict):
|
||
super().__init__(input_dict)
|
||
for key, value in input_dict.items():
|
||
if isinstance(value, dict):
|
||
value = DictToAttrRecursive(value)
|
||
self[key] = value
|
||
setattr(self, key, value)
|
||
|
||
def __getattr__(self, item):
|
||
try:
|
||
return self[item]
|
||
except KeyError:
|
||
raise AttributeError(f"Attribute {item} not found")
|
||
|
||
def __setattr__(self, key, value):
|
||
if isinstance(value, dict):
|
||
value = DictToAttrRecursive(value)
|
||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||
super().__setattr__(key, value)
|
||
|
||
def __delattr__(self, item):
|
||
try:
|
||
del self[item]
|
||
except KeyError:
|
||
raise AttributeError(f"Attribute {item} not found")
|
||
|
||
|
||
class NO_PROMPT_ERROR(Exception):
|
||
pass
|
||
|
||
|
||
# configs/tts_infer.yaml
|
||
"""
|
||
custom:
|
||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||
device: cpu
|
||
is_half: false
|
||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||
version: v2
|
||
v1:
|
||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||
device: cpu
|
||
is_half: false
|
||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||
version: v1
|
||
v2:
|
||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||
device: cpu
|
||
is_half: false
|
||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||
version: v2
|
||
v3:
|
||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||
device: cpu
|
||
is_half: false
|
||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
|
||
version: v3
|
||
v4:
|
||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||
device: cpu
|
||
is_half: false
|
||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||
version: v4
|
||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth
|
||
"""
|
||
|
||
|
||
def set_seed(seed: int):
|
||
seed = int(seed)
|
||
seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
|
||
print(f"Set seed to {seed}")
|
||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
try:
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
# torch.backends.cudnn.deterministic = True
|
||
# torch.backends.cudnn.benchmark = False
|
||
# torch.backends.cudnn.enabled = True
|
||
# 开启后会影响精度
|
||
torch.backends.cuda.matmul.allow_tf32 = False
|
||
torch.backends.cudnn.allow_tf32 = False
|
||
except:
|
||
pass
|
||
return seed
|
||
|
||
|
||
class TTS_Config:
|
||
default_configs = {
|
||
"v1": {
|
||
"device": "cpu",
|
||
"is_half": False,
|
||
"version": "v1",
|
||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||
},
|
||
"v2": {
|
||
"device": "cpu",
|
||
"is_half": False,
|
||
"version": "v2",
|
||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||
"vits_weights_path": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||
},
|
||
"v3": {
|
||
"device": "cpu",
|
||
"is_half": False,
|
||
"version": "v3",
|
||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||
},
|
||
"v4": {
|
||
"device": "cpu",
|
||
"is_half": False,
|
||
"version": "v4",
|
||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||
"vits_weights_path": "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||
},
|
||
"v2Pro": {
|
||
"device": "cpu",
|
||
"is_half": False,
|
||
"version": "v2Pro",
|
||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||
"vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||
},
|
||
"v2ProPlus": {
|
||
"device": "cpu",
|
||
"is_half": False,
|
||
"version": "v2ProPlus",
|
||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||
"vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth",
|
||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||
},
|
||
}
|
||
configs: dict = None
|
||
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||
v2_languages: list = ["auto", "auto_yue", "en", "zh", "ja", "yue", "ko", "all_zh", "all_ja", "all_yue", "all_ko"]
|
||
languages: list = v2_languages
|
||
mute_tokens: dict = {
|
||
"v1" : 486,
|
||
"v2" : 486,
|
||
"v2Pro": 486,
|
||
"v2ProPlus": 486,
|
||
"v3" : 486,
|
||
"v4" : 486,
|
||
}
|
||
mute_emb_sim_matrix: torch.Tensor = None
|
||
# "all_zh",#全部按中文识别
|
||
# "en",#全部按英文识别#######不变
|
||
# "all_ja",#全部按日文识别
|
||
# "all_yue",#全部按中文识别
|
||
# "all_ko",#全部按韩文识别
|
||
# "zh",#按中英混合识别####不变
|
||
# "ja",#按日英混合识别####不变
|
||
# "yue",#按粤英混合识别####不变
|
||
# "ko",#按韩英混合识别####不变
|
||
# "auto",#多语种启动切分识别语种
|
||
# "auto_yue",#多语种启动切分识别语种
|
||
|
||
def __init__(self, configs: Union[dict, str] = None):
|
||
# 设置默认配置文件路径
|
||
configs_base_path: str = "GPT_SoVITS/configs/"
|
||
os.makedirs(configs_base_path, exist_ok=True)
|
||
self.configs_path: str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||
|
||
if configs in ["", None]:
|
||
if not os.path.exists(self.configs_path):
|
||
self.save_configs()
|
||
print(f"Create default config file at {self.configs_path}")
|
||
configs: dict = deepcopy(self.default_configs)
|
||
|
||
if isinstance(configs, str):
|
||
self.configs_path = configs
|
||
configs: dict = self._load_configs(self.configs_path)
|
||
|
||
assert isinstance(configs, dict)
|
||
configs_ = deepcopy(self.default_configs)
|
||
configs_.update(configs)
|
||
self.configs: dict = configs_.get("custom", configs_["v2ProPlus"])
|
||
self.default_configs = deepcopy(configs_)
|
||
|
||
self.device = self.configs.get("device", torch.device("cpu"))
|
||
if "cuda" in str(self.device) and not torch.cuda.is_available():
|
||
print("Warning: CUDA is not available, set device to CPU.")
|
||
self.device = torch.device("cpu")
|
||
|
||
self.is_half = self.configs.get("is_half", False)
|
||
if str(self.device) == "cpu" and self.is_half:
|
||
print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
||
self.is_half = False
|
||
|
||
version = self.configs.get("version", None)
|
||
self.version = version
|
||
assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!"
|
||
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||
self.languages = self.v1_languages if self.version == "v1" else self.v2_languages
|
||
|
||
self.use_vocoder: bool = False
|
||
|
||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||
self.t2s_weights_path = self.default_configs[version]["t2s_weights_path"]
|
||
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
||
self.vits_weights_path = self.default_configs[version]["vits_weights_path"]
|
||
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
||
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
||
self.bert_base_path = self.default_configs[version]["bert_base_path"]
|
||
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
||
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||
self.cnhuhbert_base_path = self.default_configs[version]["cnhuhbert_base_path"]
|
||
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||
self.update_configs()
|
||
|
||
self.max_sec = None
|
||
self.hz: int = 50
|
||
self.semantic_frame_rate: str = "25hz"
|
||
self.segment_size: int = 20480
|
||
self.filter_length: int = 2048
|
||
self.sampling_rate: int = 32000
|
||
self.hop_length: int = 640
|
||
self.win_length: int = 2048
|
||
self.n_speakers: int = 300
|
||
|
||
def _load_configs(self, configs_path: str) -> dict:
|
||
if os.path.exists(configs_path):
|
||
...
|
||
else:
|
||
print(i18n("路径不存在,使用默认配置"))
|
||
self.save_configs(configs_path)
|
||
with open(configs_path, "r", encoding="utf-8") as f:
|
||
configs = yaml.load(f, Loader=yaml.FullLoader)
|
||
|
||
return configs
|
||
|
||
def save_configs(self, configs_path: str = None) -> None:
|
||
configs = deepcopy(self.default_configs)
|
||
if self.configs is not None:
|
||
configs["custom"] = self.update_configs()
|
||
|
||
if configs_path is None:
|
||
configs_path = self.configs_path
|
||
with open(configs_path, "w") as f:
|
||
yaml.dump(configs, f)
|
||
|
||
def update_configs(self):
|
||
self.config = {
|
||
"device": str(self.device),
|
||
"is_half": self.is_half,
|
||
"version": self.version,
|
||
"t2s_weights_path": self.t2s_weights_path,
|
||
"vits_weights_path": self.vits_weights_path,
|
||
"bert_base_path": self.bert_base_path,
|
||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||
}
|
||
return self.config
|
||
|
||
def update_version(self, version: str) -> None:
|
||
self.version = version
|
||
self.languages = self.v1_languages if self.version == "v1" else self.v2_languages
|
||
|
||
def __str__(self):
|
||
self.configs = self.update_configs()
|
||
string = "TTS Config".center(100, "-") + "\n"
|
||
for k, v in self.configs.items():
|
||
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
||
string += "-" * 100 + "\n"
|
||
return string
|
||
|
||
def __repr__(self):
|
||
return self.__str__()
|
||
|
||
def __hash__(self):
|
||
return hash(self.configs_path)
|
||
|
||
def __eq__(self, other):
|
||
return isinstance(other, TTS_Config) and self.configs_path == other.configs_path
|
||
|
||
|
||
class TTS:
|
||
def __init__(self, configs: Union[dict, str, TTS_Config]):
|
||
if isinstance(configs, TTS_Config):
|
||
self.configs = configs
|
||
else:
|
||
self.configs: TTS_Config = TTS_Config(configs)
|
||
|
||
self.t2s_model: Text2SemanticLightningModule = None
|
||
self.vits_model: Union[SynthesizerTrn, SynthesizerTrnV3] = None
|
||
self.bert_tokenizer: AutoTokenizer = None
|
||
self.bert_model: AutoModelForMaskedLM = None
|
||
self.cnhuhbert_model: CNHubert = None
|
||
self.vocoder = None
|
||
self.sr_model: AP_BWE = None
|
||
self.sv_model = None
|
||
self.sr_model_not_exist: bool = False
|
||
|
||
self.vocoder_configs: dict = {
|
||
"sr": None,
|
||
"T_ref": None,
|
||
"T_chunk": None,
|
||
"upsample_rate": None,
|
||
"overlapped_len": None,
|
||
}
|
||
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
|
||
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4")))
|
||
self.prepare_ref_audio_cpu_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_CPU_SLOTS", "8")))
|
||
self.prepare_bert_batch_worker = None
|
||
self.prepare_ref_semantic_batch_worker = None
|
||
self.prepare_text_cpu_worker = None
|
||
self.prepare_text_cpu_workers = max(
|
||
0,
|
||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")),
|
||
)
|
||
self.prepare_text_cpu_executor = None
|
||
|
||
self._init_models()
|
||
self.refresh_runtime_components()
|
||
|
||
self.prompt_cache: dict = {
|
||
"ref_audio_path": None,
|
||
"prompt_semantic": None,
|
||
"refer_spec": [],
|
||
"prompt_text": None,
|
||
"prompt_lang": None,
|
||
"phones": None,
|
||
"bert_features": None,
|
||
"norm_text": None,
|
||
"aux_ref_audio_paths": [],
|
||
}
|
||
|
||
self.stop_flag: bool = False
|
||
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||
|
||
def refresh_runtime_components(self):
|
||
self.prepare_bert_batch_worker = None
|
||
self.prepare_ref_semantic_batch_worker = None
|
||
self.prepare_text_cpu_worker = None
|
||
if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0":
|
||
self.prepare_bert_batch_worker = PrepareBertBatchWorker(
|
||
bert_model=self.bert_model,
|
||
tokenizer=self.bert_tokenizer,
|
||
device=self.configs.device,
|
||
stage_limiter=self.prepare_bert_stage_limiter,
|
||
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")),
|
||
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")),
|
||
max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")),
|
||
max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_BERT_MAX_PENDING_TASKS", "0")),
|
||
admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_ADMISSION_POLL_MS", "1")),
|
||
high_pressure_pending_threshold=int(
|
||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "0")
|
||
),
|
||
high_pressure_batch_window_ms=int(
|
||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_BATCH_WINDOW_MS", "1")
|
||
),
|
||
high_pressure_max_batch_items=int(
|
||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_ITEMS", "32")
|
||
),
|
||
high_pressure_max_batch_tokens=int(
|
||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_TOKENS", "8192")
|
||
),
|
||
)
|
||
if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0":
|
||
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES")
|
||
if ref_max_batch_samples is None:
|
||
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000")
|
||
self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker(
|
||
ssl_model=self.cnhuhbert_model,
|
||
vits_model=self.vits_model,
|
||
device=self.configs.device,
|
||
is_half=self.configs.is_half,
|
||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||
stage_limiter=self.prepare_ref_audio_stage_limiter,
|
||
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")),
|
||
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")),
|
||
max_batch_samples=int(ref_max_batch_samples),
|
||
)
|
||
|
||
self.text_preprocessor = TextPreprocessor(
|
||
self.bert_model,
|
||
self.bert_tokenizer,
|
||
self.configs.device,
|
||
version=self.configs.version,
|
||
bert_stage_limiter=self.prepare_bert_stage_limiter,
|
||
bert_batch_worker=self.prepare_bert_batch_worker,
|
||
)
|
||
if self.prepare_text_cpu_workers > 0:
|
||
self.prepare_text_cpu_worker = PrepareTextCpuWorker(
|
||
process_fn=lambda text, language: self.text_preprocessor.preprocess_text_segments(
|
||
text,
|
||
language,
|
||
self.configs.version,
|
||
),
|
||
worker_count=self.prepare_text_cpu_workers,
|
||
max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_PENDING_TASKS", "0")),
|
||
admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_ADMISSION_POLL_MS", "1")),
|
||
admission_controller=self._build_text_cpu_admission_state,
|
||
)
|
||
|
||
@staticmethod
|
||
def _safe_queue_qsize(executor) -> int | None:
|
||
if executor is None:
|
||
return None
|
||
queue = getattr(executor, "_work_queue", None)
|
||
if queue is None or not hasattr(queue, "qsize"):
|
||
return None
|
||
try:
|
||
return int(queue.qsize())
|
||
except Exception:
|
||
return None
|
||
|
||
def snapshot_prepare_runtime_components(self) -> dict:
|
||
g2pw_runtime = None
|
||
try:
|
||
from text import chinese2
|
||
|
||
g2pw_instance = getattr(chinese2, "g2pw", None)
|
||
g2pw_backend = None if g2pw_instance is None else getattr(g2pw_instance, "_g2pw", None)
|
||
if g2pw_backend is not None and hasattr(g2pw_backend, "snapshot"):
|
||
g2pw_runtime = dict(g2pw_backend.snapshot())
|
||
except Exception:
|
||
g2pw_runtime = None
|
||
return {
|
||
"text_cpu": {
|
||
"workers": int(self.prepare_text_cpu_workers),
|
||
"queue_size": self._safe_queue_qsize(self.prepare_text_cpu_executor),
|
||
"enabled": bool(self.prepare_text_cpu_worker is not None or self.prepare_text_cpu_executor is not None),
|
||
"worker": (
|
||
None if self.prepare_text_cpu_worker is None else dict(self.prepare_text_cpu_worker.snapshot())
|
||
),
|
||
"admission": self._build_text_cpu_admission_state(),
|
||
},
|
||
"bert": {
|
||
"stage_limiter": dict(self.prepare_bert_stage_limiter.snapshot()),
|
||
"batch_worker": (
|
||
None if self.prepare_bert_batch_worker is None else dict(self.prepare_bert_batch_worker.snapshot())
|
||
),
|
||
"batching_enabled": bool(self.prepare_bert_batch_worker is not None),
|
||
},
|
||
"ref_semantic": {
|
||
"stage_limiter": dict(self.prepare_ref_audio_stage_limiter.snapshot()),
|
||
"batch_worker": (
|
||
None
|
||
if self.prepare_ref_semantic_batch_worker is None
|
||
else dict(self.prepare_ref_semantic_batch_worker.snapshot())
|
||
),
|
||
"batching_enabled": bool(self.prepare_ref_semantic_batch_worker is not None),
|
||
},
|
||
"text_preprocessor": (
|
||
None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot()
|
||
),
|
||
"g2pw": g2pw_runtime,
|
||
}
|
||
|
||
def _build_text_cpu_admission_state(self) -> dict:
|
||
bert_pending_soft_max = max(
|
||
0,
|
||
int(
|
||
os.environ.get(
|
||
"GPTSOVITS_PREPARE_TEXT_CPU_BERT_PENDING_SOFT_MAX",
|
||
os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "32"),
|
||
)
|
||
),
|
||
)
|
||
if self.prepare_bert_batch_worker is None or bert_pending_soft_max <= 0:
|
||
return {
|
||
"blocked": False,
|
||
"reason": "",
|
||
"bert_pending": 0,
|
||
"bert_active_batch_size": 0,
|
||
"bert_pending_soft_max": int(bert_pending_soft_max),
|
||
}
|
||
bert_state = dict(self.prepare_bert_batch_worker.snapshot())
|
||
bert_pending = int(bert_state.get("pending", 0))
|
||
bert_active_batch_size = int(bert_state.get("active_batch_size", 0))
|
||
blocked = bert_pending >= bert_pending_soft_max
|
||
return {
|
||
"blocked": bool(blocked),
|
||
"reason": ("bert_pending" if blocked else ""),
|
||
"bert_pending": int(bert_pending),
|
||
"bert_active_batch_size": int(bert_active_batch_size),
|
||
"bert_pending_soft_max": int(bert_pending_soft_max),
|
||
}
|
||
|
||
def _init_models(
|
||
self,
|
||
):
|
||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||
self.init_vits_weights(self.configs.vits_weights_path)
|
||
self.init_bert_weights(self.configs.bert_base_path)
|
||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||
# self.enable_half_precision(self.configs.is_half)
|
||
|
||
def init_cnhuhbert_weights(self, base_path: str):
|
||
print(f"Loading CNHuBERT weights from {base_path}")
|
||
self.cnhuhbert_model = CNHubert(base_path)
|
||
self.cnhuhbert_model = self.cnhuhbert_model.eval()
|
||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||
|
||
def init_bert_weights(self, base_path: str):
|
||
print(f"Loading BERT weights from {base_path}")
|
||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||
self.bert_model = self.bert_model.eval()
|
||
self.bert_model = self.bert_model.to(self.configs.device)
|
||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||
self.bert_model = self.bert_model.half()
|
||
|
||
def init_vits_weights(self, weights_path: str):
|
||
self.configs.vits_weights_path = weights_path
|
||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
||
if "Pro" in model_version:
|
||
self.init_sv_model()
|
||
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
||
|
||
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
|
||
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
|
||
raise FileExistsError(info)
|
||
|
||
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
|
||
dict_s2 = load_sovits_new(weights_path)
|
||
hps = dict_s2["config"]
|
||
hps["model"]["semantic_frame_rate"] = "25hz"
|
||
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
|
||
hps["model"]["version"] = "v2" # v3model,v2sybomls
|
||
elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||
hps["model"]["version"] = "v1"
|
||
else:
|
||
hps["model"]["version"] = "v2"
|
||
version = hps["model"]["version"]
|
||
v3v4set = {"v3", "v4"}
|
||
if model_version not in v3v4set:
|
||
if "Pro" not in model_version:
|
||
model_version = version
|
||
else:
|
||
hps["model"]["version"] = model_version
|
||
else:
|
||
hps["model"]["version"] = model_version
|
||
|
||
self.configs.filter_length = hps["data"]["filter_length"]
|
||
self.configs.segment_size = hps["train"]["segment_size"]
|
||
self.configs.sampling_rate = hps["data"]["sampling_rate"]
|
||
self.configs.hop_length = hps["data"]["hop_length"]
|
||
self.configs.win_length = hps["data"]["win_length"]
|
||
self.configs.n_speakers = hps["data"]["n_speakers"]
|
||
self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"]
|
||
kwargs = hps["model"]
|
||
# print(f"self.configs.sampling_rate:{self.configs.sampling_rate}")
|
||
|
||
self.configs.update_version(model_version)
|
||
|
||
# print(f"model_version:{model_version}")
|
||
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
|
||
if model_version not in v3v4set:
|
||
vits_model = SynthesizerTrn(
|
||
self.configs.filter_length // 2 + 1,
|
||
self.configs.segment_size // self.configs.hop_length,
|
||
n_speakers=self.configs.n_speakers,
|
||
**kwargs,
|
||
)
|
||
self.configs.use_vocoder = False
|
||
else:
|
||
kwargs["version"] = model_version
|
||
vits_model = SynthesizerTrnV3(
|
||
self.configs.filter_length // 2 + 1,
|
||
self.configs.segment_size // self.configs.hop_length,
|
||
n_speakers=self.configs.n_speakers,
|
||
**kwargs,
|
||
)
|
||
self.configs.use_vocoder = True
|
||
self.init_vocoder(model_version)
|
||
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
|
||
del vits_model.enc_q
|
||
|
||
self.is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
|
||
|
||
if if_lora_v3 == False:
|
||
print(
|
||
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
|
||
)
|
||
else:
|
||
print(
|
||
f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits)['weight'], strict=False)}"
|
||
)
|
||
lora_rank = dict_s2["lora_rank"]
|
||
lora_config = LoraConfig(
|
||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||
r=lora_rank,
|
||
lora_alpha=lora_rank,
|
||
init_lora_weights=True,
|
||
)
|
||
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
|
||
print(
|
||
f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
|
||
)
|
||
|
||
vits_model.cfm = vits_model.cfm.merge_and_unload()
|
||
|
||
vits_model = vits_model.to(self.configs.device)
|
||
vits_model = vits_model.eval()
|
||
|
||
self.vits_model = vits_model
|
||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||
self.vits_model = self.vits_model.half()
|
||
|
||
self.configs.save_configs()
|
||
|
||
|
||
|
||
def init_t2s_weights(self, weights_path: str):
|
||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||
self.configs.t2s_weights_path = weights_path
|
||
self.configs.save_configs()
|
||
self.configs.hz = 50
|
||
dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False)
|
||
config = dict_s1["config"]
|
||
self.configs.max_sec = config["data"]["max_sec"]
|
||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||
t2s_model.load_state_dict(dict_s1["weight"])
|
||
t2s_model = t2s_model.to(self.configs.device)
|
||
t2s_model = t2s_model.eval()
|
||
self.t2s_model = t2s_model
|
||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||
self.t2s_model = self.t2s_model.half()
|
||
|
||
codebook = t2s_model.model.ar_audio_embedding.weight.clone()
|
||
mute_emb = codebook[self.configs.mute_tokens[self.configs.version]].unsqueeze(0)
|
||
sim_matrix = F.cosine_similarity(mute_emb.float(), codebook.float(), dim=-1)
|
||
self.configs.mute_emb_sim_matrix = sim_matrix
|
||
|
||
def init_vocoder(self, version: str):
|
||
if version == "v3":
|
||
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
|
||
return
|
||
if self.vocoder is not None:
|
||
self.vocoder.cpu()
|
||
del self.vocoder
|
||
self.empty_cache()
|
||
|
||
self.vocoder = BigVGAN.from_pretrained(
|
||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||
use_cuda_kernel=False,
|
||
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||
# remove weight norm in the model and set to eval mode
|
||
self.vocoder.remove_weight_norm()
|
||
|
||
self.vocoder_configs["sr"] = 24000
|
||
self.vocoder_configs["T_ref"] = 468
|
||
self.vocoder_configs["T_chunk"] = 934
|
||
self.vocoder_configs["upsample_rate"] = 256
|
||
self.vocoder_configs["overlapped_len"] = 12
|
||
|
||
elif version == "v4":
|
||
if self.vocoder is not None and self.vocoder.__class__.__name__ == "Generator":
|
||
return
|
||
if self.vocoder is not None:
|
||
self.vocoder.cpu()
|
||
del self.vocoder
|
||
self.empty_cache()
|
||
|
||
self.vocoder = Generator(
|
||
initial_channel=100,
|
||
resblock="1",
|
||
resblock_kernel_sizes=[3, 7, 11],
|
||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||
upsample_rates=[10, 6, 2, 2, 2],
|
||
upsample_initial_channel=512,
|
||
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||
gin_channels=0,
|
||
is_bias=True,
|
||
)
|
||
self.vocoder.remove_weight_norm()
|
||
state_dict_g = torch.load(
|
||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
|
||
map_location="cpu",
|
||
weights_only=False,
|
||
)
|
||
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
|
||
|
||
self.vocoder_configs["sr"] = 48000
|
||
self.vocoder_configs["T_ref"] = 500
|
||
self.vocoder_configs["T_chunk"] = 1000
|
||
self.vocoder_configs["upsample_rate"] = 480
|
||
self.vocoder_configs["overlapped_len"] = 12
|
||
|
||
self.vocoder = self.vocoder.eval()
|
||
if self.configs.is_half == True:
|
||
self.vocoder = self.vocoder.half().to(self.configs.device)
|
||
else:
|
||
self.vocoder = self.vocoder.to(self.configs.device)
|
||
|
||
def init_sr_model(self):
|
||
if self.sr_model is not None:
|
||
return
|
||
try:
|
||
self.sr_model: AP_BWE = AP_BWE(self.configs.device, DictToAttrRecursive)
|
||
self.sr_model_not_exist = False
|
||
except FileNotFoundError:
|
||
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
|
||
self.sr_model_not_exist = True
|
||
|
||
def init_sv_model(self):
|
||
if self.sv_model is not None:
|
||
return
|
||
self.sv_model = SV(self.configs.device, self.configs.is_half)
|
||
|
||
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||
"""
|
||
To enable half precision for the TTS model.
|
||
Args:
|
||
enable: bool, whether to enable half precision.
|
||
|
||
"""
|
||
if str(self.configs.device) == "cpu" and enable:
|
||
print("Half precision is not supported on CPU.")
|
||
return
|
||
|
||
self.configs.is_half = enable
|
||
self.precision = torch.float16 if enable else torch.float32
|
||
if save:
|
||
self.configs.save_configs()
|
||
if enable:
|
||
if self.t2s_model is not None:
|
||
self.t2s_model = self.t2s_model.half()
|
||
if self.vits_model is not None:
|
||
self.vits_model = self.vits_model.half()
|
||
if self.bert_model is not None:
|
||
self.bert_model = self.bert_model.half()
|
||
if self.cnhuhbert_model is not None:
|
||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||
if self.vocoder is not None:
|
||
self.vocoder = self.vocoder.half()
|
||
else:
|
||
if self.t2s_model is not None:
|
||
self.t2s_model = self.t2s_model.float()
|
||
if self.vits_model is not None:
|
||
self.vits_model = self.vits_model.float()
|
||
if self.bert_model is not None:
|
||
self.bert_model = self.bert_model.float()
|
||
if self.cnhuhbert_model is not None:
|
||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||
if self.vocoder is not None:
|
||
self.vocoder = self.vocoder.float()
|
||
|
||
def set_device(self, device: torch.device, save: bool = True):
|
||
"""
|
||
To set the device for all models.
|
||
Args:
|
||
device: torch.device, the device to use for all models.
|
||
"""
|
||
self.configs.device = device
|
||
if save:
|
||
self.configs.save_configs()
|
||
if self.t2s_model is not None:
|
||
self.t2s_model = self.t2s_model.to(device)
|
||
if self.vits_model is not None:
|
||
self.vits_model = self.vits_model.to(device)
|
||
if self.bert_model is not None:
|
||
self.bert_model = self.bert_model.to(device)
|
||
if self.cnhuhbert_model is not None:
|
||
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
|
||
if self.vocoder is not None:
|
||
self.vocoder = self.vocoder.to(device)
|
||
if self.sr_model is not None:
|
||
self.sr_model = self.sr_model.to(device)
|
||
|
||
def set_ref_audio(self, ref_audio_path: str):
|
||
"""
|
||
To set the reference audio for the TTS model,
|
||
including the prompt_semantic and refer_spepc.
|
||
Args:
|
||
ref_audio_path: str, the path of the reference audio.
|
||
"""
|
||
bundle = self.extract_ref_audio_bundle(ref_audio_path)
|
||
if self.prompt_cache["refer_spec"] in [[], None]:
|
||
self.prompt_cache["refer_spec"] = [bundle["refer_spec"]]
|
||
else:
|
||
self.prompt_cache["refer_spec"][0] = bundle["refer_spec"]
|
||
self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"]
|
||
self.prompt_cache["raw_audio"] = bundle["raw_audio"]
|
||
self.prompt_cache["raw_sr"] = bundle["raw_sr"]
|
||
self._set_ref_audio_path(ref_audio_path)
|
||
|
||
def _load_ref_audio_raw(self, ref_audio_path: str):
|
||
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
|
||
return raw_audio.float(), int(raw_sr)
|
||
|
||
@torch.inference_mode()
|
||
def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor):
|
||
wav16k = wav16k.to(self.configs.device)
|
||
if self.configs.is_half:
|
||
wav16k = wav16k.half()
|
||
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
|
||
codes = self.vits_model.extract_latent(hubert_feature)
|
||
return codes[0, 0].to(self.configs.device)
|
||
|
||
@torch.inference_mode()
|
||
def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor):
|
||
forward_start = time.perf_counter()
|
||
prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||
return prompt_semantic, forward_ms
|
||
|
||
@torch.inference_mode()
|
||
def _prepare_prompt_semantic_wav16k_profile(self, raw_audio: torch.Tensor, raw_sr: int):
|
||
limiter = getattr(self, "prepare_ref_audio_cpu_limiter", None)
|
||
if limiter is None:
|
||
cpu_prepare_start = time.perf_counter()
|
||
wav16k = prepare_prompt_semantic_wav16k(
|
||
raw_audio=raw_audio,
|
||
raw_sr=raw_sr,
|
||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||
)
|
||
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
||
return wav16k, cpu_prepare_ms, {"wait_ms": 0.0, "slots": 0.0, "peak_inflight": 0.0}
|
||
|
||
with limiter.enter() as limiter_stats:
|
||
cpu_prepare_start = time.perf_counter()
|
||
wav16k = prepare_prompt_semantic_wav16k(
|
||
raw_audio=raw_audio,
|
||
raw_sr=raw_sr,
|
||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||
)
|
||
cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0
|
||
return wav16k, cpu_prepare_ms, {
|
||
"wait_ms": float(limiter_stats.get("wait_ms", 0.0)),
|
||
"slots": float(limiter_stats.get("slots", 0.0)),
|
||
"peak_inflight": float(limiter_stats.get("peak_inflight", 0.0)),
|
||
}
|
||
|
||
@torch.inference_mode()
|
||
def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||
wav16k, cpu_prepare_ms, _ = self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
|
||
prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k)
|
||
return prompt_semantic, cpu_prepare_ms, forward_ms
|
||
|
||
@torch.inference_mode()
|
||
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||
prompt_semantic, _, _ = self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr)
|
||
return prompt_semantic
|
||
|
||
def extract_prompt_semantic(self, ref_wav_path: str):
|
||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path)
|
||
return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
|
||
|
||
def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||
raw_audio_device = raw_audio.to(self.configs.device).float()
|
||
|
||
if raw_sr != self.configs.sampling_rate:
|
||
audio = raw_audio_device
|
||
if audio.shape[0] == 2:
|
||
audio = audio.mean(0).unsqueeze(0)
|
||
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
||
else:
|
||
audio = raw_audio_device
|
||
if audio.shape[0] == 2:
|
||
audio = audio.mean(0).unsqueeze(0)
|
||
|
||
maxx = audio.abs().max()
|
||
if maxx > 1:
|
||
audio /= min(2, maxx)
|
||
spec = spectrogram_torch(
|
||
audio,
|
||
self.configs.filter_length,
|
||
self.configs.sampling_rate,
|
||
self.configs.hop_length,
|
||
self.configs.win_length,
|
||
center=False,
|
||
)
|
||
if self.configs.is_half:
|
||
spec = spec.half()
|
||
if self.is_v2pro == True:
|
||
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
|
||
if self.configs.is_half:
|
||
audio = audio.half()
|
||
else:
|
||
audio = None
|
||
return spec, audio, raw_audio, raw_sr
|
||
|
||
def extract_ref_spec(self, ref_audio_path: str):
|
||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||
return self._extract_ref_spec_from_raw(raw_audio, raw_sr)
|
||
|
||
def extract_ref_audio_bundle(self, ref_audio_path: str):
|
||
load_start = time.perf_counter()
|
||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||
load_ms = (time.perf_counter() - load_start) * 1000.0
|
||
if self.prepare_ref_semantic_batch_worker is None:
|
||
wav16k, prompt_semantic_cpu_prepare_ms, prompt_semantic_cpu_limiter_stats = (
|
||
self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr)
|
||
)
|
||
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
|
||
prompt_semantic_start = time.perf_counter()
|
||
prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(
|
||
wav16k
|
||
)
|
||
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
|
||
ref_spec_start = time.perf_counter()
|
||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||
audio_stage_wait_ms = float(limiter_stats["wait_ms"])
|
||
audio_stage_slots = float(limiter_stats["slots"])
|
||
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
|
||
prompt_semantic_profile = {
|
||
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
|
||
"prompt_semantic_cpu_prepare_wait_ms": float(prompt_semantic_cpu_limiter_stats.get("wait_ms", 0.0)),
|
||
"prompt_semantic_cpu_prepare_slots": float(prompt_semantic_cpu_limiter_stats.get("slots", 0.0)),
|
||
"prompt_semantic_cpu_prepare_inflight_peak": float(
|
||
prompt_semantic_cpu_limiter_stats.get("peak_inflight", 0.0)
|
||
),
|
||
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||
"prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]),
|
||
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
|
||
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms),
|
||
"prompt_semantic_forward_ms": float(prompt_semantic_forward_ms),
|
||
"prompt_semantic_scatter_ms": 0.0,
|
||
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
|
||
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||
"prompt_semantic_batch_size": 1.0,
|
||
"prompt_semantic_batch_samples": 0.0,
|
||
}
|
||
ref_spec_wait_ms = 0.0
|
||
return {
|
||
"prompt_semantic": prompt_semantic,
|
||
"refer_spec": refer_spec,
|
||
"raw_audio": raw_audio,
|
||
"raw_sr": raw_sr,
|
||
"profile": {
|
||
"audio_load_ms": load_ms,
|
||
"audio_stage_wait_ms": audio_stage_wait_ms,
|
||
"audio_stage_slots": audio_stage_slots,
|
||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||
"prompt_semantic_ms": prompt_semantic_ms,
|
||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||
"prompt_semantic_cpu_prepare_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_cpu_prepare_slots": float(
|
||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)
|
||
),
|
||
"prompt_semantic_cpu_prepare_inflight_peak": float(
|
||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
|
||
),
|
||
"prompt_semantic_worker_queue_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_batch_collect_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_stage_limiter_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||
),
|
||
"prompt_semantic_cpu_prepare_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||
),
|
||
"prompt_semantic_forward_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)
|
||
),
|
||
"prompt_semantic_scatter_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)
|
||
),
|
||
"prompt_semantic_stage_slots": float(
|
||
prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)
|
||
),
|
||
"prompt_semantic_stage_inflight_peak": float(
|
||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||
),
|
||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||
"prompt_semantic_batch_samples": float(
|
||
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
|
||
),
|
||
"ref_spec_wait_ms": ref_spec_wait_ms,
|
||
"ref_spec_ms": ref_spec_ms,
|
||
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
|
||
},
|
||
}
|
||
|
||
prompt_semantic_profile = {
|
||
"prompt_semantic_wait_ms": 0.0,
|
||
"prompt_semantic_cpu_prepare_wait_ms": 0.0,
|
||
"prompt_semantic_cpu_prepare_slots": float(getattr(self.prepare_ref_audio_cpu_limiter, "slots", 0.0)),
|
||
"prompt_semantic_cpu_prepare_inflight_peak": 0.0,
|
||
"prompt_semantic_worker_queue_wait_ms": 0.0,
|
||
"prompt_semantic_batch_collect_wait_ms": 0.0,
|
||
"prompt_semantic_stage_limiter_wait_ms": 0.0,
|
||
"prompt_semantic_batch_dispatch_delay_ms": 0.0,
|
||
"prompt_semantic_cpu_prepare_ms": 0.0,
|
||
"prompt_semantic_forward_ms": 0.0,
|
||
"prompt_semantic_scatter_ms": 0.0,
|
||
"prompt_semantic_stage_slots": 0.0,
|
||
"prompt_semantic_stage_inflight_peak": 0.0,
|
||
"prompt_semantic_batch_size": 1.0,
|
||
"prompt_semantic_batch_samples": 0.0,
|
||
}
|
||
if self.prepare_ref_semantic_batch_worker is not None:
|
||
prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr)
|
||
prompt_semantic_profile.update(worker_profile)
|
||
prompt_semantic_ms = (
|
||
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
|
||
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
|
||
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
|
||
)
|
||
with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats:
|
||
ref_spec_start = time.perf_counter()
|
||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||
audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float(
|
||
ref_spec_limiter_stats["wait_ms"]
|
||
)
|
||
audio_stage_slots = max(
|
||
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||
float(ref_spec_limiter_stats["slots"]),
|
||
)
|
||
audio_stage_inflight_peak = max(
|
||
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||
float(ref_spec_limiter_stats["peak_inflight"]),
|
||
)
|
||
return {
|
||
"prompt_semantic": prompt_semantic,
|
||
"refer_spec": refer_spec,
|
||
"raw_audio": raw_audio,
|
||
"raw_sr": raw_sr,
|
||
"profile": {
|
||
"audio_load_ms": load_ms,
|
||
"audio_stage_wait_ms": audio_stage_wait_ms,
|
||
"audio_stage_slots": audio_stage_slots,
|
||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||
"prompt_semantic_ms": prompt_semantic_ms,
|
||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||
"prompt_semantic_cpu_prepare_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_cpu_prepare_slots": float(
|
||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)
|
||
),
|
||
"prompt_semantic_cpu_prepare_inflight_peak": float(
|
||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0)
|
||
),
|
||
"prompt_semantic_worker_queue_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_batch_collect_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_stage_limiter_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||
),
|
||
"prompt_semantic_cpu_prepare_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||
),
|
||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||
"prompt_semantic_stage_inflight_peak": float(
|
||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||
),
|
||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||
"prompt_semantic_batch_samples": float(
|
||
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
|
||
),
|
||
"ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]),
|
||
"ref_spec_ms": ref_spec_ms,
|
||
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
|
||
},
|
||
}
|
||
|
||
async def extract_ref_audio_bundle_async(self, ref_audio_path: str):
|
||
if self.prepare_ref_semantic_batch_worker is None:
|
||
return await asyncio.to_thread(self.extract_ref_audio_bundle, ref_audio_path)
|
||
|
||
load_start = time.perf_counter()
|
||
raw_audio, raw_sr = await asyncio.to_thread(self._load_ref_audio_raw, ref_audio_path)
|
||
load_ms = (time.perf_counter() - load_start) * 1000.0
|
||
|
||
prompt_semantic_task = asyncio.create_task(
|
||
self.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr)
|
||
)
|
||
|
||
def _build_ref_spec_profile():
|
||
with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats:
|
||
ref_spec_start = time.perf_counter()
|
||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||
return refer_spec, {
|
||
"ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]),
|
||
"ref_spec_ms": float(ref_spec_ms),
|
||
"audio_stage_slots": float(ref_spec_limiter_stats["slots"]),
|
||
"audio_stage_inflight_peak": float(ref_spec_limiter_stats["peak_inflight"]),
|
||
}
|
||
|
||
ref_spec_task = asyncio.create_task(asyncio.to_thread(_build_ref_spec_profile))
|
||
(prompt_semantic, prompt_semantic_profile), (refer_spec, ref_spec_profile) = await asyncio.gather(
|
||
prompt_semantic_task,
|
||
ref_spec_task,
|
||
)
|
||
|
||
prompt_semantic_ms = (
|
||
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
|
||
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
|
||
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
|
||
)
|
||
audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float(
|
||
ref_spec_profile.get("ref_spec_wait_ms", 0.0)
|
||
)
|
||
audio_stage_slots = max(
|
||
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||
float(ref_spec_profile.get("audio_stage_slots", 0.0)),
|
||
)
|
||
audio_stage_inflight_peak = max(
|
||
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||
float(ref_spec_profile.get("audio_stage_inflight_peak", 0.0)),
|
||
)
|
||
return {
|
||
"prompt_semantic": prompt_semantic,
|
||
"refer_spec": refer_spec,
|
||
"raw_audio": raw_audio,
|
||
"raw_sr": raw_sr,
|
||
"profile": {
|
||
"audio_load_ms": float(load_ms),
|
||
"audio_stage_wait_ms": float(audio_stage_wait_ms),
|
||
"audio_stage_slots": float(audio_stage_slots),
|
||
"audio_stage_inflight_peak": float(audio_stage_inflight_peak),
|
||
"prompt_semantic_ms": float(prompt_semantic_ms),
|
||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||
"prompt_semantic_worker_queue_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_batch_collect_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_stage_limiter_wait_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0)
|
||
),
|
||
"prompt_semantic_batch_dispatch_delay_ms": float(
|
||
prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0)
|
||
),
|
||
"prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||
"prompt_semantic_stage_inflight_peak": float(
|
||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||
),
|
||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||
"prompt_semantic_batch_samples": float(prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)),
|
||
"ref_spec_wait_ms": float(ref_spec_profile.get("ref_spec_wait_ms", 0.0)),
|
||
"ref_spec_ms": float(ref_spec_profile.get("ref_spec_ms", 0.0)),
|
||
"bundle_total_ms": float(load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_profile.get("ref_spec_ms", 0.0)),
|
||
},
|
||
}
|
||
|
||
def extract_text_features(self, text: str, language: str, profile: dict | None = None):
|
||
return self.text_preprocessor.segment_and_extract_feature_for_text(
|
||
text, language, self.configs.version, profile=profile
|
||
)
|
||
|
||
def prepare_text_segments(self, text: str, language: str):
|
||
return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version)
|
||
|
||
def resolve_g2pw_segments(self, prepared_segments, profile: dict | None = None):
|
||
return self.text_preprocessor.resolve_g2pw_segments(prepared_segments, profile=profile)
|
||
|
||
def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None):
|
||
return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
|
||
|
||
async def build_text_features_from_segments_async(self, prepared_segments, profile: dict | None = None):
|
||
return await self.text_preprocessor.build_phones_and_bert_from_segments_async(
|
||
prepared_segments,
|
||
profile=profile,
|
||
)
|
||
|
||
async def build_text_feature_pair_from_segments_async(
|
||
self,
|
||
prompt_segments,
|
||
target_segments,
|
||
prompt_profile: dict | None = None,
|
||
target_profile: dict | None = None,
|
||
):
|
||
return await self.text_preprocessor.build_phones_and_bert_pair_from_segments_async(
|
||
prompt_segments,
|
||
target_segments,
|
||
prompt_profile=prompt_profile,
|
||
target_profile=target_profile,
|
||
)
|
||
|
||
def _set_ref_audio_path(self, ref_audio_path):
|
||
self.prompt_cache["ref_audio_path"] = ref_audio_path
|
||
|
||
def _set_ref_spec(self, ref_audio_path):
|
||
spec_audio = self._get_ref_spec(ref_audio_path)
|
||
if self.prompt_cache["refer_spec"] in [[], None]:
|
||
self.prompt_cache["refer_spec"] = [spec_audio]
|
||
else:
|
||
self.prompt_cache["refer_spec"][0] = spec_audio
|
||
|
||
def _get_ref_spec(self, ref_audio_path):
|
||
spec, audio, raw_audio, raw_sr = self.extract_ref_spec(ref_audio_path)
|
||
self.prompt_cache["raw_audio"] = raw_audio
|
||
self.prompt_cache["raw_sr"] = raw_sr
|
||
return spec, audio
|
||
|
||
def _set_prompt_semantic(self, ref_wav_path: str):
|
||
prompt_semantic = self.extract_prompt_semantic(ref_wav_path)
|
||
self.prompt_cache["prompt_semantic"] = prompt_semantic
|
||
|
||
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
|
||
seq = sequences[0]
|
||
ndim = seq.dim()
|
||
if axis < 0:
|
||
axis += ndim
|
||
dtype: torch.dtype = seq.dtype
|
||
pad_value = torch.tensor(pad_value, dtype=dtype)
|
||
seq_lengths = [seq.shape[axis] for seq in sequences]
|
||
if max_length is None:
|
||
max_length = max(seq_lengths)
|
||
else:
|
||
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
|
||
|
||
padded_sequences = []
|
||
for seq, length in zip(sequences, seq_lengths):
|
||
padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
|
||
padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value)
|
||
padded_sequences.append(padded_seq)
|
||
batch = torch.stack(padded_sequences)
|
||
return batch
|
||
|
||
def to_batch(
|
||
self,
|
||
data: list,
|
||
prompt_data: dict = None,
|
||
batch_size: int = 5,
|
||
threshold: float = 0.75,
|
||
split_bucket: bool = True,
|
||
device: torch.device = torch.device("cpu"),
|
||
precision: torch.dtype = torch.float32,
|
||
):
|
||
_data: list = []
|
||
index_and_len_list = []
|
||
for idx, item in enumerate(data):
|
||
norm_text_len = len(item["norm_text"])
|
||
index_and_len_list.append([idx, norm_text_len])
|
||
|
||
batch_index_list = []
|
||
if split_bucket:
|
||
index_and_len_list.sort(key=lambda x: x[1])
|
||
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
|
||
|
||
batch_index_list_len = 0
|
||
pos = 0
|
||
while pos < index_and_len_list.shape[0]:
|
||
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
|
||
pos_end = min(pos + batch_size, index_and_len_list.shape[0])
|
||
while pos < pos_end:
|
||
batch = index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||
score = batch[(pos_end - pos) // 2] / (batch.mean() + 1e-8)
|
||
if (score >= threshold) or (pos_end - pos == 1):
|
||
batch_index = index_and_len_list[pos:pos_end, 0].tolist()
|
||
batch_index_list_len += len(batch_index)
|
||
batch_index_list.append(batch_index)
|
||
pos = pos_end
|
||
break
|
||
pos_end = pos_end - 1
|
||
|
||
assert batch_index_list_len == len(data)
|
||
|
||
else:
|
||
for i in range(len(data)):
|
||
if i % batch_size == 0:
|
||
batch_index_list.append([])
|
||
batch_index_list[-1].append(i)
|
||
|
||
for batch_idx, index_list in enumerate(batch_index_list):
|
||
item_list = [data[idx] for idx in index_list]
|
||
phones_list = []
|
||
phones_len_list = []
|
||
# bert_features_list = []
|
||
all_phones_list = []
|
||
all_phones_len_list = []
|
||
all_bert_features_list = []
|
||
norm_text_batch = []
|
||
all_bert_max_len = 0
|
||
all_phones_max_len = 0
|
||
for item in item_list:
|
||
if prompt_data is not None:
|
||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1).to(
|
||
dtype=precision, device=device
|
||
)
|
||
all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]).to(device)
|
||
phones = torch.LongTensor(item["phones"]).to(device)
|
||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||
else:
|
||
all_bert_features = item["bert_features"].to(dtype=precision, device=device)
|
||
phones = torch.LongTensor(item["phones"]).to(device)
|
||
all_phones = phones
|
||
# norm_text = item["norm_text"]
|
||
|
||
all_bert_max_len = max(all_bert_max_len, all_bert_features.shape[-1])
|
||
all_phones_max_len = max(all_phones_max_len, all_phones.shape[-1])
|
||
|
||
phones_list.append(phones)
|
||
phones_len_list.append(phones.shape[-1])
|
||
all_phones_list.append(all_phones)
|
||
all_phones_len_list.append(all_phones.shape[-1])
|
||
all_bert_features_list.append(all_bert_features)
|
||
norm_text_batch.append(item["norm_text"])
|
||
|
||
phones_batch = phones_list
|
||
all_phones_batch = all_phones_list
|
||
all_bert_features_batch = all_bert_features_list
|
||
|
||
max_len = max(all_bert_max_len, all_phones_max_len)
|
||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||
#### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||
# all_bert_features_batch = all_bert_features_list
|
||
# all_bert_features_batch = torch.zeros((len(all_bert_features_list), 1024, max_len), dtype=precision, device=device)
|
||
# for idx, item in enumerate(all_bert_features_list):
|
||
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||
|
||
# #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
|
||
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
|
||
# all_phones_batch = torch.stack(all_phones_list, dim=0)
|
||
|
||
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
|
||
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
|
||
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
|
||
|
||
batch = {
|
||
"phones": phones_batch,
|
||
"phones_len": torch.LongTensor(phones_len_list).to(device),
|
||
"all_phones": all_phones_batch,
|
||
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
|
||
"all_bert_features": all_bert_features_batch,
|
||
"norm_text": norm_text_batch,
|
||
"max_len": max_len,
|
||
}
|
||
_data.append(batch)
|
||
|
||
return _data, batch_index_list
|
||
|
||
def recovery_order(self, data: list, batch_index_list: list) -> list:
|
||
"""
|
||
Recovery the order of the audio according to the batch_index_list.
|
||
|
||
Args:
|
||
data (List[list(torch.Tensor)]): the out of order audio .
|
||
batch_index_list (List[list[int]]): the batch index list.
|
||
|
||
Returns:
|
||
list (List[torch.Tensor]): the data in the original order.
|
||
"""
|
||
length = len(sum(batch_index_list, []))
|
||
_data = [None] * length
|
||
for i, index_list in enumerate(batch_index_list):
|
||
for j, index in enumerate(index_list):
|
||
_data[index] = data[i][j]
|
||
return _data
|
||
|
||
def stop(
|
||
self,
|
||
):
|
||
"""
|
||
Stop the inference process.
|
||
"""
|
||
self.stop_flag = True
|
||
|
||
@torch.no_grad()
|
||
def run(self, inputs: dict):
|
||
"""
|
||
Text to speech inference.
|
||
|
||
Args:
|
||
inputs (dict):
|
||
{
|
||
"text": "", # str.(required) text to be synthesized
|
||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||
"ref_audio_path": "", # str.(required) reference audio path
|
||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||
"top_k": 15, # int. top k sampling
|
||
"top_p": 1, # float. top p sampling
|
||
"temperature": 1, # float. temperature for sampling
|
||
"text_split_method": "cut1", # str. text split method, see text_segmentation_method.py for details.
|
||
"batch_size": 1, # int. batch size for inference
|
||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||
"seed": -1, # int. random seed for reproducibility.
|
||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||
"return_fragment": False, # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
||
"streaming_mode": False, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||
"fixed_length_chunk": False, # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
||
}
|
||
returns:
|
||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||
"""
|
||
########## variables initialization ###########
|
||
self.stop_flag: bool = False
|
||
text: str = inputs.get("text", "")
|
||
text_lang: str = inputs.get("text_lang", "")
|
||
ref_audio_path: str = inputs.get("ref_audio_path", "")
|
||
aux_ref_audio_paths: list = inputs.get("aux_ref_audio_paths", [])
|
||
prompt_text: str = inputs.get("prompt_text", "")
|
||
prompt_lang: str = inputs.get("prompt_lang", "")
|
||
top_k: int = inputs.get("top_k", 15)
|
||
top_p: float = inputs.get("top_p", 1)
|
||
temperature: float = inputs.get("temperature", 1)
|
||
text_split_method: str = inputs.get("text_split_method", "cut1")
|
||
batch_size = inputs.get("batch_size", 1)
|
||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||
speed_factor = inputs.get("speed_factor", 1.0)
|
||
split_bucket = inputs.get("split_bucket", True)
|
||
return_fragment = inputs.get("return_fragment", False)
|
||
fragment_interval = inputs.get("fragment_interval", 0.3)
|
||
seed = inputs.get("seed", -1)
|
||
seed = -1 if seed in ["", None] else seed
|
||
actual_seed = set_seed(seed)
|
||
parallel_infer = inputs.get("parallel_infer", True)
|
||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||
sample_steps = inputs.get("sample_steps", 32)
|
||
super_sampling = inputs.get("super_sampling", False)
|
||
streaming_mode = inputs.get("streaming_mode", False)
|
||
overlap_length = inputs.get("overlap_length", 2)
|
||
min_chunk_length = inputs.get("min_chunk_length", 16)
|
||
fixed_length_chunk = inputs.get("fixed_length_chunk", False)
|
||
chunk_split_thershold = 0.0 # 该值代表语义token与mute token的余弦相似度阈值,若大于该阈值,则视为可切分点。
|
||
|
||
if parallel_infer and not streaming_mode:
|
||
print(i18n("并行推理模式已开启"))
|
||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||
elif not parallel_infer and streaming_mode and not self.configs.use_vocoder:
|
||
print(i18n("流式推理模式已开启"))
|
||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||
elif streaming_mode and self.configs.use_vocoder:
|
||
print(i18n("SoVits V3/4模型不支持流式推理模式,已自动回退到分段返回模式"))
|
||
streaming_mode = False
|
||
return_fragment = True
|
||
if parallel_infer:
|
||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||
else:
|
||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||
# self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||
elif parallel_infer and streaming_mode:
|
||
print(i18n("不支持同时开启并行推理和流式推理模式,已自动关闭并行推理模式"))
|
||
parallel_infer = False
|
||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive
|
||
else:
|
||
print(i18n("朴素推理模式已开启"))
|
||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||
|
||
if return_fragment and streaming_mode:
|
||
print(i18n("流式推理模式不支持分段返回,已自动关闭分段返回"))
|
||
return_fragment = False
|
||
|
||
if (return_fragment or streaming_mode) and split_bucket:
|
||
print(i18n("分段返回模式/流式推理模式不支持分桶处理,已自动关闭分桶处理"))
|
||
split_bucket = False
|
||
|
||
|
||
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
|
||
print(i18n("分桶处理模式已开启"))
|
||
elif speed_factor != 1.0:
|
||
print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理"))
|
||
split_bucket = False
|
||
elif self.configs.use_vocoder and parallel_infer:
|
||
print(i18n("当开启并行推理模式时,SoVits V3/4模型不支持分桶处理,已自动关闭分桶处理"))
|
||
split_bucket = False
|
||
else:
|
||
print(i18n("分桶处理模式已关闭"))
|
||
|
||
# if fragment_interval < 0.01:
|
||
# fragment_interval = 0.01
|
||
# print(i18n("分段间隔过小,已自动设置为0.01"))
|
||
|
||
no_prompt_text = False
|
||
if prompt_text in [None, ""]:
|
||
no_prompt_text = True
|
||
|
||
assert text_lang in self.configs.languages
|
||
if not no_prompt_text:
|
||
assert prompt_lang in self.configs.languages
|
||
|
||
if no_prompt_text and self.configs.use_vocoder:
|
||
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
|
||
|
||
if ref_audio_path in [None, ""] and (
|
||
(self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])
|
||
):
|
||
raise ValueError(
|
||
"ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()"
|
||
)
|
||
|
||
###### setting reference audio and prompt text preprocessing ########
|
||
t0 = time.perf_counter()
|
||
if (ref_audio_path is not None) and (
|
||
ref_audio_path != self.prompt_cache["ref_audio_path"]
|
||
or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)
|
||
):
|
||
if not os.path.exists(ref_audio_path):
|
||
raise ValueError(f"{ref_audio_path} not exists")
|
||
self.set_ref_audio(ref_audio_path)
|
||
|
||
aux_ref_audio_paths = aux_ref_audio_paths if aux_ref_audio_paths is not None else []
|
||
paths = set(aux_ref_audio_paths) & set(self.prompt_cache["aux_ref_audio_paths"])
|
||
if not (len(list(paths)) == len(aux_ref_audio_paths) == len(self.prompt_cache["aux_ref_audio_paths"])):
|
||
self.prompt_cache["aux_ref_audio_paths"] = aux_ref_audio_paths
|
||
self.prompt_cache["refer_spec"] = [self.prompt_cache["refer_spec"][0]]
|
||
for path in aux_ref_audio_paths:
|
||
if path in [None, ""]:
|
||
continue
|
||
if not os.path.exists(path):
|
||
print(i18n("音频文件不存在,跳过:"), path)
|
||
continue
|
||
self.prompt_cache["refer_spec"].append(self._get_ref_spec(path))
|
||
|
||
if not no_prompt_text:
|
||
prompt_text = prompt_text.strip("\n")
|
||
if prompt_text[-1] not in splits:
|
||
prompt_text += "。" if prompt_lang != "en" else "."
|
||
print(i18n("实际输入的参考文本:"), prompt_text)
|
||
if self.prompt_cache["prompt_text"] != prompt_text:
|
||
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(
|
||
prompt_text, prompt_lang, self.configs.version
|
||
)
|
||
self.prompt_cache["prompt_text"] = prompt_text
|
||
self.prompt_cache["prompt_lang"] = prompt_lang
|
||
self.prompt_cache["phones"] = phones
|
||
self.prompt_cache["bert_features"] = bert_features
|
||
self.prompt_cache["norm_text"] = norm_text
|
||
|
||
###### text preprocessing ########
|
||
t1 = time.perf_counter()
|
||
data: list = None
|
||
if not (return_fragment or streaming_mode):
|
||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
||
if len(data) == 0:
|
||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||
return
|
||
|
||
batch_index_list: list = None
|
||
data, batch_index_list = self.to_batch(
|
||
data,
|
||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||
batch_size=batch_size,
|
||
threshold=batch_threshold,
|
||
split_bucket=split_bucket,
|
||
device=self.configs.device,
|
||
precision=self.precision,
|
||
)
|
||
else:
|
||
print(f"############ {i18n('切分文本')} ############")
|
||
texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
|
||
data = []
|
||
for i in range(len(texts)):
|
||
if i % batch_size == 0:
|
||
data.append([])
|
||
data[-1].append(texts[i])
|
||
|
||
def make_batch(batch_texts):
|
||
batch_data = []
|
||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||
for text in tqdm(batch_texts):
|
||
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(
|
||
text, text_lang, self.configs.version
|
||
)
|
||
if phones is None:
|
||
continue
|
||
res = {
|
||
"phones": phones,
|
||
"bert_features": bert_features,
|
||
"norm_text": norm_text,
|
||
}
|
||
batch_data.append(res)
|
||
if len(batch_data) == 0:
|
||
return None
|
||
batch, _ = self.to_batch(
|
||
batch_data,
|
||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||
batch_size=batch_size,
|
||
threshold=batch_threshold,
|
||
split_bucket=False,
|
||
device=self.configs.device,
|
||
precision=self.precision,
|
||
)
|
||
return batch[0]
|
||
|
||
t2 = time.perf_counter()
|
||
try:
|
||
print("############ 推理 ############")
|
||
###### inference ######
|
||
t_34 = 0.0
|
||
t_45 = 0.0
|
||
t2s_observe_batch_count = 0
|
||
t2s_observe_fastpath_hits = 0
|
||
t2s_observe_generated_tokens = 0
|
||
audio = []
|
||
is_first_package = True
|
||
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
||
for item in data:
|
||
t3 = time.perf_counter()
|
||
if return_fragment or streaming_mode:
|
||
item = make_batch(item)
|
||
if item is None:
|
||
continue
|
||
|
||
batch_phones: List[torch.LongTensor] = item["phones"]
|
||
# batch_phones:torch.LongTensor = item["phones"]
|
||
batch_phones_len: torch.LongTensor = item["phones_len"]
|
||
all_phoneme_ids: torch.LongTensor = item["all_phones"]
|
||
all_phoneme_lens: torch.LongTensor = item["all_phones_len"]
|
||
all_bert_features: torch.LongTensor = item["all_bert_features"]
|
||
norm_text: str = item["norm_text"]
|
||
max_len = item["max_len"]
|
||
|
||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
||
if no_prompt_text:
|
||
prompt = None
|
||
else:
|
||
prompt = (
|
||
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||
)
|
||
|
||
refer_audio_spec = []
|
||
|
||
sv_emb = [] if self.is_v2pro else None
|
||
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
||
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
||
refer_audio_spec.append(spec)
|
||
if self.is_v2pro:
|
||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
||
|
||
if not streaming_mode:
|
||
print(f"############ {i18n('预测语义Token')} ############")
|
||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||
all_phoneme_ids,
|
||
all_phoneme_lens,
|
||
prompt,
|
||
all_bert_features,
|
||
# prompt_phone_len=ph_offset,
|
||
top_k=top_k,
|
||
top_p=top_p,
|
||
temperature=temperature,
|
||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||
max_len=max_len,
|
||
repetition_penalty=repetition_penalty,
|
||
)
|
||
t4 = time.perf_counter()
|
||
t_34 += t4 - t3
|
||
if hasattr(self.t2s_model.model, "get_last_infer_stats"):
|
||
t2s_stats = self.t2s_model.model.get_last_infer_stats()
|
||
if t2s_stats:
|
||
generated_token_count = int(t2s_stats.get("generated_token_count", 0))
|
||
t2s_total_ms = (t4 - t3) * 1000.0
|
||
avg_decode_ms_per_token = (
|
||
t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0
|
||
)
|
||
t2s_observe_batch_count += 1
|
||
t2s_observe_generated_tokens += generated_token_count
|
||
if bool(t2s_stats.get("fastpath_hit", False)):
|
||
t2s_observe_fastpath_hits += 1
|
||
print(
|
||
"[t2s_observe] "
|
||
f"mode={t2s_stats.get('infer_mode')} "
|
||
f"batch_size={t2s_stats.get('batch_size')} "
|
||
f"tokens={generated_token_count} "
|
||
f"t2s_ms={t2s_total_ms:.3f} "
|
||
f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} "
|
||
f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} "
|
||
f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} "
|
||
f"fastpath_hit={t2s_stats.get('fastpath_hit')}"
|
||
)
|
||
|
||
|
||
batch_audio_fragment = []
|
||
|
||
# ## vits并行推理 method 1
|
||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||
# max_len = 0
|
||
# for i in range(0, len(batch_phones)):
|
||
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||
# batch_phones = batch_phones.to(self.configs.device)
|
||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
|
||
# ))
|
||
print(f"############ {i18n('合成音频')} ############")
|
||
if not self.configs.use_vocoder:
|
||
if speed_factor == 1.0:
|
||
print(f"{i18n('并行合成中')}...")
|
||
# ## vits并行推理 method 2
|
||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||
audio_frag_idx = [
|
||
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
||
for i in range(0, len(pred_semantic_list))
|
||
]
|
||
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
||
all_pred_semantic = (
|
||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||
)
|
||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||
|
||
_batch_audio_fragment = self.vits_model.decode(
|
||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||
).detach()[0, 0, :]
|
||
|
||
audio_frag_end_idx.insert(0, 0)
|
||
batch_audio_fragment = [
|
||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||
for i in range(1, len(audio_frag_end_idx))
|
||
]
|
||
else:
|
||
# ## vits串行推理
|
||
for i, idx in enumerate(tqdm(idx_list)):
|
||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||
_pred_semantic = (
|
||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||
audio_fragment = self.vits_model.decode(
|
||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||
).detach()[0, 0, :]
|
||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||
else:
|
||
if parallel_infer:
|
||
print(f"{i18n('并行合成中')}...")
|
||
audio_fragments = self.using_vocoder_synthesis_batched_infer(
|
||
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||
)
|
||
batch_audio_fragment.extend(audio_fragments)
|
||
else:
|
||
for i, idx in enumerate(tqdm(idx_list)):
|
||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||
_pred_semantic = (
|
||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||
audio_fragment = self.using_vocoder_synthesis(
|
||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||
)
|
||
batch_audio_fragment.append(audio_fragment)
|
||
|
||
else:
|
||
# refer_audio_spec: torch.Tensor = [
|
||
# item.to(dtype=self.precision, device=self.configs.device)
|
||
# for item in self.prompt_cache["refer_spec"]
|
||
# ]
|
||
semantic_token_generator =self.t2s_model.model.infer_panel(
|
||
all_phoneme_ids[0].unsqueeze(0),
|
||
all_phoneme_lens,
|
||
prompt,
|
||
all_bert_features[0].unsqueeze(0),
|
||
top_k=top_k,
|
||
top_p=top_p,
|
||
temperature=temperature,
|
||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||
max_len=max_len,
|
||
repetition_penalty=repetition_penalty,
|
||
streaming_mode=True,
|
||
chunk_length=min_chunk_length,
|
||
mute_emb_sim_matrix=self.configs.mute_emb_sim_matrix if not fixed_length_chunk else None,
|
||
chunk_split_thershold=chunk_split_thershold,
|
||
)
|
||
t4 = time.perf_counter()
|
||
t_34 += t4 - t3
|
||
phones = batch_phones[0].unsqueeze(0).to(self.configs.device)
|
||
is_first_chunk = True
|
||
|
||
if not self.configs.use_vocoder:
|
||
# if speed_factor == 1.0:
|
||
# upsample_rate = math.prod(self.vits_model.upsample_rates)*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1)
|
||
# else:
|
||
upsample_rate = math.prod(self.vits_model.upsample_rates)*((2 if self.vits_model.semantic_frame_rate == "25hz" else 1)/speed_factor)
|
||
else:
|
||
# if speed_factor == 1.0:
|
||
# upsample_rate = self.vocoder_configs["upsample_rate"]*(3.875 if self.configs.version == "v3" else 4)
|
||
# else:
|
||
upsample_rate = self.vocoder_configs["upsample_rate"]*((3.875 if self.configs.version == "v3" else 4)/speed_factor)
|
||
|
||
last_audio_chunk = None
|
||
# last_tokens = None
|
||
last_latent = None
|
||
previous_tokens = []
|
||
overlap_len = overlap_length
|
||
overlap_size = math.ceil(overlap_length*upsample_rate)
|
||
for semantic_tokens, is_final in semantic_token_generator:
|
||
if semantic_tokens is None and last_audio_chunk is not None:
|
||
yield self.audio_postprocess(
|
||
[[last_audio_chunk[-overlap_size:]]],
|
||
output_sr,
|
||
None,
|
||
speed_factor,
|
||
False,
|
||
0.0,
|
||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||
)
|
||
break
|
||
|
||
_semantic_tokens = semantic_tokens
|
||
print(f"semantic_tokens shape:{semantic_tokens.shape}")
|
||
|
||
previous_tokens.append(semantic_tokens)
|
||
|
||
_semantic_tokens = torch.cat(previous_tokens, dim=-1)
|
||
|
||
if not is_first_chunk and semantic_tokens.shape[-1] < 10:
|
||
overlap_len = overlap_length+(10-semantic_tokens.shape[-1])
|
||
else:
|
||
overlap_len = overlap_length
|
||
|
||
|
||
if not self.configs.use_vocoder:
|
||
token_padding_length = 0
|
||
# token_padding_length = int(phones.shape[-1]*2)-_semantic_tokens.shape[-1]
|
||
# if token_padding_length>0:
|
||
# _semantic_tokens = F.pad(_semantic_tokens, (0, token_padding_length), "constant", 486)
|
||
# else:
|
||
# token_padding_length = 0
|
||
|
||
audio_chunk, latent, latent_mask = self.vits_model.decode_streaming(
|
||
_semantic_tokens.unsqueeze(0),
|
||
phones, refer_audio_spec,
|
||
speed=speed_factor,
|
||
sv_emb=sv_emb,
|
||
result_length=semantic_tokens.shape[-1]+overlap_len if not is_first_chunk else None,
|
||
overlap_frames=last_latent[:,:,-overlap_len*(2 if self.vits_model.semantic_frame_rate == "25hz" else 1):] \
|
||
if last_latent is not None else None,
|
||
padding_length=token_padding_length
|
||
)
|
||
audio_chunk=audio_chunk.detach()[0, 0, :]
|
||
else:
|
||
raise RuntimeError(i18n("SoVits V3/4模型不支持流式推理模式"))
|
||
|
||
if overlap_len>overlap_length:
|
||
audio_chunk=audio_chunk[-int((overlap_length+semantic_tokens.shape[-1])*upsample_rate):]
|
||
|
||
audio_chunk_ = audio_chunk
|
||
if is_first_chunk and not is_final:
|
||
is_first_chunk = False
|
||
audio_chunk_ = audio_chunk_[:-overlap_size]
|
||
elif is_first_chunk and is_final:
|
||
is_first_chunk = False
|
||
elif not is_first_chunk and not is_final:
|
||
audio_chunk_ = self.sola_algorithm([last_audio_chunk, audio_chunk_], overlap_size)
|
||
audio_chunk_ = (
|
||
audio_chunk_[last_audio_chunk.shape[0]-overlap_size:-overlap_size] if not is_final \
|
||
else audio_chunk_[last_audio_chunk.shape[0]-overlap_size:]
|
||
)
|
||
|
||
last_latent = latent
|
||
last_audio_chunk = audio_chunk
|
||
yield self.audio_postprocess(
|
||
[[audio_chunk_]],
|
||
output_sr,
|
||
None,
|
||
speed_factor,
|
||
False,
|
||
0.0,
|
||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||
)
|
||
|
||
if is_first_package:
|
||
print(f"first_package_delay: {time.perf_counter()-t0:.3f}")
|
||
is_first_package = False
|
||
|
||
|
||
yield output_sr, np.zeros(int(output_sr*fragment_interval), dtype=np.int16)
|
||
|
||
t5 = time.perf_counter()
|
||
t_45 += t5 - t4
|
||
if return_fragment:
|
||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||
yield self.audio_postprocess(
|
||
[batch_audio_fragment],
|
||
output_sr,
|
||
None,
|
||
speed_factor,
|
||
False,
|
||
fragment_interval,
|
||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||
)
|
||
elif streaming_mode:...
|
||
else:
|
||
audio.append(batch_audio_fragment)
|
||
|
||
if self.stop_flag:
|
||
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||
return
|
||
|
||
if not (return_fragment or streaming_mode):
|
||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||
if t2s_observe_batch_count > 0:
|
||
request_avg_decode_ms_per_token = (
|
||
(t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0
|
||
)
|
||
print(
|
||
"[t2s_request_observe] "
|
||
f"batches={t2s_observe_batch_count} "
|
||
f"fastpath_hits={t2s_observe_fastpath_hits} "
|
||
f"generated_tokens={t2s_observe_generated_tokens} "
|
||
f"t2s_total_ms={t_34 * 1000.0:.3f} "
|
||
f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}"
|
||
)
|
||
if len(audio) == 0:
|
||
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||
return
|
||
yield self.audio_postprocess(
|
||
audio,
|
||
output_sr,
|
||
batch_index_list,
|
||
speed_factor,
|
||
split_bucket,
|
||
fragment_interval,
|
||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||
)
|
||
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
# 必须返回一个空音频, 否则会导致显存不释放。
|
||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||
# 重置模型, 否则会导致显存释放不完全。
|
||
del self.t2s_model
|
||
del self.vits_model
|
||
self.t2s_model = None
|
||
self.vits_model = None
|
||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||
self.init_vits_weights(self.configs.vits_weights_path)
|
||
raise e
|
||
finally:
|
||
self.empty_cache()
|
||
|
||
def empty_cache(self):
|
||
try:
|
||
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
|
||
if "cuda" in str(self.configs.device):
|
||
torch.cuda.empty_cache()
|
||
elif str(self.configs.device) == "mps":
|
||
torch.mps.empty_cache()
|
||
except:
|
||
pass
|
||
|
||
def audio_postprocess(
|
||
self,
|
||
audio: List[torch.Tensor],
|
||
sr: int,
|
||
batch_index_list: list = None,
|
||
speed_factor: float = 1.0,
|
||
split_bucket: bool = True,
|
||
fragment_interval: float = 0.3,
|
||
super_sampling: bool = False,
|
||
) -> Tuple[int, np.ndarray]:
|
||
if fragment_interval>0:
|
||
zero_wav = torch.zeros(
|
||
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
||
)
|
||
|
||
for i, batch in enumerate(audio):
|
||
for j, audio_fragment in enumerate(batch):
|
||
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
|
||
if max_audio > 1:
|
||
audio_fragment /= max_audio
|
||
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) if fragment_interval>0 else audio_fragment
|
||
audio[i][j] = audio_fragment
|
||
|
||
if split_bucket:
|
||
audio = self.recovery_order(audio, batch_index_list)
|
||
else:
|
||
# audio = [item for batch in audio for item in batch]
|
||
audio = sum(audio, [])
|
||
|
||
audio = torch.cat(audio, dim=0)
|
||
|
||
if super_sampling:
|
||
print(f"############ {i18n('音频超采样')} ############")
|
||
t1 = time.perf_counter()
|
||
self.init_sr_model()
|
||
if not self.sr_model_not_exist:
|
||
audio, sr = self.sr_model(audio.unsqueeze(0), sr)
|
||
if isinstance(audio, torch.Tensor):
|
||
max_audio = float(torch.abs(audio).max().item())
|
||
else:
|
||
max_audio = float(np.abs(audio).max())
|
||
if max_audio > 1:
|
||
audio /= max_audio
|
||
t2 = time.perf_counter()
|
||
print(f"超采样用时:{t2 - t1:.3f}s")
|
||
if isinstance(audio, torch.Tensor):
|
||
audio = audio.detach().float().cpu().numpy()
|
||
else:
|
||
audio = np.asarray(audio)
|
||
audio = (audio.reshape(-1) * 32768).astype(np.int16)
|
||
|
||
|
||
# try:
|
||
# if speed_factor != 1.0:
|
||
# audio = speed_change(audio, speed=speed_factor, sr=int(sr))
|
||
# except Exception as e:
|
||
# print(f"Failed to change speed of audio: \n{e}")
|
||
|
||
return sr, audio
|
||
|
||
def using_vocoder_synthesis(
|
||
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
|
||
):
|
||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||
raw_entry = self.prompt_cache["refer_spec"][0]
|
||
if isinstance(raw_entry, tuple):
|
||
raw_entry = raw_entry[0]
|
||
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
|
||
|
||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||
ref_sr = self.prompt_cache["raw_sr"]
|
||
ref_audio = ref_audio.to(self.configs.device).float()
|
||
if ref_audio.shape[0] == 2:
|
||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||
|
||
# tgt_sr = self.vocoder_configs["sr"]
|
||
tgt_sr = 24000 if self.configs.version == "v3" else 32000
|
||
if ref_sr != tgt_sr:
|
||
ref_audio = resample(ref_audio, ref_sr, tgt_sr, self.configs.device)
|
||
|
||
mel2 = mel_fn(ref_audio) if self.configs.version == "v3" else mel_fn_v4(ref_audio)
|
||
mel2 = norm_spec(mel2)
|
||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||
mel2 = mel2[:, :, :T_min]
|
||
fea_ref = fea_ref[:, :, :T_min]
|
||
T_ref = self.vocoder_configs["T_ref"]
|
||
T_chunk = self.vocoder_configs["T_chunk"]
|
||
if T_min > T_ref:
|
||
mel2 = mel2[:, :, -T_ref:]
|
||
fea_ref = fea_ref[:, :, -T_ref:]
|
||
T_min = T_ref
|
||
chunk_len = T_chunk - T_min
|
||
|
||
mel2 = mel2.to(self.precision)
|
||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||
|
||
cfm_resss = []
|
||
idx = 0
|
||
while 1:
|
||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||
if fea_todo_chunk.shape[-1] == 0:
|
||
break
|
||
idx += chunk_len
|
||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||
|
||
cfm_res = self.vits_model.cfm.inference(
|
||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||
)
|
||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||
|
||
mel2 = cfm_res[:, :, -T_min:]
|
||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||
|
||
cfm_resss.append(cfm_res)
|
||
cfm_res = torch.cat(cfm_resss, 2)
|
||
cfm_res = denorm_spec(cfm_res)
|
||
|
||
with torch.inference_mode():
|
||
wav_gen = self.vocoder(cfm_res)
|
||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||
|
||
return audio
|
||
|
||
def using_vocoder_synthesis_request_local(
|
||
self,
|
||
semantic_tokens: torch.Tensor,
|
||
phones: torch.Tensor,
|
||
prompt_semantic: torch.Tensor,
|
||
prompt_phones: torch.Tensor,
|
||
refer_audio_spec: torch.Tensor,
|
||
raw_audio: torch.Tensor,
|
||
raw_sr: int,
|
||
speed: float = 1.0,
|
||
sample_steps: int = 32,
|
||
):
|
||
prompt_semantic_tokens = prompt_semantic.unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||
prompt_phones = prompt_phones.unsqueeze(0).to(self.configs.device)
|
||
refer_audio_spec = refer_audio_spec.to(dtype=self.precision, device=self.configs.device)
|
||
|
||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||
ref_audio = raw_audio.to(self.configs.device).float()
|
||
if ref_audio.shape[0] == 2:
|
||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||
|
||
tgt_sr = 24000 if self.configs.version == "v3" else 32000
|
||
if raw_sr != tgt_sr:
|
||
ref_audio = resample(ref_audio, raw_sr, tgt_sr, self.configs.device)
|
||
|
||
mel_spec_fn = mel_fn if self.configs.version == "v3" else mel_fn_v4
|
||
mel2 = mel_spec_fn(ref_audio)
|
||
mel2 = norm_spec(mel2)
|
||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||
mel2 = mel2[:, :, :T_min]
|
||
fea_ref = fea_ref[:, :, :T_min]
|
||
T_ref = self.vocoder_configs["T_ref"]
|
||
T_chunk = self.vocoder_configs["T_chunk"]
|
||
if T_min > T_ref:
|
||
mel2 = mel2[:, :, -T_ref:]
|
||
fea_ref = fea_ref[:, :, -T_ref:]
|
||
T_min = T_ref
|
||
chunk_len = T_chunk - T_min
|
||
|
||
mel2 = mel2.to(self.precision)
|
||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||
|
||
cfm_resss = []
|
||
idx = 0
|
||
while 1:
|
||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||
if fea_todo_chunk.shape[-1] == 0:
|
||
break
|
||
idx += chunk_len
|
||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||
|
||
cfm_res = self.vits_model.cfm.inference(
|
||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||
)
|
||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||
|
||
mel2 = cfm_res[:, :, -T_min:]
|
||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||
|
||
cfm_resss.append(cfm_res)
|
||
cfm_res = torch.cat(cfm_resss, 2)
|
||
cfm_res = denorm_spec(cfm_res)
|
||
|
||
with torch.inference_mode():
|
||
wav_gen = self.vocoder(cfm_res)
|
||
audio = wav_gen[0][0]
|
||
|
||
return audio
|
||
|
||
@torch.inference_mode()
|
||
def synthesize_audio_request_local(
|
||
self,
|
||
semantic_tokens: torch.Tensor,
|
||
phones: torch.Tensor,
|
||
prompt_semantic: torch.Tensor,
|
||
prompt_phones: torch.Tensor,
|
||
refer_spec: tuple | List[tuple],
|
||
raw_audio: torch.Tensor,
|
||
raw_sr: int,
|
||
speed: float = 1.0,
|
||
sample_steps: int = 32,
|
||
):
|
||
refer_specs = list(refer_spec) if isinstance(refer_spec, list) else [refer_spec]
|
||
refer_audio_spec, audio_tensor = refer_specs[0]
|
||
if not self.configs.use_vocoder:
|
||
refer_audio_spec_list = [item[0].to(dtype=self.precision, device=self.configs.device) for item in refer_specs]
|
||
sv_emb = None
|
||
if self.is_v2pro:
|
||
sv_emb = []
|
||
for _, audio_tensor_item in refer_specs:
|
||
if audio_tensor_item is None:
|
||
raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频"))
|
||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor_item).to(self.configs.device))
|
||
return self.vits_model.decode(
|
||
semantic_tokens,
|
||
phones,
|
||
refer_audio_spec_list,
|
||
speed=speed,
|
||
sv_emb=sv_emb,
|
||
).detach()[0, 0, :]
|
||
|
||
return self.using_vocoder_synthesis_request_local(
|
||
semantic_tokens=semantic_tokens,
|
||
phones=phones,
|
||
prompt_semantic=prompt_semantic,
|
||
prompt_phones=prompt_phones,
|
||
refer_audio_spec=refer_audio_spec,
|
||
raw_audio=raw_audio,
|
||
raw_sr=raw_sr,
|
||
speed=speed,
|
||
sample_steps=sample_steps,
|
||
)
|
||
|
||
@torch.inference_mode()
|
||
def synthesize_audio_requests_local_batched(
|
||
self,
|
||
semantic_tokens_list: List[torch.Tensor],
|
||
phones_list: List[torch.Tensor],
|
||
refer_specs: List[tuple | List[tuple]],
|
||
speeds: List[float] | None = None,
|
||
sample_steps_list: List[int] | None = None,
|
||
) -> List[torch.Tensor]:
|
||
batch_size = len(semantic_tokens_list)
|
||
if batch_size == 0:
|
||
return []
|
||
if len(phones_list) != batch_size or len(refer_specs) != batch_size:
|
||
raise ValueError("batched request-local synthesis 输入长度不一致")
|
||
if speeds is None:
|
||
speeds = [1.0] * batch_size
|
||
if sample_steps_list is None:
|
||
sample_steps_list = [32] * batch_size
|
||
if len(speeds) != batch_size or len(sample_steps_list) != batch_size:
|
||
raise ValueError("batched request-local synthesis 参数长度不一致")
|
||
first_speed = float(speeds[0])
|
||
first_sample_steps = int(sample_steps_list[0])
|
||
if any(abs(float(item) - first_speed) > 1e-6 for item in speeds):
|
||
raise ValueError("batched request-local synthesis 目前要求 speed 一致")
|
||
if any(int(item) != first_sample_steps for item in sample_steps_list):
|
||
raise ValueError("batched request-local synthesis 目前要求 sample_steps 一致")
|
||
if self.configs.use_vocoder:
|
||
raise NotImplementedError("request-local batched VITS synthesis 暂不支持 vocoder 模型")
|
||
|
||
device = self.configs.device
|
||
max_semantic_len = max(int(item.shape[-1]) for item in semantic_tokens_list)
|
||
max_phone_len = max(int(item.shape[-1]) for item in phones_list)
|
||
semantic_batch = torch.zeros((1, batch_size, max_semantic_len), dtype=torch.long, device=device)
|
||
phone_batch = torch.zeros((batch_size, max_phone_len), dtype=torch.long, device=device)
|
||
semantic_lengths = []
|
||
phone_lengths = []
|
||
refer_audio_specs: List[torch.Tensor] = []
|
||
sv_emb_batch = None
|
||
sv_emb_list: List[torch.Tensor] = []
|
||
|
||
for batch_index, semantic_tokens in enumerate(semantic_tokens_list):
|
||
semantic_len = int(semantic_tokens.shape[-1])
|
||
phone_len = int(phones_list[batch_index].shape[-1])
|
||
semantic_batch[0, batch_index, :semantic_len] = semantic_tokens.to(device=device, dtype=torch.long)
|
||
phone_batch[batch_index, :phone_len] = phones_list[batch_index].to(device=device, dtype=torch.long)
|
||
semantic_lengths.append(semantic_len)
|
||
phone_lengths.append(phone_len)
|
||
|
||
refer_spec_item = refer_specs[batch_index]
|
||
refer_spec_group = list(refer_spec_item) if isinstance(refer_spec_item, list) else [refer_spec_item]
|
||
if len(refer_spec_group) != 1:
|
||
raise ValueError("batched request-local synthesis 暂不支持单请求多参考音频")
|
||
refer_audio_spec, audio_tensor = refer_spec_group[0]
|
||
refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device))
|
||
if self.is_v2pro:
|
||
if audio_tensor is None:
|
||
raise ValueError(i18n("v2Pro request-local batched synthesis 缺少 16k 参考音频"))
|
||
sv_emb_list.append(self.sv_model.compute_embedding3(audio_tensor).to(device))
|
||
|
||
if self.is_v2pro:
|
||
sv_emb_batch = torch.cat(sv_emb_list, dim=0)
|
||
|
||
audio_batch, audio_lengths = self.vits_model.decode_batched_request_local(
|
||
codes=semantic_batch,
|
||
code_lengths=torch.LongTensor(semantic_lengths).to(device),
|
||
text=phone_batch,
|
||
text_lengths=torch.LongTensor(phone_lengths).to(device),
|
||
refer_list=refer_audio_specs,
|
||
speed=first_speed,
|
||
sv_emb=sv_emb_batch,
|
||
)
|
||
audios: List[torch.Tensor] = []
|
||
for batch_index in range(batch_size):
|
||
audio_len = int(audio_lengths[batch_index].item())
|
||
audios.append(audio_batch[batch_index, 0, :audio_len].detach())
|
||
return audios
|
||
|
||
def using_vocoder_synthesis_batched_infer(
|
||
self,
|
||
idx_list: List[int],
|
||
semantic_tokens_list: List[torch.Tensor],
|
||
batch_phones: List[torch.Tensor],
|
||
speed: float = 1.0,
|
||
sample_steps: int = 32,
|
||
) -> List[torch.Tensor]:
|
||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||
raw_entry = self.prompt_cache["refer_spec"][0]
|
||
if isinstance(raw_entry, tuple):
|
||
raw_entry = raw_entry[0]
|
||
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
|
||
|
||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||
ref_sr = self.prompt_cache["raw_sr"]
|
||
ref_audio = ref_audio.to(self.configs.device).float()
|
||
if ref_audio.shape[0] == 2:
|
||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||
|
||
# tgt_sr = self.vocoder_configs["sr"]
|
||
tgt_sr = 24000 if self.configs.version == "v3" else 32000
|
||
if ref_sr != tgt_sr:
|
||
ref_audio = resample(ref_audio, ref_sr, tgt_sr, self.configs.device)
|
||
|
||
mel2 = mel_fn(ref_audio) if self.configs.version == "v3" else mel_fn_v4(ref_audio)
|
||
mel2 = norm_spec(mel2)
|
||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||
mel2 = mel2[:, :, :T_min]
|
||
fea_ref = fea_ref[:, :, :T_min]
|
||
T_ref = self.vocoder_configs["T_ref"]
|
||
T_chunk = self.vocoder_configs["T_chunk"]
|
||
if T_min > T_ref:
|
||
mel2 = mel2[:, :, -T_ref:]
|
||
fea_ref = fea_ref[:, :, -T_ref:]
|
||
T_min = T_ref
|
||
chunk_len = T_chunk - T_min
|
||
|
||
mel2 = mel2.to(self.precision)
|
||
|
||
# #### batched inference
|
||
overlapped_len = self.vocoder_configs["overlapped_len"]
|
||
feat_chunks = []
|
||
feat_lens = []
|
||
feat_list = []
|
||
|
||
for i, idx in enumerate(idx_list):
|
||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||
semantic_tokens = (
|
||
semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||
feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||
feat_list.append(feat)
|
||
feat_lens.append(feat.shape[2])
|
||
|
||
feats = torch.cat(feat_list, 2)
|
||
feats_padded = F.pad(feats, (overlapped_len, 0), "constant", 0)
|
||
pos = 0
|
||
padding_len = 0
|
||
while True:
|
||
if pos == 0:
|
||
chunk = feats_padded[:, :, pos : pos + chunk_len]
|
||
else:
|
||
pos = pos - overlapped_len
|
||
chunk = feats_padded[:, :, pos : pos + chunk_len]
|
||
pos += chunk_len
|
||
if chunk.shape[-1] == 0:
|
||
break
|
||
|
||
# padding for the last chunk
|
||
padding_len = chunk_len - chunk.shape[2]
|
||
if padding_len != 0:
|
||
chunk = F.pad(chunk, (0, padding_len), "constant", 0)
|
||
feat_chunks.append(chunk)
|
||
|
||
feat_chunks = torch.cat(feat_chunks, 0)
|
||
bs = feat_chunks.shape[0]
|
||
fea_ref = fea_ref.repeat(bs, 1, 1)
|
||
fea = torch.cat([fea_ref, feat_chunks], 2).transpose(2, 1)
|
||
pred_spec = self.vits_model.cfm.inference(
|
||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||
)
|
||
pred_spec = pred_spec[:, :, -chunk_len:]
|
||
dd = pred_spec.shape[1]
|
||
pred_spec = pred_spec.permute(1, 0, 2).contiguous().view(dd, -1).unsqueeze(0)
|
||
# pred_spec = pred_spec[..., :-padding_len]
|
||
|
||
pred_spec = denorm_spec(pred_spec)
|
||
|
||
with torch.no_grad():
|
||
wav_gen = self.vocoder(pred_spec)
|
||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||
|
||
audio_fragments = []
|
||
upsample_rate = self.vocoder_configs["upsample_rate"]
|
||
pos = 0
|
||
|
||
while pos < audio.shape[-1]:
|
||
audio_fragment = audio[pos : pos + chunk_len * upsample_rate]
|
||
audio_fragments.append(audio_fragment)
|
||
pos += chunk_len * upsample_rate
|
||
|
||
audio = self.sola_algorithm(audio_fragments, overlapped_len * upsample_rate)
|
||
audio = audio[overlapped_len * upsample_rate : -padding_len * upsample_rate]
|
||
|
||
audio_fragments = []
|
||
for feat_len in feat_lens:
|
||
audio_fragment = audio[: feat_len * upsample_rate]
|
||
audio_fragments.append(audio_fragment)
|
||
audio = audio[feat_len * upsample_rate :]
|
||
|
||
return audio_fragments
|
||
|
||
def sola_algorithm(
|
||
self,
|
||
audio_fragments: List[torch.Tensor],
|
||
overlap_len: int,
|
||
search_len:int= 320
|
||
):
|
||
# overlap_len-=search_len
|
||
|
||
dtype = audio_fragments[0].dtype
|
||
|
||
for i in range(len(audio_fragments) - 1):
|
||
f1 = audio_fragments[i].float()
|
||
f2 = audio_fragments[i + 1].float()
|
||
w1 = f1[-overlap_len:]
|
||
w2 = f2[:overlap_len+search_len]
|
||
# w2 = w2[-w2.shape[-1]//2:]
|
||
# assert w1.shape == w2.shape
|
||
corr_norm = F.conv1d(w2.view(1, 1, -1), w1.view(1, 1, -1)).view(-1)
|
||
|
||
corr_den = F.conv1d(w2.view(1, 1, -1)**2, torch.ones_like(w1).view(1, 1, -1)).view(-1)+ 1e-8
|
||
idx = (corr_norm/corr_den.sqrt()).argmax()
|
||
|
||
print(f"seg_idx: {idx}")
|
||
|
||
# idx = corr.argmax()
|
||
f1_ = f1[: -overlap_len]
|
||
audio_fragments[i] = f1_
|
||
|
||
f2_ = f2[idx:]
|
||
window = torch.hann_window((overlap_len) * 2, device=f1.device, dtype=f1.dtype)
|
||
f2_[: overlap_len] = (
|
||
window[: overlap_len] * f2_[: overlap_len]
|
||
+ window[overlap_len :] * f1[-overlap_len :]
|
||
)
|
||
|
||
# window = torch.sin(torch.arange((overlap_len - idx), device=f1.device) * np.pi / (overlap_len - idx))
|
||
# f2_[: (overlap_len - idx)] = (
|
||
# window * f2_[: (overlap_len - idx)]
|
||
# + (1-window) * f1[-(overlap_len - idx) :]
|
||
# )
|
||
|
||
audio_fragments[i + 1] = f2_
|
||
|
||
return torch.cat(audio_fragments, 0).to(dtype)
|