From 9f840832cb85accd0f476b33e28ef081158bac03 Mon Sep 17 00:00:00 2001 From: XXXXRT Date: Wed, 3 Apr 2024 13:16:07 +0100 Subject: [PATCH] MAKE API GREAT AGAIN! --- api.py | 726 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 570 insertions(+), 156 deletions(-) diff --git a/api.py b/api.py index ea0e39d0..b1052a33 100644 --- a/api.py +++ b/api.py @@ -33,12 +33,12 @@ endpoint: `/` 使用执行参数指定的参考音频: GET: - `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` #从zh,en,ja,auto中选择 POST: ```json { "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh" + "text_language": "zh" #从zh,en,ja,auto中选择 } ``` @@ -49,20 +49,20 @@ POST: ```json { "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh", + "text_language": "zh", #从zh,en,ja,auto中选择 "cut_punc": ",。", } ``` 手动指定当次推理所使用的参考音频: GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` #从zh,en,ja,auto中选择 POST: ```json { "refer_wav_path": "123.wav", "prompt_text": "一二三。", - "prompt_language": "zh", + "prompt_language": "zh", #从zh,en,ja,auto中选择 "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", "text_language": "zh" } @@ -116,10 +116,16 @@ RESP: 无 """ - import argparse import os,re import sys + + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) # 神奇位置,防止import的问题 + + import signal import LangSegment from time import time as ttime @@ -142,17 +148,13 @@ from my_utils import load_audio import config as global_config import logging import subprocess +from typing import Dict, List, Tuple +from tools.i18n.i18n import I18nAuto +import traceback +import math +i18n = I18nAuto() -class DefaultRefer: - def __init__(self, path, text, language): - self.path = args.default_refer_path - self.text = args.default_refer_text - self.language = args.default_refer_language - - def is_ready(self) -> bool: - return is_full(self.path, self.text, self.language) - def is_empty(*items): # 任意一项不为空返回False for item in items: @@ -191,12 +193,17 @@ def change_sovits_weights(sovits_path): def change_gpt_weights(gpt_path): - global hz, max_sec, t2s_model, config + global hz, max_sec, t2s_model, config, is_fast_inference hz = 50 dict_s1 = torch.load(gpt_path, map_location="cpu") config = dict_s1["config"] max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + try: + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False, flash_attn_enabled=flash_atten) + is_fast_inference = True + except TypeError: + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + is_fast_inference = False t2s_model.load_state_dict(dict_s1["weight"]) if is_half == True: t2s_model = t2s_model.half() @@ -233,16 +240,20 @@ def get_bert_inf(phones, word2ph, norm_text, language): language=language.replace("all_","") if language == "zh": bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype) + bert = torch.zeros( + (1024, len(phones)), + dtype=precision, + ).to(device) else: bert = torch.zeros( (1024, len(phones)), - dtype=torch.float16 if is_half == True else torch.float32, + dtype=precision, ).to(device) return bert -def get_phones_and_bert(text,language): +def get_phones_and_bert(text:str,language:str): if language in {"en","all_zh","all_ja"}: language = language.replace("all_","") if language == "en": @@ -259,7 +270,7 @@ def get_phones_and_bert(text,language): else: bert = torch.zeros( (1024, len(phones)), - dtype=torch.float16 if is_half == True else torch.float32, + dtype=precision, ).to(device) elif language in {"zh", "ja","auto"}: textlist=[] @@ -300,6 +311,14 @@ def get_phones_and_bert(text,language): return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text +def extract_feature_for_text(textlist:list, langlist:list)->Tuple[list, torch.Tensor, str]: + if len(textlist) == 0: + return None, None, None + + phones, bert_features, norm_text = get_phones_and_bert(textlist, langlist) + return phones, bert_features, norm_text + + class DictToAttrRecursive: def __init__(self, input_dict): for key, value in input_dict.items(): @@ -310,14 +329,75 @@ class DictToAttrRecursive: setattr(self, key, value) -def get_spepc(hps, filename): - audio = load_audio(filename, int(hps.data.sampling_rate)) - audio = torch.FloatTensor(audio) - audio_norm = audio - audio_norm = audio_norm.unsqueeze(0) - spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, - hps.data.win_length, center=False) - return spec +class REF: + def __init__(self, ref_path="", ref_text="", ref_language=""): + ref_text = ref_text.strip("\n") + if ref_text: + if (ref_text[-1] not in splits): ref_text += "。" if ref_language != "en" else "." + if ref_language: + ref_language = dict_language[ref_language.lower()] + self.path = ref_path + self.text = ref_text + self.language = ref_language + + def set_prompt_semantic(self, ref_wav_path:str): + zero_wav = np.zeros( + int(hps.data.sampling_rate * 0.3), + dtype=np.float16 if is_half else np.float32, + ) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + if is_half: + wav16k = wav16k.half() + zero_wav_torch = zero_wav_torch.half() + + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = vq_model.extract_latent(ssl_content) + + prompt_semantic = codes[0, 0].to(device) + self.prompt_semantic = prompt_semantic + self.codes = codes + self.ssl_content = ssl_content + + def set_ref_spec(self, ref_audio_path): + audio = load_audio(ref_audio_path, int(hps.data.sampling_rate)) + audio = torch.FloatTensor(audio) + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + spec = spec.to(device) + if is_half: + spec = spec.half() + # self.refer_spec = spec + self.refer_spec = spec + + def set_ref_audio(self): + ''' + To set the reference audio for the TTS model, + including the prompt_semantic and refer_spec. + Args: + ref_audio_path: str, the path of the reference audio. + ''' + self.set_prompt_semantic(self.path) + self.set_ref_spec(self.path) + self.phone, self.bert_feature, self.norm_text = get_phones_and_bert(self.text, self.language) + + def is_ready(self) -> bool: + return is_full(self.path, self.text, self.language) def pack_audio(audio_bytes, data, rate): @@ -402,29 +482,329 @@ def only_punc(text): return not any(t.isalnum() or t.isalpha() for t in text) -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): - t0 = ttime() - prompt_text = prompt_text.strip("\n") - prompt_language, text = prompt_language, text.strip("\n") - zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - if (is_half == True): - wav16k = wav16k.half().to(device) - zero_wav_torch = zero_wav_torch.half().to(device) +def preprocess(text:list, lang:str)->List[Dict]: + result = [] + print(i18n("############ 提取文本Bert特征 ############")) + for _text in text: + phones, bert_features, norm_text = extract_feature_for_text(_text, lang) + if phones is None: + continue + res={ + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + result.append(res) + return result + + +def audio_postprocess( + audio:List[torch.Tensor], + sr:int, + batch_index_list:list=None, + fragment_interval:float=0.3 + ): + zero_wav = torch.zeros( + int(hps.data.sampling_rate * fragment_interval), + dtype=precision, + device=device + ) + + audio_bytes = BytesIO() + + for i, batch in enumerate(audio): + for j, audio_fragment in enumerate(batch): + max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 + if max_audio>1: audio_fragment/=max_audio + audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) + audio[i][j] = audio_fragment.cpu().numpy() + + + if split_bucket: + audio = recovery_order(audio, batch_index_list) + else: + # audio = [item for batch in audio for item in batch] + audio = sum(audio, []) + + + audio = pack_audio(audio_bytes,(np.concatenate(audio, 0) * 32768).astype(np.int16),hps.data.sampling_rate) + + if media_type == "wav": + audio_bytes = pack_wav(audio,hps.data.sampling_rate) + return audio_bytes.getvalue() + + +def batch_sequences(sequences: List[torch.Tensor], axis:int = 0, pad_value:int = 0, max_length:int=None): + seq = sequences[0] + ndim = seq.dim() + if axis < 0: + axis += ndim + dtype:torch.dtype = seq.dtype + pad_value = torch.tensor(pad_value, dtype=dtype) + seq_lengths = [seq.shape[axis] for seq in sequences] + if max_length is None: + max_length = max(seq_lengths) else: - wav16k = wav16k.to(device) - zero_wav_torch = zero_wav_torch.to(device) - wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() - codes = vq_model.extract_latent(ssl_content) - prompt_semantic = codes[0, 0] + max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length + + padded_sequences = [] + for seq, length in zip(sequences, seq_lengths): + padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1) + padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value) + padded_sequences.append(padded_seq) + batch = torch.stack(padded_sequences) + return batch + + +def to_batch(data:list, ref:REF, + threshold:float=0.75, + ): + + _data:list = [] + index_and_len_list = [] + for idx, item in enumerate(data): + norm_text_len = len(item["norm_text"]) + index_and_len_list.append([idx, norm_text_len]) + + batch_index_list = [] + if split_bucket: + index_and_len_list.sort(key=lambda x: x[1]) + index_and_len_list = np.array(index_and_len_list, dtype=np.int64) + + batch_index_list_len = 0 + pos = 0 + while pos =threshold) or (pos_end-pos==1): + batch_index=index_and_len_list[pos:pos_end, 0].tolist() + batch_index_list_len += len(batch_index) + batch_index_list.append(batch_index) + pos = pos_end + break + pos_end=pos_end-1 + + assert batch_index_list_len == len(data) + + else: + for i in range(len(data)): + if i%batch_size == 0: + batch_index_list.append([]) + batch_index_list[-1].append(i) + + + for batch_idx, index_list in enumerate(batch_index_list): + item_list = [data[idx] for idx in index_list] + phones_list = [] + phones_len_list = [] + # bert_features_list = [] + all_phones_list = [] + all_phones_len_list = [] + all_bert_features_list = [] + norm_text_batch = [] + bert_max_len = 0 + phones_max_len = 0 + for item in item_list: + all_bert_features = torch.cat([ref.bert_feature, item["bert_features"]], 1).to(dtype=precision, device=device) + all_phones = torch.LongTensor(ref.phone+item["phones"]).to(device) + phones = torch.LongTensor(item["phones"]).to(device) + # norm_text = ref.norm_text+item["norm_text"] + + bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) + phones_max_len = max(phones_max_len, phones.shape[-1]) + + phones_list.append(phones) + phones_len_list.append(phones.shape[-1]) + all_phones_list.append(all_phones) + all_phones_len_list.append(all_phones.shape[-1]) + all_bert_features_list.append(all_bert_features) + norm_text_batch.append(item["norm_text"]) + + phones_batch = phones_list + all_phones_batch = all_phones_list + all_bert_features_batch = all_bert_features_list + + + batch = { + "phones": phones_batch, + "phones_len": torch.LongTensor(phones_len_list).to(device), + "all_phones": all_phones_batch, + "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), + "all_bert_features": all_bert_features_batch, + "norm_text": norm_text_batch + } + _data.append(batch) + + return _data, batch_index_list + + +def recovery_order(data:list, batch_index_list:list)->list: + ''' + Recovery the order of the audio according to the batch_index_list. + + Args: + data (List[list(np.ndarray)]): the out of order audio . + batch_index_list (List[list[int]]): the batch index list. + + Returns: + list (List[np.ndarray]): the data in the original order. + ''' + length = len(sum(batch_index_list, [])) + _data = [None]*length + for i, index_list in enumerate(batch_index_list): + for j, index in enumerate(index_list): + _data[index] = data[i][j] + return _data + + +def run(ref:REF, text, text_lang): + logger.info("run") + + ########## variables initialization ########### + if not is_fast_inference: + batch_size = 1 + top_k = 5 + top_p = 1 + temperature = 1 + batch_threshold = 0.75 + fragment_interval = 0.3 + text_lang = dict_language[text_lang.lower()] + + + if ref.path in [None, ""] or \ + ((ref.prompt_semantic is None) or (ref.refer_spec is None)): + raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") + + + t0 = ttime() + ###### text preprocessing ######## + t1 = ttime() + data:list = None + if not return_fragment: + data = text.split("\n") + if len(data) == 0: + yield np.zeros(int(hps.data.sampling_rate), type=np.int16) + return + + batch_index_list:list = None + data = preprocess(data, text_lang) + data, batch_index_list = to_batch(data, ref, + threshold=batch_threshold, + ) + else: + print(i18n("############ 切分文本 ############")) + texts = text.split("\n") + data = [] + for i in range(len(texts)): + if i%batch_size == 0: + data.append([]) + data[-1].append(texts[i]) + + def make_batch(batch_texts): + batch_data = [] + print(i18n("############ 提取文本Bert特征 ############")) + batch_data = preprocess(batch_texts, text_lang) + if len(batch_data) == 0: + return None + batch, _ = to_batch(batch_data, ref, + threshold=batch_threshold, + ) + return batch[0] + + t2 = ttime() + try: + print("############ 推理 ############") + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + for item in data: + t3 = ttime() + if return_fragment: + item = make_batch(item) + if item is None: + continue + + batch_phones:List[torch.LongTensor] = item["phones"] + batch_phones_len:torch.LongTensor = item["phones_len"] + all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] + all_phoneme_lens:torch.LongTensor = item["all_phones_len"] + all_bert_features:List[torch.LongTensor] = item["all_bert_features"] + norm_text:str = item["norm_text"] + + print(norm_text) + + prompt = ref.prompt_semantic.expand(len(all_phoneme_ids), -1).to(device) + + with torch.no_grad(): + pred_semantic_list, idx_list = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + t4 = ttime() + t_34 += t4 - t3 + + refer_audio_spec:torch.Tensor = ref.refer_spec.to(dtype=precision, device=device) + + batch_audio_fragment = [] + + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(vq_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(device) + _batch_audio_fragment = (vq_model.decode( + all_pred_semantic, _batch_phones,refer_audio_spec + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + + + t5 = ttime() + t_45 += t5 - t4 + if return_fragment: + logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) + yield audio_postprocess([batch_audio_fragment], + hps.data.sampling_rate, + None, + fragment_interval + ) + else: + audio.append(batch_audio_fragment) + + logger.info("return_fragment:"+str(return_fragment)+" split_bucket:"+str(split_bucket)+" batch_size"+str(batch_size)+" media_type:"+media_type) + if not return_fragment: + logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + yield audio_postprocess(audio, + hps.data.sampling_rate, + batch_index_list, + fragment_interval + ) + + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield np.zeros(int(hps.data.sampling_rate), dtype=np.int16) + finally: + pass + + +def get_tts_wav(ref:REF, text, text_language): + logger.info("get_tts_wav") + t0 = ttime() t1 = ttime() - prompt_language = dict_language[prompt_language.lower()] text_language = dict_language[text_language.lower()] - phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language) + phones1, bert1, norm_text1 = ref.phone, ref.bert_feature, ref.norm_text texts = text.split("\n") audio_bytes = BytesIO() @@ -432,6 +812,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) # 简单防止纯符号引发参考音频泄露 if only_punc(text): continue + print(text) audio_opt = [] phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language) @@ -440,8 +821,11 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - prompt = prompt_semantic.unsqueeze(0).to(device) + prompt = ref.prompt_semantic.unsqueeze(0).to(device) t2 = ttime() + + zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) + with torch.no_grad(): # pred_semantic = t2s_model.model.infer( pred_semantic, idx = t2s_model.model.infer_panel( @@ -454,106 +838,37 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) early_stop_num=hz * max_sec) t3 = ttime() # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 - refer = get_spepc(hps, ref_wav_path) # .to(device) - if (is_half == True): - refer = refer.half().to(device) + if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): # 神秘代码,有些时候sys.path会出问题,import的是fast inference分支的AR + pred_semantic = pred_semantic[0] + idx=idx[0] + pred_semantic = pred_semantic[-idx:] + pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0) else: - refer = refer.to(device) + pred_semantic = pred_semantic[:,-idx:] + pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 + # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] audio = \ vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - refer).detach().cpu().numpy()[ - 0, 0] ###试试重建不带上prompt部分 + ref.refer_spec).detach().cpu().numpy()[0, 0] ###试试重建不带上prompt部分 audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) - # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) - if stream_mode == "normal": + logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + if return_fragment: audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) yield audio_chunk - if not stream_mode == "normal": + if not return_fragment: if media_type == "wav": audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate) yield audio_bytes.getvalue() - - -def handle_control(command): - if command == "restart": - os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) - elif command == "exit": - os.kill(os.getpid(), signal.SIGTERM) - exit(0) - - -def handle_change(path, text, language): - if is_empty(path, text, language): - return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400) - - if path != "" or path is not None: - default_refer.path = path - if text != "" or text is not None: - default_refer.text = text - if language != "" or language is not None: - default_refer.language = language - - logger.info(f"当前默认参考音频路径: {default_refer.path}") - logger.info(f"当前默认参考音频文本: {default_refer.text}") - logger.info(f"当前默认参考音频语种: {default_refer.language}") - logger.info(f"is_ready: {default_refer.is_ready()}") - - - return JSONResponse({"code": 0, "message": "Success"}, status_code=200) - - -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc): - if ( - refer_wav_path == "" or refer_wav_path is None - or prompt_text == "" or prompt_text is None - or prompt_language == "" or prompt_language is None - ): - refer_wav_path, prompt_text, prompt_language = ( - default_refer.path, - default_refer.text, - default_refer.language, - ) - if not default_refer.is_ready(): - return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) - - if cut_punc == None: - text = cut_text(text,default_cut_punc) - else: - text = cut_text(text,cut_punc) - - return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/"+media_type) - - - - # -------------------------------- # 初始化部分 # -------------------------------- -now_dir = os.getcwd() -sys.path.append(now_dir) -sys.path.append("%s/GPT_SoVITS" % (now_dir)) -dict_language = { - "中文": "all_zh", - "英文": "en", - "日文": "all_ja", - "中英混合": "zh", - "日英混合": "ja", - "多语种混合": "auto", #多语种启动切分识别语种 - "all_zh": "all_zh", - "en": "en", - "all_ja": "all_ja", - "zh": "zh", - "ja": "ja", - "auto": "auto", -} # logger logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) @@ -573,11 +888,14 @@ parser.add_argument("-dl", "--default_refer_language", type=str, default="", hel parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") +parser.add_argument("-bs", "--batch_size", type=int, default=1, help="批处理大小") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") +parser.add_argument("-rf", "--return_fragment", action="store_true", default=False, help="是否开启碎片返回") +parser.add_argument("-sb", "--split_bucket", action="store_true", default=False, help="是否将批处理分成多个桶") +parser.add_argument("-fa", "--flash_atten", action="store_true", default=False, help="是否开启flash_attention") # bool值的用法为 `python ./api.py -fp ...` # 此时 full_precision==True, half_precision==False -parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") # 切割常用分句符为 `python ./api.py -cp ".?!。?!"` @@ -593,9 +911,30 @@ host = args.bind_addr cnhubert_base_path = args.hubert_path bert_path = args.bert_path default_cut_punc = args.cut_punc +batch_size = args.batch_size +return_fragment = args.return_fragment +split_bucket = args.split_bucket +flash_atten = args.flash_atten + +dict_language = { + "中文": "all_zh", + "英文": "en", + "英语": "en", + "日文": "all_ja", + "日语": "all_ja", + "中英混合": "zh", + "日英混合": "ja", + "多语种混合": "auto", #多语种启动切分识别语种 + "all_zh": "all_zh", + "en": "en", + "all_ja": "all_ja", + "zh": "zh", + "ja": "ja", + "auto": "auto", +} +splits = [",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"] +is_fast_inference = True -# 应用参数配置 -default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) # 模型路径检查 if sovits_path == "": @@ -605,15 +944,6 @@ if gpt_path == "": gpt_path = g_config.pretrained_gpt_path logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}") -# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 -if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": - default_refer.path, default_refer.text, default_refer.language = "", "", "" - logger.info("未指定默认参考音频") -else: - logger.info(f"默认参考音频路径: {default_refer.path}") - logger.info(f"默认参考音频文本: {default_refer.text}") - logger.info(f"默认参考音频语种: {default_refer.language}") - # 获取半精度 is_half = g_config.is_half if args.full_precision: @@ -624,22 +954,20 @@ if args.full_precision and args.half_precision: is_half = g_config.is_half # 炒饭fallback logger.info(f"半精: {is_half}") -# 流式返回模式 -if args.stream_mode.lower() in ["normal","n"]: - stream_mode = "normal" - logger.info("流式返回已开启") -else: - stream_mode = "close" +precision = torch.float16 if is_half else torch.float32 +device = torch.device(device) + # 音频编码格式 if args.media_type.lower() in ["aac","ogg"]: media_type = args.media_type.lower() -elif stream_mode == "close": +elif not return_fragment: media_type = "wav" else: media_type = "ogg" logger.info(f"编码格式: {media_type}") + # 初始化模型 cnhubert.cnhubert_base_path = cnhubert_base_path tokenizer = AutoTokenizer.from_pretrained(bert_path) @@ -655,6 +983,92 @@ change_sovits_weights(sovits_path) change_gpt_weights(gpt_path) +# ????? +if split_bucket and is_fast_inference: + return_fragment = False + logger.info("分桶处理已开启") + logger.info("碎片返回已关闭") + +if return_fragment: + logger.info("碎片返回已开启") + +if batch_size != 1 and is_fast_inference: + logger.info("批处理已开启") + logger.info(f"批处理大小:{batch_size}") +else: + logger.info("批处理已关闭") + + +# 应用参数配置 +default_refer = REF(args.default_refer_path, args.default_refer_text, args.default_refer_language) + + +# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 +if not default_refer.is_ready(): + default_refer.path, default_refer.text, default_refer.language = "", "", "" + logger.info("未指定默认参考音频") +else: + logger.info(f"默认参考音频路径: {default_refer.path}") + logger.info(f"默认参考音频文本: {default_refer.text}") + logger.info(f"默认参考音频语种: {default_refer.language}") + default_refer.set_ref_audio() + + +def handle_control(command): + if command == "restart": + os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def handle_change(path, text, language): + global default_refer + if is_empty(path, text, language): + return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400) + + if (path != "" or path is not None) and\ + (text != "" or text is not None) and\ + (language != "" or language is not None): + default_refer = REF(path, text, language) + + logger.info(f"当前默认参考音频路径: {default_refer.path}") + logger.info(f"当前默认参考音频文本: {default_refer.text}") + logger.info(f"当前默认参考音频语种: {default_refer.language}") + logger.info(f"is_ready: {default_refer.is_ready()}") + + + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc): + if (refer_wav_path != default_refer.path) or\ + (prompt_text != default_refer.text) or\ + (prompt_language != default_refer.language): + ref = REF(refer_wav_path, prompt_text, prompt_language) + else: + ref = default_refer + + if ( + refer_wav_path == "" or refer_wav_path is None + or prompt_text == "" or prompt_text is None + or prompt_language == "" or prompt_language is None + ): + ref = default_refer + if not default_refer.is_ready(): + return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + + if cut_punc == None: + text = cut_text(text,default_cut_punc) + else: + text = cut_text(text,cut_punc) + + + + if is_fast_inference: + return StreamingResponse(run(ref, text,text_language), media_type="audio/"+media_type) + else: + return StreamingResponse(get_tts_wav(ref, text,text_language), media_type="audio/"+media_type) # -------------------------------- @@ -720,15 +1134,15 @@ async def tts_endpoint(request: Request): @app.get("/") async def tts_endpoint( - refer_wav_path: str = None, - prompt_text: str = None, - prompt_language: str = None, - text: str = None, - text_language: str = None, - cut_punc: str = None, + refer_wav_path: str = "", + prompt_text: str = "", + prompt_language: str = "", + text: str = "", + text_language: str = "", + cut_punc: str = "", ): return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc) if __name__ == "__main__": - uvicorn.run(app, host=host, port=port, workers=1) + uvicorn.run(app, host=host, port=port, workers=1) \ No newline at end of file