Merge 7478a8a85622d2a4ff953b168cdd3748ccc49ce3 into 35e755427da174037da246642cab6987876c74fa

This commit is contained in:
Kevin Zhang 2024-05-16 16:20:40 +08:00 committed by GitHub
commit a93d60cb1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 797 additions and 288 deletions

View File

@ -34,9 +34,6 @@ RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
fi fi
# Copy the rest of the application
COPY . /workspace
# Copy the rest of the application # Copy the rest of the application
COPY . /workspace COPY . /workspace

View File

@ -5,6 +5,7 @@ import random
import traceback import traceback
from tqdm import tqdm from tqdm import tqdm
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
import ffmpeg import ffmpeg
@ -26,6 +27,7 @@ from my_utils import load_audio
from module.mel_processing import spectrogram_torch from module.mel_processing import spectrogram_torch
from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor
i18n = I18nAuto() i18n = I18nAuto()
# configs/tts_infer.yaml # configs/tts_infer.yaml
@ -49,7 +51,8 @@ custom:
""" """
def set_seed(seed:int):
def set_seed(seed: int):
seed = int(seed) seed = int(seed)
seed = seed if seed != -1 else random.randrange(1 << 32) seed = seed if seed != -1 else random.randrange(1 << 32)
print(f"Set seed to {seed}") print(f"Set seed to {seed}")
@ -70,50 +73,52 @@ def set_seed(seed:int):
except: except:
pass pass
return seed return seed
class TTS_Config: class TTS_Config:
default_configs={ default_configs = {
"device": "cpu", "device": "cpu",
"is_half": False, "is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
} "load_base": True,
configs:dict = None }
def __init__(self, configs: Union[dict, str]=None): configs: dict = None
def __init__(self, configs: Union[dict, str] = None):
# 设置默认配置文件路径 # 设置默认配置文件路径
configs_base_path:str = "GPT_SoVITS/configs/" configs_base_path: str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True) os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") self.configs_path: str = os.path.join(configs_base_path, "tts_infer.yaml")
if configs in ["", None]: if configs in ["", None]:
if not os.path.exists(self.configs_path): if not os.path.exists(self.configs_path):
self.save_configs() self.save_configs()
print(f"Create default config file at {self.configs_path}") print(f"Create default config file at {self.configs_path}")
configs:dict = {"default": deepcopy(self.default_configs)} configs: dict = {"default": deepcopy(self.default_configs)}
if isinstance(configs, str): if isinstance(configs, str):
self.configs_path = configs self.configs_path = configs
configs:dict = self._load_configs(self.configs_path) configs: dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict) assert isinstance(configs, dict)
default_configs:dict = configs.get("default", None) default_configs: dict = configs.get("default", None)
if default_configs is not None: if default_configs is not None:
self.default_configs = default_configs self.default_configs = default_configs
self.configs:dict = configs.get("custom", deepcopy(self.default_configs)) self.configs: dict = configs.get("custom", deepcopy(self.default_configs))
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None) self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
self.load_base = self.configs.get("load_base", True)
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs['t2s_weights_path'] self.t2s_weights_path = self.default_configs['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
@ -127,34 +132,32 @@ class TTS_Config:
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path'] self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs() self.update_configs()
self.max_sec = None
self.hz:int = 50
self.semantic_frame_rate:str = "25hz"
self.segment_size:int = 20480
self.filter_length:int = 2048
self.sampling_rate:int = 32000
self.hop_length:int = 640
self.win_length:int = 2048
self.n_speakers:int = 300
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
self.max_sec = None
def _load_configs(self, configs_path: str)->dict: self.hz: int = 50
self.semantic_frame_rate: str = "25hz"
self.segment_size: int = 20480
self.filter_length: int = 2048
self.sampling_rate: int = 32000
self.hop_length: int = 640
self.win_length: int = 2048
self.n_speakers: int = 300
self.languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
def _load_configs(self, configs_path: str) -> dict:
with open(configs_path, 'r') as f: with open(configs_path, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader) configs = yaml.load(f, Loader=yaml.FullLoader)
return configs return configs
def save_configs(self, configs_path:str=None)->None: def save_configs(self, configs_path: str = None) -> None:
configs={ configs = {
"default":self.default_configs, "default": self.default_configs,
} }
if self.configs is not None: if self.configs is not None:
configs["custom"] = self.update_configs() configs["custom"] = self.update_configs()
if configs_path is None: if configs_path is None:
configs_path = self.configs_path configs_path = self.configs_path
with open(configs_path, 'w') as f: with open(configs_path, 'w') as f:
@ -162,15 +165,16 @@ class TTS_Config:
def update_configs(self): def update_configs(self):
self.config = { self.config = {
"device" : str(self.device), "device": str(self.device),
"is_half" : self.is_half, "is_half": self.is_half,
"t2s_weights_path" : self.t2s_weights_path, "t2s_weights_path": self.t2s_weights_path,
"vits_weights_path" : self.vits_weights_path, "vits_weights_path": self.vits_weights_path,
"bert_base_path" : self.bert_base_path, "bert_base_path": self.bert_base_path,
"cnhuhbert_base_path": self.cnhuhbert_base_path, "cnhuhbert_base_path": self.cnhuhbert_base_path,
"load_base": self.load_base,
} }
return self.config return self.config
def __str__(self): def __str__(self):
self.configs = self.update_configs() self.configs = self.update_configs()
string = "TTS Config".center(100, '-') + '\n' string = "TTS Config".center(100, '-') + '\n'
@ -178,75 +182,87 @@ class TTS_Config:
string += f"{str(k).ljust(20)}: {str(v)}\n" string += f"{str(k).ljust(20)}: {str(v)}\n"
string += "-" * 100 + '\n' string += "-" * 100 + '\n'
return string return string
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
def __hash__(self):
return hash(self.configs_path)
def __eq__(self, other):
return isinstance(other, TTS_Config) and self.configs_path == other.configs_path
class TTS: class TTS:
bert_tokenizer: AutoTokenizer = None
bert_model: AutoModelForMaskedLM = None
cnhuhbert_model: CNHubert = None
def __init__(self, configs: Union[dict, str, TTS_Config]): def __init__(self, configs: Union[dict, str, TTS_Config]):
if isinstance(configs, TTS_Config): if isinstance(configs, TTS_Config):
self.configs = configs self.configs = configs
else: else:
self.configs:TTS_Config = TTS_Config(configs) self.configs: TTS_Config = TTS_Config(configs)
self.t2s_model:Text2SemanticLightningModule = None
self.vits_model:SynthesizerTrn = None
self.bert_tokenizer:AutoTokenizer = None
self.bert_model:AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None
self._init_models()
self.text_preprocessor:TextPreprocessor = \
TextPreprocessor(self.bert_model,
self.bert_tokenizer,
self.configs.device)
self.prompt_cache:dict = {
"ref_audio_path" : None,
"prompt_semantic": None,
"refer_spec" : None,
"prompt_text" : None,
"prompt_lang" : None,
"phones" : None,
"bert_features" : None,
"norm_text" : None,
}
self.stop_flag:bool = False
self.precision:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
def _init_models(self,): self.t2s_model: Text2SemanticLightningModule = None
self.vits_model: SynthesizerTrn = None
# self.bert_tokenizer:AutoTokenizer = None
# self.bert_model:AutoModelForMaskedLM = None
# self.cnhuhbert_model:CNHubert = None
self._init_models()
self.text_preprocessor: TextPreprocessor = \
TextPreprocessor(TTS.bert_model,
TTS.bert_tokenizer,
self.configs.device)
self.prompt_cache: dict = {
"ref_audio_path": None,
"prompt_semantic": None,
"refer_spec": None,
"prompt_text": None,
"prompt_lang": None,
"phones": None,
"bert_features": None,
"norm_text": None,
}
self.stop_flag: bool = False
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
def _init_models(self):
self.init_t2s_weights(self.configs.t2s_weights_path) self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path) self.init_vits_weights(self.configs.vits_weights_path)
self.init_bert_weights(self.configs.bert_base_path) if self.configs.load_base:
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) TTS.init_bert_weights(self.configs)
TTS.init_cnhuhbert_weights(self.configs)
# self.enable_half_precision(self.configs.is_half) # self.enable_half_precision(self.configs.is_half)
@staticmethod
def init_base_models(configs: TTS_Config):
def init_cnhuhbert_weights(self, base_path: str): TTS.init_bert_weights(configs)
print(f"Loading CNHuBERT weights from {base_path}") TTS.init_cnhuhbert_weights(configs)
self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model=self.cnhuhbert_model.eval() @staticmethod
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) def init_cnhuhbert_weights(configs: TTS_Config):
if self.configs.is_half and str(self.configs.device)!="cpu": print(f"Loading CNHuBERT weights from {configs.cnhuhbert_base_path}")
self.cnhuhbert_model = self.cnhuhbert_model.half() TTS.cnhuhbert_model = CNHubert(configs.cnhuhbert_base_path)
TTS.cnhuhbert_model = TTS.cnhuhbert_model.eval()
TTS.cnhuhbert_model = TTS.cnhuhbert_model.to(configs.device)
if configs.is_half and str(configs.device) != "cpu":
def init_bert_weights(self, base_path: str): TTS.cnhuhbert_model = TTS.cnhuhbert_model.half()
print(f"Loading BERT weights from {base_path}")
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) @staticmethod
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) def init_bert_weights(configs: TTS_Config):
self.bert_model=self.bert_model.eval() print(f"Loading BERT weights from {configs.bert_base_path}")
self.bert_model = self.bert_model.to(self.configs.device) TTS.bert_tokenizer = AutoTokenizer.from_pretrained(configs.bert_base_path)
if self.configs.is_half and str(self.configs.device)!="cpu": TTS.bert_model = AutoModelForMaskedLM.from_pretrained(configs.bert_base_path)
self.bert_model = self.bert_model.half() TTS.bert_model = TTS.bert_model.eval()
TTS.bert_model = TTS.bert_model.to(configs.device)
if configs.is_half and str(configs.device) != "cpu":
TTS.bert_model = TTS.bert_model.half()
def init_vits_weights(self, weights_path: str): def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}") print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path self.configs.vits_weights_path = weights_path
@ -255,7 +271,7 @@ class TTS:
hps = dict_s2["config"] hps = dict_s2["config"]
self.configs.filter_length = hps["data"]["filter_length"] self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"] self.configs.segment_size = hps["train"]["segment_size"]
self.configs.sampling_rate = hps["data"]["sampling_rate"] self.configs.sampling_rate = hps["data"]["sampling_rate"]
self.configs.hop_length = hps["data"]["hop_length"] self.configs.hop_length = hps["data"]["hop_length"]
self.configs.win_length = hps["data"]["win_length"] self.configs.win_length = hps["data"]["win_length"]
self.configs.n_speakers = hps["data"]["n_speakers"] self.configs.n_speakers = hps["data"]["n_speakers"]
@ -270,15 +286,14 @@ class TTS:
# if ("pretrained" not in weights_path): # if ("pretrained" not in weights_path):
if hasattr(vits_model, "enc_q"): if hasattr(vits_model, "enc_q"):
del vits_model.enc_q del vits_model.enc_q
vits_model = vits_model.to(self.configs.device) vits_model = vits_model.to(self.configs.device)
vits_model = vits_model.eval() vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False) vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model self.vits_model = vits_model
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.vits_model = self.vits_model.half() self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str): def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}") print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path self.configs.t2s_weights_path = weights_path
@ -292,9 +307,9 @@ class TTS:
t2s_model = t2s_model.to(self.configs.device) t2s_model = t2s_model.to(self.configs.device)
t2s_model = t2s_model.eval() t2s_model = t2s_model.eval()
self.t2s_model = t2s_model self.t2s_model = t2s_model
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half() self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True): def enable_half_precision(self, enable: bool = True):
''' '''
To enable half precision for the TTS model. To enable half precision for the TTS model.
@ -305,29 +320,29 @@ class TTS:
if str(self.configs.device) == "cpu" and enable: if str(self.configs.device) == "cpu" and enable:
print("Half precision is not supported on CPU.") print("Half precision is not supported on CPU.")
return return
self.configs.is_half = enable self.configs.is_half = enable
self.precision = torch.float16 if enable else torch.float32 self.precision = torch.float16 if enable else torch.float32
self.configs.save_configs() self.configs.save_configs()
if enable: if enable:
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model =self.t2s_model.half() self.t2s_model = self.t2s_model.half()
if self.vits_model is not None: if self.vits_model is not None:
self.vits_model = self.vits_model.half() self.vits_model = self.vits_model.half()
if self.bert_model is not None: if TTS.bert_model is not None:
self.bert_model =self.bert_model.half() TTS.bert_model = TTS.bert_model.half()
if self.cnhuhbert_model is not None: if TTS.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half() TTS.cnhuhbert_model = TTS.cnhuhbert_model.half()
else: else:
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model = self.t2s_model.float() self.t2s_model = self.t2s_model.float()
if self.vits_model is not None: if self.vits_model is not None:
self.vits_model = self.vits_model.float() self.vits_model = self.vits_model.float()
if self.bert_model is not None: if TTS.bert_model is not None:
self.bert_model = self.bert_model.float() TTS.bert_model = TTS.bert_model.float()
if self.cnhuhbert_model is not None: if TTS.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float() TTS.cnhuhbert_model = TTS.cnhuhbert_model.float()
def set_device(self, device: torch.device): def set_device(self, device: torch.device):
''' '''
To set the device for all models. To set the device for all models.
@ -340,12 +355,12 @@ class TTS:
self.t2s_model = self.t2s_model.to(device) self.t2s_model = self.t2s_model.to(device)
if self.vits_model is not None: if self.vits_model is not None:
self.vits_model = self.vits_model.to(device) self.vits_model = self.vits_model.to(device)
if self.bert_model is not None: if TTS.bert_model is not None:
self.bert_model = self.bert_model.to(device) TTS.bert_model = TTS.bert_model.to(device)
if self.cnhuhbert_model is not None: if TTS.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device) TTS.cnhuhbert_model = TTS.cnhuhbert_model.to(device)
def set_ref_audio(self, ref_audio_path:str): def set_ref_audio(self, ref_audio_path: str):
''' '''
To set the reference audio for the TTS model, To set the reference audio for the TTS model,
including the prompt_semantic and refer_spepc. including the prompt_semantic and refer_spepc.
@ -354,7 +369,7 @@ class TTS:
''' '''
self._set_prompt_semantic(ref_audio_path) self._set_prompt_semantic(ref_audio_path)
self._set_ref_spec(ref_audio_path) self._set_ref_spec(ref_audio_path)
def _set_ref_spec(self, ref_audio_path): def _set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
@ -373,9 +388,8 @@ class TTS:
spec = spec.half() spec = spec.half()
# self.refer_spec = spec # self.refer_spec = spec
self.prompt_cache["refer_spec"] = spec self.prompt_cache["refer_spec"] = spec
def _set_prompt_semantic(self, ref_wav_path: str):
def _set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros( zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3), int(self.configs.sampling_rate * 0.3),
dtype=np.float16 if self.configs.is_half else np.float32, dtype=np.float16 if self.configs.is_half else np.float32,
@ -399,16 +413,16 @@ class TTS:
1, 2 1, 2
) # .float() ) # .float()
codes = self.vits_model.extract_latent(hubert_feature) codes = self.vits_model.extract_latent(hubert_feature)
prompt_semantic = codes[0, 0].to(self.configs.device) prompt_semantic = codes[0, 0].to(self.configs.device)
self.prompt_cache["prompt_semantic"] = prompt_semantic self.prompt_cache["prompt_semantic"] = prompt_semantic
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None): def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
seq = sequences[0] seq = sequences[0]
ndim = seq.dim() ndim = seq.dim()
if axis < 0: if axis < 0:
axis += ndim axis += ndim
dtype:torch.dtype = seq.dtype dtype: torch.dtype = seq.dtype
pad_value = torch.tensor(pad_value, dtype=dtype) pad_value = torch.tensor(pad_value, dtype=dtype)
seq_lengths = [seq.shape[axis] for seq in sequences] seq_lengths = [seq.shape[axis] for seq in sequences]
if max_length is None: if max_length is None:
@ -423,16 +437,16 @@ class TTS:
padded_sequences.append(padded_seq) padded_sequences.append(padded_seq)
batch = torch.stack(padded_sequences) batch = torch.stack(padded_sequences)
return batch return batch
def to_batch(self, data:list, def to_batch(self, data: list,
prompt_data:dict=None, prompt_data: dict = None,
batch_size:int=5, batch_size: int = 5,
threshold:float=0.75, threshold: float = 0.75,
split_bucket:bool=True, split_bucket: bool = True,
device:torch.device=torch.device("cpu"), device: torch.device = torch.device("cpu"),
precision:torch.dtype=torch.float32, precision: torch.dtype = torch.float32,
): ):
_data:list = [] _data: list = []
index_and_len_list = [] index_and_len_list = []
for idx, item in enumerate(data): for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"]) norm_text_len = len(item["norm_text"])
@ -441,33 +455,32 @@ class TTS:
batch_index_list = [] batch_index_list = []
if split_bucket: if split_bucket:
index_and_len_list.sort(key=lambda x: x[1]) index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64) index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list_len = 0 batch_index_list_len = 0
pos = 0 pos = 0
while pos <index_and_len_list.shape[0]: while pos < index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))]) # batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0]) pos_end = min(pos + batch_size, index_and_len_list.shape[0])
while pos < pos_end: while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32) batch = index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8) score = batch[(pos_end - pos) // 2] / (batch.mean() + 1e-8)
if (score>=threshold) or (pos_end-pos==1): if (score >= threshold) or (pos_end - pos == 1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist() batch_index = index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index) batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index) batch_index_list.append(batch_index)
pos = pos_end pos = pos_end
break break
pos_end=pos_end-1 pos_end = pos_end - 1
assert batch_index_list_len == len(data) assert batch_index_list_len == len(data)
else: else:
for i in range(len(data)): for i in range(len(data)):
if i%batch_size == 0: if i % batch_size == 0:
batch_index_list.append([]) batch_index_list.append([])
batch_index_list[-1].append(i) batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list): for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list] item_list = [data[idx] for idx in index_list]
phones_list = [] phones_list = []
@ -481,33 +494,32 @@ class TTS:
phones_max_len = 0 phones_max_len = 0
for item in item_list: for item in item_list:
if prompt_data is not None: if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1) \
.to(dtype=precision, device=device) .to(dtype=precision, device=device)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device) all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device) phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = prompt_data["norm_text"]+item["norm_text"] # norm_text = prompt_data["norm_text"]+item["norm_text"]
else: else:
all_bert_features = item["bert_features"]\ all_bert_features = item["bert_features"] \
.to(dtype=precision, device=device) .to(dtype=precision, device=device)
phones = torch.LongTensor(item["phones"]).to(device) phones = torch.LongTensor(item["phones"]).to(device)
all_phones = phones all_phones = phones
# norm_text = item["norm_text"] # norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1]) phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones) phones_list.append(phones)
phones_len_list.append(phones.shape[-1]) phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones) all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1]) all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features) all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"]) norm_text_batch.append(item["norm_text"])
phones_batch = phones_list phones_batch = phones_list
all_phones_batch = all_phones_list all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list all_bert_features_batch = all_bert_features_list
max_len = max(bert_max_len, phones_max_len) max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad。padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略 #### 直接对phones和bert_features进行pad。padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
@ -516,16 +528,16 @@ class TTS:
# all_bert_features_batch = torch.zeros((len(all_bert_features_list), 1024, max_len), dtype=precision, device=device) # all_bert_features_batch = torch.zeros((len(all_bert_features_list), 1024, max_len), dtype=precision, device=device)
# for idx, item in enumerate(all_bert_features_list): # for idx, item in enumerate(all_bert_features_list):
# all_bert_features_batch[idx, :, : item.shape[-1]] = item # all_bert_features_batch[idx, :, : item.shape[-1]] = item
# #### 先对phones进行embedding、对bert_features进行project再pad到相同长度padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略 # #### 先对phones进行embedding、对bert_features进行project再pad到相同长度padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list] # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
# all_phones_batch = torch.stack(all_phones_list, dim=0) # all_phones_batch = torch.stack(all_phones_list, dim=0)
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list] # all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list] # all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0) # all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
batch = { batch = {
"phones": phones_batch, "phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device), "phones_len": torch.LongTensor(phones_len_list).to(device),
@ -536,10 +548,10 @@ class TTS:
"max_len": max_len, "max_len": max_len,
} }
_data.append(batch) _data.append(batch)
return _data, batch_index_list return _data, batch_index_list
def recovery_order(self, data:list, batch_index_list:list)->list: def recovery_order(self, data: list, batch_index_list: list) -> list:
''' '''
Recovery the order of the audio according to the batch_index_list. Recovery the order of the audio according to the batch_index_list.
@ -551,20 +563,20 @@ class TTS:
list (List[np.ndarray]): the data in the original order. list (List[np.ndarray]): the data in the original order.
''' '''
length = len(sum(batch_index_list, [])) length = len(sum(batch_index_list, []))
_data = [None]*length _data = [None] * length
for i, index_list in enumerate(batch_index_list): for i, index_list in enumerate(batch_index_list):
for j, index in enumerate(index_list): for j, index in enumerate(index_list):
_data[index] = data[i][j] _data[index] = data[i][j]
return _data return _data
def stop(self,): def stop(self, ):
''' '''
Stop the inference process. Stop the inference process.
''' '''
self.stop_flag = True self.stop_flag = True
@torch.no_grad() @torch.no_grad()
def run(self, inputs:dict): def run(self, inputs: dict):
""" """
Text to speech inference. Text to speech inference.
@ -594,16 +606,16 @@ class TTS:
tuple[int, np.ndarray]: sampling rate and audio data. tuple[int, np.ndarray]: sampling rate and audio data.
""" """
########## variables initialization ########### ########## variables initialization ###########
self.stop_flag:bool = False self.stop_flag: bool = False
text:str = inputs.get("text", "") text: str = inputs.get("text", "")
text_lang:str = inputs.get("text_lang", "") text_lang: str = inputs.get("text_lang", "")
ref_audio_path:str = inputs.get("ref_audio_path", "") ref_audio_path: str = inputs.get("ref_audio_path", "")
prompt_text:str = inputs.get("prompt_text", "") prompt_text: str = inputs.get("prompt_text", "")
prompt_lang:str = inputs.get("prompt_lang", "") prompt_lang: str = inputs.get("prompt_lang", "")
top_k:int = inputs.get("top_k", 5) top_k: int = inputs.get("top_k", 5)
top_p:float = inputs.get("top_p", 1) top_p: float = inputs.get("top_p", 1)
temperature:float = inputs.get("temperature", 1) temperature: float = inputs.get("temperature", 1)
text_split_method:str = inputs.get("text_split_method", "cut0") text_split_method: str = inputs.get("text_split_method", "cut0")
batch_size = inputs.get("batch_size", 1) batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75) batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0) speed_factor = inputs.get("speed_factor", 1.0)
@ -632,7 +644,7 @@ class TTS:
if split_bucket: if split_bucket:
print(i18n("分桶处理模式已开启")) print(i18n("分桶处理模式已开启"))
if fragment_interval<0.01: if fragment_interval < 0.01:
fragment_interval = 0.01 fragment_interval = 0.01
print(i18n("分段间隔过小已自动设置为0.01")) print(i18n("分段间隔过小已自动设置为0.01"))
@ -645,8 +657,9 @@ class TTS:
assert prompt_lang in self.configs.languages assert prompt_lang in self.configs.languages
if ref_audio_path in [None, ""] and \ if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)): ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") raise ValueError(
"ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
###### setting reference audio and prompt text preprocessing ######## ###### setting reference audio and prompt text preprocessing ########
t0 = ttime() t0 = ttime()
@ -662,48 +675,49 @@ class TTS:
self.prompt_cache["prompt_lang"] = prompt_lang self.prompt_cache["prompt_lang"] = prompt_lang
phones, bert_features, norm_text = \ phones, bert_features, norm_text = \
self.text_preprocessor.segment_and_extract_feature_for_text( self.text_preprocessor.segment_and_extract_feature_for_text(
prompt_text, prompt_text,
prompt_lang) prompt_lang)
self.prompt_cache["phones"] = phones self.prompt_cache["phones"] = phones
self.prompt_cache["bert_features"] = bert_features self.prompt_cache["bert_features"] = bert_features
self.prompt_cache["norm_text"] = norm_text self.prompt_cache["norm_text"] = norm_text
###### text preprocessing ######## ###### text preprocessing ########
t1 = ttime() t1 = ttime()
data:list = None data: list = None
if not return_fragment: if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
if len(data) == 0: if len(data) == 0:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16) dtype=np.int16)
return return
batch_index_list:list = None batch_index_list: list = None
data, batch_index_list = self.to_batch(data, data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None, prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size, batch_size=batch_size,
threshold=batch_threshold, threshold=batch_threshold,
split_bucket=split_bucket, split_bucket=split_bucket,
device=self.configs.device, device=self.configs.device,
precision=self.precision precision=self.precision
) )
else: else:
print(i18n("############ 切分文本 ############")) print(i18n("############ 切分文本 ############"))
texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
data = [] data = []
for i in range(len(texts)): for i in range(len(texts)):
if i%batch_size == 0: if i % batch_size == 0:
data.append([]) data.append([])
data[-1].append(texts[i]) data[-1].append(texts[i])
def make_batch(batch_texts): def make_batch(batch_texts):
batch_data = [] batch_data = []
print(i18n("############ 提取文本Bert特征 ############")) print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(batch_texts): for text in tqdm(batch_texts):
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text,
text_lang)
if phones is None: if phones is None:
continue continue
res={ res = {
"phones": phones, "phones": phones,
"bert_features": bert_features, "bert_features": bert_features,
"norm_text": norm_text, "norm_text": norm_text,
@ -711,17 +725,16 @@ class TTS:
batch_data.append(res) batch_data.append(res)
if len(batch_data) == 0: if len(batch_data) == 0:
return None return None
batch, _ = self.to_batch(batch_data, batch, _ = self.to_batch(batch_data,
prompt_data=self.prompt_cache if not no_prompt_text else None, prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size, batch_size=batch_size,
threshold=batch_threshold, threshold=batch_threshold,
split_bucket=False, split_bucket=False,
device=self.configs.device, device=self.configs.device,
precision=self.precision precision=self.precision
) )
return batch[0] return batch[0]
t2 = ttime() t2 = ttime()
try: try:
print("############ 推理 ############") print("############ 推理 ############")
@ -736,21 +749,21 @@ class TTS:
if item is None: if item is None:
continue continue
batch_phones:List[torch.LongTensor] = item["phones"] batch_phones: List[torch.LongTensor] = item["phones"]
# batch_phones:torch.LongTensor = item["phones"] # batch_phones:torch.LongTensor = item["phones"]
batch_phones_len:torch.LongTensor = item["phones_len"] batch_phones_len: torch.LongTensor = item["phones_len"]
all_phoneme_ids:torch.LongTensor = item["all_phones"] all_phoneme_ids: torch.LongTensor = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"] all_phoneme_lens: torch.LongTensor = item["all_phones_len"]
all_bert_features:torch.LongTensor = item["all_bert_features"] all_bert_features: torch.LongTensor = item["all_bert_features"]
norm_text:str = item["norm_text"] norm_text: str = item["norm_text"]
max_len = item["max_len"] max_len = item["max_len"]
print(i18n("前端处理后的文本(每句):"), norm_text) print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text : if no_prompt_text:
prompt = None prompt = None
else: else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(
self.configs.device)
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
@ -768,14 +781,14 @@ class TTS:
t4 = ttime() t4 = ttime()
t_34 += t4 - t3 t_34 += t4 - t3
refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\ refer_audio_spec: torch.Tensor = self.prompt_cache["refer_spec"] \
.to(dtype=self.precision, device=self.configs.device) .to(dtype=self.precision, device=self.configs.device)
batch_audio_fragment = [] batch_audio_fragment = []
# 这里要记得加 torch.no_grad() 不然速度慢一大截 # 这里要记得加 torch.no_grad() 不然速度慢一大截
# with torch.no_grad(): # with torch.no_grad():
# ## vits并行推理 method 1 # ## vits并行推理 method 1
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] # pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device) # pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
@ -792,15 +805,17 @@ class TTS:
# ## vits并行推理 method 2 # ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates) upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] audio_frag_idx = [pred_semantic_list[i].shape[0] * 2 * upsample_rate for i in
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] range(0, len(pred_semantic_list))]
audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode( _batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec all_pred_semantic, _batch_phones, refer_audio_spec
).detach()[0, 0, :]) ).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0) audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] batch_audio_fragment = [_batch_audio_fragment[audio_frag_end_idx[i - 1]:audio_frag_end_idx[i]] for i in
range(1, len(audio_frag_end_idx))]
# ## vits串行推理 # ## vits串行推理
# for i, idx in enumerate(idx_list): # for i, idx in enumerate(idx_list):
@ -817,36 +832,36 @@ class TTS:
t_45 += t5 - t4 t_45 += t5 - t4
if return_fragment: if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess([batch_audio_fragment], yield self.audio_postprocess([batch_audio_fragment],
self.configs.sampling_rate, self.configs.sampling_rate,
None, None,
speed_factor, speed_factor,
False, False,
fragment_interval fragment_interval
) )
else: else:
audio.append(batch_audio_fragment) audio.append(batch_audio_fragment)
if self.stop_flag: if self.stop_flag:
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16) dtype=np.int16)
return return
if not return_fragment: if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield self.audio_postprocess(audio, yield self.audio_postprocess(audio,
self.configs.sampling_rate, self.configs.sampling_rate,
batch_index_list, batch_index_list,
speed_factor, speed_factor,
split_bucket, split_bucket,
fragment_interval fragment_interval
) )
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。 # 必须返回一个空音频, 否则会导致显存不释放。
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
dtype=np.int16) dtype=np.int16)
# 重置模型, 否则会导致显存释放不完全。 # 重置模型, 否则会导致显存释放不完全。
del self.t2s_model del self.t2s_model
del self.vits_model del self.vits_model
@ -857,60 +872,56 @@ class TTS:
raise e raise e
finally: finally:
self.empty_cache() self.empty_cache()
def empty_cache(self): def empty_cache(self):
try: try:
if "cuda" in str(self.configs.device): if "cuda" in str(self.configs.device):
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif str(self.configs.device) == "mps": elif str(self.configs.device) == "mps":
torch.mps.empty_cache() torch.mps.empty_cache()
except: except:
pass pass
def audio_postprocess(self, def audio_postprocess(self,
audio:List[torch.Tensor], audio: List[torch.Tensor],
sr:int, sr: int,
batch_index_list:list=None, batch_index_list: list = None,
speed_factor:float=1.0, speed_factor: float = 1.0,
split_bucket:bool=True, split_bucket: bool = True,
fragment_interval:float=0.3 fragment_interval: float = 0.3
)->tuple[int, np.ndarray]: ) -> tuple[int, np.ndarray]:
zero_wav = torch.zeros( zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), int(self.configs.sampling_rate * fragment_interval),
dtype=self.precision, dtype=self.precision,
device=self.configs.device device=self.configs.device
) )
for i, batch in enumerate(audio): for i, batch in enumerate(audio):
for j, audio_fragment in enumerate(batch): for j, audio_fragment in enumerate(batch):
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio if max_audio > 1: audio_fragment /= max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy() audio[i][j] = audio_fragment.cpu().numpy()
if split_bucket: if split_bucket:
audio = self.recovery_order(audio, batch_index_list) audio = self.recovery_order(audio, batch_index_list)
else: else:
# audio = [item for batch in audio for item in batch] # audio = [item for batch in audio for item in batch]
audio = sum(audio, []) audio = sum(audio, [])
audio = np.concatenate(audio, 0) audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16) audio = (audio * 32768).astype(np.int16)
try: try:
if speed_factor != 1.0: if speed_factor != 1.0:
audio = speed_change(audio, speed=speed_factor, sr=int(sr)) audio = speed_change(audio, speed=speed_factor, sr=int(sr))
except Exception as e: except Exception as e:
print(f"Failed to change speed of audio: \n{e}") print(f"Failed to change speed of audio: \n{e}")
return sr, audio return sr, audio
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流 # 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes() raw_audio = input_audio.astype(np.int16).tobytes()
@ -929,4 +940,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将管道输出解码为 NumPy 数组 # 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16) processed_audio = np.frombuffer(out, np.int16)
return processed_audio return processed_audio

View File

@ -1,6 +1,7 @@
custom: custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
load_base: true
device: cuda device: cuda
is_half: true is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt

View File

@ -0,0 +1,15 @@
custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
load_base: false
t2s_weights_path: GPT_weights/voice1-e10.ckpt
vits_weights_path: SoVITS_weights/voice1_e8_s192.pth
default:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth

View File

@ -0,0 +1,15 @@
custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
load_base: false
t2s_weights_path: GPT_weights/voice1-e10.ckpt
vits_weights_path: SoVITS_weights/voice1_e8_s192.pth
default:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth

470
api_v3.py Normal file
View File

@ -0,0 +1,470 @@
"""
# WebAPI文档 (3.0) - 使用了缓存技术初始化时使用LRU Cache TTS 实例,缓存加载模型的世界,达到减少切换不同语音时的推理时间
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
## 执行参数:
`-a` - `绑定地址, 默认"127.0.0.1"`
`-p` - `绑定端口, 默认9880`
`-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
## 调用:
### 推理
endpoint: `/tts`
GET:
```
http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是罗浮云骑将军景元不必拘谨将军只是一时的身份你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
```
POST:
```json
{
"text": "", # str.(required) text to be synthesized
"text_lang": "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path.
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int.(optional) top k sampling
"top_p": 1, # float.(optional) top p sampling
"temperature": 1, # float.(optional) temperature for sampling
"text_split_method": "cut5", # str.(optional) text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int.(optional) batch size for inference
"batch_threshold": 0.75, # float.(optional) threshold for batch splitting.
"split_bucket": true, # bool.(optional) whether to split the batch into multiple buckets.
"speed_factor":1.0, # float.(optional) control the speed of the synthesized audio.
"fragment_interval":0.3, # float.(optional) to control the interval of the audio fragment.
"seed": -1, # int.(optional) random seed for reproducibility.
"media_type": "wav", # str.(optional) media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": false, # bool.(optional) whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35, # float.(optional) repetition penalty for T2S model.
"tts_infer_yaml_path": GPT_SoVITS/configs/tts_infer.yaml # str.(optional) tts infer yaml path
}
```
RESP:
成功: 直接返回 wav 音频流 http code 200
失败: 返回包含错误信息的 json, http code 400
### 命令控制
endpoint: `/control`
command:
"restart": 重新运行
"exit": 结束运行
GET:
```
http://127.0.0.1:9880/control?command=restart
```
POST:
```json
{
"command": "restart"
}
```
RESP:
### 切换GPT模型
endpoint: `/set_gpt_weights`
GET:
```
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
```
RESP:
成功: 返回"success", http code 200
失败: 返回包含错误信息的 json, http code 400
### 切换Sovits模型
endpoint: `/set_sovits_weights`
GET:
```
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
```
RESP:
成功: 返回"success", http code 200
失败: 返回包含错误信息的 json, http code 400
"""
import os
import sys
import traceback
from typing import Generator
import torch
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import argparse
import subprocess
import wave
import signal
import numpy as np
import soundfile as sf
from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi import FastAPI
import uvicorn
from io import BytesIO
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from functools import lru_cache
cut_method_names = get_cut_method_names()
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
args = parser.parse_args()
port = args.port
host = args.bind_addr
argv = sys.argv
default_tts_config = TTS_Config()
TTS.init_base_models(default_tts_config)
APP = FastAPI()
class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
ref_audio_path: str = None
prompt_lang: str = None
prompt_text: str = ""
top_k: int = 5
top_p: float = 1
temperature: float = 1
text_split_method: str = "cut5"
batch_size: int = 1
batch_threshold: float = 0.75
split_bucket: bool = True
speed_factor: float = 1.0
fragment_interval: float = 0.3
seed: int = -1
media_type: str = "wav"
streaming_mode: bool = False
parallel_infer: bool = True
repetition_penalty: float = 1.35
tts_infer_yaml_path: str = None
"""推理时需要加载的声音模型的yaml配置文件路径GPT_SoVITS/configs/tts_infer.yaml"""
@lru_cache(maxsize=10)
def get_tts_instance(tts_config: TTS_Config) -> TTS:
print(f"load tts config from {tts_config.configs_path}")
return TTS(tts_config)
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
"""modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files"""
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
audio_file.write(data)
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format='wav')
return io_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen([
'ffmpeg',
'-f', 's16le', # 输入16位有符号小端整数PCM
'-ar', str(rate), # 设置采样率
'-ac', '1', # 单声道
'-i', 'pipe:0', # 从管道读取输入
'-c:a', 'aac', # 音频编码器为AAC
'-b:a', '192k', # 比特率
'-vn', # 不包含视频
'-f', 'adts', # 输出AAC数据流格式
'pipe:1' # 将输出写入管道
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
# This will create a wave header then append the frame input
# It should be first on a streaming wav file
# Other frames better should not have it (else you will hear some artifacts each chunk start)
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()
def handle_control(command: str):
if command == "restart":
os.execl(sys.executable, sys.executable, *argv)
elif command == "exit":
os.kill(os.getpid(), signal.SIGTERM)
exit(0)
def check_params(req: dict, tts_config: TTS_Config):
text: str = req.get("text", "")
text_lang: str = req.get("text_lang", "")
ref_audio_path: str = req.get("ref_audio_path", "")
streaming_mode: bool = req.get("streaming_mode", False)
media_type: str = req.get("media_type", "wav")
prompt_lang: str = req.get("prompt_lang", "")
text_split_method: str = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if (text_lang in [None, ""]):
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
elif text_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": "text_lang is not supported"})
if (prompt_lang in [None, ""]):
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
elif prompt_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"})
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
if text_split_method not in cut_method_names:
return JSONResponse(status_code=400,
content={"message": f"text_split_method:{text_split_method} is not supported"})
return None
async def tts_handle(req: dict):
"""
Text to speech handler.
Args:
req (dict):
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
}
returns:
StreamingResponse: audio stream response.
"""
streaming_mode = req.get("streaming_mode", False)
media_type = req.get("media_type", "wav")
tts_infer_yaml_path = req.get("tts_infer_yaml_path", "GPT_SoVITS/configs/tts_infer.yaml")
tts_config = TTS_Config(tts_infer_yaml_path)
check_res = check_params(req, tts_config)
if check_res is not None:
return check_res
if streaming_mode:
req["return_fragment"] = True
try:
tts_instance = get_tts_instance(tts_config)
move_to_gpu(tts_instance, tts_config)
tts_generator = tts_instance.run(req)
if streaming_mode:
def streaming_generator(tts_generator: Generator, media_type: str):
if media_type == "wav":
yield wave_header_chunk()
media_type = "raw"
for sr, chunk in tts_generator:
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
move_to_cpu(tts_instance)
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
else:
sr, audio_data = next(tts_generator)
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
move_to_cpu(tts_instance)
return Response(audio_data, media_type=f"audio/{media_type}")
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
def move_to_cpu(tts):
cpu_device = torch.device('cpu')
tts.set_device(cpu_device)
print("Moved TTS models to CPU to save GPU memory.")
def move_to_gpu(tts: TTS, tts_config: TTS_Config):
tts.set_device(tts_config.device)
print("Moved TTS models back to GPU for performance.")
@APP.get("/control")
async def control(command: str = None):
if command is None:
return JSONResponse(status_code=400, content={"message": "command is required"})
handle_control(command)
@APP.get("/tts")
async def tts_get_endpoint(
text: str = None,
text_lang: str = None,
ref_audio_path: str = None,
prompt_lang: str = None,
prompt_text: str = "",
top_k: int = 5,
top_p: float = 1,
temperature: float = 1,
text_split_method: str = "cut0",
batch_size: int = 1,
batch_threshold: float = 0.75,
split_bucket: bool = True,
speed_factor: float = 1.0,
fragment_interval: float = 0.3,
seed: int = -1,
media_type: str = "wav",
streaming_mode: bool = False,
parallel_infer: bool = True,
repetition_penalty: float = 1.35,
tts_infer_yaml_path: str = "GPT_SoVITS/configs/tts_infer.yaml"
):
req = {
"text": text,
"text_lang": text_lang.lower(),
"ref_audio_path": ref_audio_path,
"prompt_text": prompt_text,
"prompt_lang": prompt_lang.lower(),
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": text_split_method,
"batch_size": int(batch_size),
"batch_threshold": float(batch_threshold),
"speed_factor": float(speed_factor),
"split_bucket": split_bucket,
"fragment_interval": fragment_interval,
"seed": seed,
"media_type": media_type,
"streaming_mode": streaming_mode,
"parallel_infer": parallel_infer,
"repetition_penalty": float(repetition_penalty),
"tts_infer_yaml_path": tts_infer_yaml_path
}
return await tts_handle(req)
@APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request):
req = request.dict()
return await tts_handle(req)
@APP.get("/set_refer_audio")
async def set_refer_audio(refer_audio_path: str = None, tts_infer_yaml_path: str = "GPT_SoVITS/configs/tts_infer.yaml"):
try:
tts_config = TTS_Config(tts_infer_yaml_path)
tts_instance = get_tts_instance(tts_config)
tts_instance.set_ref_audio(refer_audio_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
@APP.get("/set_gpt_weights")
async def set_gpt_weights(weights_path: str = None, tts_infer_yaml_path: str = "GPT_SoVITS/configs/tts_infer.yaml"):
try:
if weights_path in ["", None]:
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
tts_config = TTS_Config(tts_infer_yaml_path)
tts_instance = get_tts_instance(tts_config)
tts_instance.init_t2s_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"change gpt weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
@APP.get("/set_sovits_weights")
async def set_sovits_weights(weights_path: str = None, tts_infer_yaml_path: str = "GPT_SoVITS/configs/tts_infer.yaml"):
try:
if weights_path in ["", None]:
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
tts_config = TTS_Config(tts_infer_yaml_path)
tts_instance = get_tts_instance(tts_config)
tts_instance.init_vits_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
if __name__ == "__main__":
try:
uvicorn.run(APP, host=host, port=port, workers=1)
except Exception as e:
traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM)
exit(0)