From f45fff72d9efd50f83865464be6ef33ab37ecbe4 Mon Sep 17 00:00:00 2001 From: XXXXRT Date: Wed, 3 Apr 2024 23:33:02 +0100 Subject: [PATCH] Fixed some bug --- api.py | 716 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 368 insertions(+), 348 deletions(-) diff --git a/api.py b/api.py index 89a84a83..f2409de5 100644 --- a/api.py +++ b/api.py @@ -158,6 +158,78 @@ import math i18n = I18nAuto() +class REF: + def __init__(self, ref_path="", ref_text="", ref_language=""): + if ref_text: + ref_text = ref_text.strip("\n") + if (ref_text[-1] not in splits): ref_text += "。" if ref_language != "en" else "." + if ref_language: + ref_language = dict_language[ref_language.lower()] + self.path = ref_path + self.text = ref_text + self.language = ref_language + self.prompt_semantic = None + self.refer_spec = None + + def set_prompt_semantic(self, ref_wav_path:str): + zero_wav = np.zeros( + int(hps.data.sampling_rate * 0.3), + dtype=np.float16 if is_half else np.float32, + ) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + if is_half: + wav16k = wav16k.half() + zero_wav_torch = zero_wav_torch.half() + + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = vq_model.extract_latent(ssl_content) + + prompt_semantic = codes[0, 0].to(device) + self.prompt_semantic = prompt_semantic + self.codes = codes + self.ssl_content = ssl_content + + def set_ref_spec(self, ref_audio_path): + audio = load_audio(ref_audio_path, int(hps.data.sampling_rate)) + audio = torch.FloatTensor(audio) + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + spec = spec.to(device) + if is_half: + spec = spec.half() + # self.refer_spec = spec + self.refer_spec = spec + + def set_ref_audio(self): + ''' + To set the reference audio for the TTS model, + including the prompt_semantic and refer_spec. + Args: + ref_audio_path: str, the path of the reference audio. + ''' + self.set_prompt_semantic(self.path) + self.set_ref_spec(self.path) + self.phone, self.bert_feature, self.norm_text = get_phones_and_bert(self.text, self.language) + + def is_ready(self) -> bool: + return is_full(self.path, self.text, self.language) + def is_empty(*items): # 任意一项不为空返回False for item in items: @@ -193,6 +265,7 @@ def change_sovits_weights(sovits_path): vq_model = vq_model.to(device) vq_model.eval() vq_model.load_state_dict(dict_s2["weight"], strict=False) + def change_gpt_weights(gpt_path): @@ -243,10 +316,6 @@ def get_bert_inf(phones, word2ph, norm_text, language): language=language.replace("all_","") if language == "zh": bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype) - bert = torch.zeros( - (1024, len(phones)), - dtype=precision, - ).to(device) else: bert = torch.zeros( (1024, len(phones)), @@ -332,77 +401,6 @@ class DictToAttrRecursive: setattr(self, key, value) -class REF: - def __init__(self, ref_path="", ref_text="", ref_language=""): - ref_text = ref_text.strip("\n") - if ref_text: - if (ref_text[-1] not in splits): ref_text += "。" if ref_language != "en" else "." - if ref_language: - ref_language = dict_language[ref_language.lower()] - self.path = ref_path - self.text = ref_text - self.language = ref_language - - def set_prompt_semantic(self, ref_wav_path:str): - zero_wav = np.zeros( - int(hps.data.sampling_rate * 0.3), - dtype=np.float16 if is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(device) - zero_wav_torch = zero_wav_torch.to(device) - if is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() - codes = vq_model.extract_latent(ssl_content) - - prompt_semantic = codes[0, 0].to(device) - self.prompt_semantic = prompt_semantic - self.codes = codes - self.ssl_content = ssl_content - - def set_ref_spec(self, ref_audio_path): - audio = load_audio(ref_audio_path, int(hps.data.sampling_rate)) - audio = torch.FloatTensor(audio) - audio_norm = audio - audio_norm = audio_norm.unsqueeze(0) - spec = spectrogram_torch( - audio_norm, - hps.data.filter_length, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - center=False, - ) - spec = spec.to(device) - if is_half: - spec = spec.half() - # self.refer_spec = spec - self.refer_spec = spec - - def set_ref_audio(self): - ''' - To set the reference audio for the TTS model, - including the prompt_semantic and refer_spec. - Args: - ref_audio_path: str, the path of the reference audio. - ''' - self.set_prompt_semantic(self.path) - self.set_ref_spec(self.path) - self.phone, self.bert_feature, self.norm_text = get_phones_and_bert(self.text, self.language) - - def is_ready(self) -> bool: - return is_full(self.path, self.text, self.language) - - def pack_audio(audio_bytes, data, rate): if media_type == "ogg": audio_bytes = pack_ogg(audio_bytes, data, rate) @@ -485,6 +483,73 @@ def only_punc(text): return not any(t.isalnum() or t.isalpha() for t in text) +def get_tts_wav(ref:REF, text, text_language): + logger.info("get_tts_wav") + t0 = ttime() + t1 = ttime() + text_language = dict_language[text_language.lower()] + phones1, bert1, norm_text1 = ref.phone, ref.bert_feature, ref.norm_text + texts = text.split("\n") + audio_bytes = BytesIO() + + for text in texts: + # 简单防止纯符号引发参考音频泄露 + if only_punc(text): + continue + print(text) + + audio_opt = [] + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language) + bert = torch.cat([bert1, bert2], 1) + + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + prompt = ref.prompt_semantic.unsqueeze(0).to(device) + t2 = ttime() + + zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) + + with torch.no_grad(): + # pred_semantic = t2s_model.model.infer( + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + prompt, + bert, + # prompt_phone_len=ph_offset, + top_k=config['inference']['top_k'], + early_stop_num=hz * max_sec) + t3 = ttime() + # print(pred_semantic.shape,idx) + if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): # 神秘代码,有些时候sys.path会出问题,import的是fast inference分支的AR + pred_semantic = pred_semantic[0] + idx=idx[0] + pred_semantic = pred_semantic[-idx:] + pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0) + else: + pred_semantic = pred_semantic[:,-idx:] + pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 + + # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] + audio = \ + vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), + ref.refer_spec).detach().cpu().numpy()[0, 0] ###试试重建不带上prompt部分 + audio_opt.append(audio) + audio_opt.append(zero_wav) + t4 = ttime() + audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) + logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + if return_fragment: + audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) + yield audio_chunk + + if not return_fragment: + if media_type == "wav": + audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate) + yield audio_bytes.getvalue() + + def preprocess(text:list, lang:str)->List[Dict]: result = [] for _text in text: @@ -559,88 +624,88 @@ def batch_sequences(sequences: List[torch.Tensor], axis:int = 0, pad_value:int = def to_batch(data:list, ref:REF, - threshold:float=0.75, - ): + threshold:float=0.75, + ): - _data:list = [] - index_and_len_list = [] - for idx, item in enumerate(data): - norm_text_len = len(item["norm_text"]) - index_and_len_list.append([idx, norm_text_len]) + _data:list = [] + index_and_len_list = [] + for idx, item in enumerate(data): + norm_text_len = len(item["norm_text"]) + index_and_len_list.append([idx, norm_text_len]) - batch_index_list = [] - if split_bucket: - index_and_len_list.sort(key=lambda x: x[1]) - index_and_len_list = np.array(index_and_len_list, dtype=np.int64) + batch_index_list = [] + if split_bucket: + index_and_len_list.sort(key=lambda x: x[1]) + index_and_len_list = np.array(index_and_len_list, dtype=np.int64) + logger.info("batch_size: "+str(batch_size)) + batch_index_list_len = 0 + pos = 0 + while pos =threshold) or (pos_end-pos==1): + batch_index=index_and_len_list[pos:pos_end, 0].tolist() + batch_index_list_len += len(batch_index) + batch_index_list.append(batch_index) + pos = pos_end + break + pos_end=pos_end-1 - batch_index_list_len = 0 - pos = 0 - while pos =threshold) or (pos_end-pos==1): - batch_index=index_and_len_list[pos:pos_end, 0].tolist() - batch_index_list_len += len(batch_index) - batch_index_list.append(batch_index) - pos = pos_end - break - pos_end=pos_end-1 + assert batch_index_list_len == len(data) - assert batch_index_list_len == len(data) - - else: - for i in range(len(data)): - if i%batch_size == 0: - batch_index_list.append([]) - batch_index_list[-1].append(i) + else: + for i in range(len(data)): + if i%batch_size == 0: + batch_index_list.append([]) + batch_index_list[-1].append(i) - for batch_idx, index_list in enumerate(batch_index_list): - item_list = [data[idx] for idx in index_list] - phones_list = [] - phones_len_list = [] - # bert_features_list = [] - all_phones_list = [] - all_phones_len_list = [] - all_bert_features_list = [] - norm_text_batch = [] - bert_max_len = 0 - phones_max_len = 0 - for item in item_list: - all_bert_features = torch.cat([ref.bert_feature, item["bert_features"]], 1).to(dtype=precision, device=device) - all_phones = torch.LongTensor(ref.phone+item["phones"]).to(device) - phones = torch.LongTensor(item["phones"]).to(device) - # norm_text = ref.norm_text+item["norm_text"] + for batch_idx, index_list in enumerate(batch_index_list): + item_list = [data[idx] for idx in index_list] + phones_list = [] + phones_len_list = [] + # bert_features_list = [] + all_phones_list = [] + all_phones_len_list = [] + all_bert_features_list = [] + norm_text_batch = [] + bert_max_len = 0 + phones_max_len = 0 + for item in item_list: + all_bert_features = torch.cat([ref.bert_feature, item["bert_features"]], 1).to(dtype=precision, device=device) + all_phones = torch.LongTensor(ref.phone+item["phones"]).to(device) + phones = torch.LongTensor(item["phones"]).to(device) + # norm_text = ref.norm_text+item["norm_text"] - bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) - phones_max_len = max(phones_max_len, phones.shape[-1]) + bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) + phones_max_len = max(phones_max_len, phones.shape[-1]) - phones_list.append(phones) - phones_len_list.append(phones.shape[-1]) - all_phones_list.append(all_phones) - all_phones_len_list.append(all_phones.shape[-1]) - all_bert_features_list.append(all_bert_features) - norm_text_batch.append(item["norm_text"]) + phones_list.append(phones) + phones_len_list.append(phones.shape[-1]) + all_phones_list.append(all_phones) + all_phones_len_list.append(all_phones.shape[-1]) + all_bert_features_list.append(all_bert_features) + norm_text_batch.append(item["norm_text"]) - phones_batch = phones_list - all_phones_batch = all_phones_list - all_bert_features_batch = all_bert_features_list + phones_batch = phones_list + all_phones_batch = all_phones_list + all_bert_features_batch = all_bert_features_list - batch = { - "phones": phones_batch, - "phones_len": torch.LongTensor(phones_len_list).to(device), - "all_phones": all_phones_batch, - "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), - "all_bert_features": all_bert_features_batch, - "norm_text": norm_text_batch + batch = { + "phones": phones_batch, + "phones_len": torch.LongTensor(phones_len_list).to(device), + "all_phones": all_phones_batch, + "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), + "all_bert_features": all_bert_features_batch, + "norm_text": norm_text_batch } - _data.append(batch) + _data.append(batch) - return _data, batch_index_list + return _data, batch_index_list def recovery_order(data:list, batch_index_list:list)->list: @@ -663,204 +728,138 @@ def recovery_order(data:list, batch_index_list:list)->list: def run(ref:REF, text, text_lang): - logger.info("run") + logger.info("run") + logger.info(f"batch_size: {batch_size}") - ########## variables initialization ########### - top_k = 5 - top_p = 1 - temperature = 1 - batch_threshold = 0.75 - fragment_interval = 0.3 - text_lang = dict_language[text_lang.lower()] + ########## variables initialization ########### + top_k = 5 + top_p = 1 + temperature = 1 + batch_threshold = 0.75 + fragment_interval = 0.3 + text_lang = dict_language[text_lang.lower()] - if ref.path in [None, ""] or \ - ((ref.prompt_semantic is None) or (ref.refer_spec is None)): - raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") + if ref.path in [None, ""] or \ + ((ref.prompt_semantic is None) or (ref.refer_spec is None)): + raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") - t0 = ttime() - ###### text preprocessing ######## - t1 = ttime() - data:list = None - if not return_fragment: - data = text.split("\n") - if len(data) == 0: - yield np.zeros(int(hps.data.sampling_rate), type=np.int16) - return - - batch_index_list:list = None - data = preprocess(data, text_lang) - data, batch_index_list = to_batch(data, ref, - threshold=batch_threshold, - ) - else: - texts = text.split("\n") - data = [] - for i in range(len(texts)): - if i%batch_size == 0: - data.append([]) - data[-1].append(texts[i]) - - def make_batch(batch_texts): - batch_data = [] - batch_data = preprocess(batch_texts, text_lang) - if len(batch_data) == 0: - return None - batch, _ = to_batch(batch_data, ref, - threshold=batch_threshold, - ) - return batch[0] - - t2 = ttime() - try: - ###### inference ###### - t_34 = 0.0 - t_45 = 0.0 - audio = [] - for item in data: - t3 = ttime() - if return_fragment: - item = make_batch(item) - if item is None: - continue - - batch_phones:List[torch.LongTensor] = item["phones"] - batch_phones_len:torch.LongTensor = item["phones_len"] - all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] - all_phoneme_lens:torch.LongTensor = item["all_phones_len"] - all_bert_features:List[torch.LongTensor] = item["all_bert_features"] - norm_text:str = item["norm_text"] - - print(norm_text) - - prompt = ref.prompt_semantic.expand(len(all_phoneme_ids), -1).to(device) - - with torch.no_grad(): - pred_semantic_list, idx_list = t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_lens, - prompt, - all_bert_features, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=hz * max_sec, - ) - t4 = ttime() - t_34 += t4 - t3 - - refer_audio_spec:torch.Tensor = ref.refer_spec.to(dtype=precision, device=device) - - batch_audio_fragment = [] - - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(vq_model.upsample_rates) - audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] - audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(device) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(device) - _batch_audio_fragment = (vq_model.decode( - all_pred_semantic, _batch_phones,refer_audio_spec - ).detach()[0, 0, :]) - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] - - - t5 = ttime() - t_45 += t5 - t4 - if return_fragment: - logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) - yield audio_postprocess([batch_audio_fragment], - hps.data.sampling_rate, - None, - fragment_interval - ) - else: - audio.append(batch_audio_fragment) - - logger.info("return_fragment:"+str(return_fragment)+" split_bucket:"+str(split_bucket)+" batch_size"+str(batch_size)+" media_type:"+media_type) - if not return_fragment: - logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) - yield audio_postprocess(audio, - hps.data.sampling_rate, - batch_index_list, - fragment_interval - ) - - except Exception as e: - traceback.print_exc() - # 必须返回一个空音频, 否则会导致显存不释放。 - yield np.zeros(int(hps.data.sampling_rate), dtype=np.int16) - finally: - pass - - -def get_tts_wav(ref:REF, text, text_language): - logger.info("get_tts_wav") t0 = ttime() + ###### text preprocessing ######## t1 = ttime() - text_language = dict_language[text_language.lower()] - phones1, bert1, norm_text1 = ref.phone, ref.bert_feature, ref.norm_text - texts = text.split("\n") - audio_bytes = BytesIO() + data:list = None + if not return_fragment: + data = text.split("\n") + if len(data) == 0: + yield np.zeros(int(hps.data.sampling_rate), type=np.int16) + return + + batch_index_list:list = None + data = preprocess(data, text_lang) + data, batch_index_list = to_batch(data, ref, + threshold=batch_threshold, + ) + else: + texts = text.split("\n") + data = [] + for i in range(len(texts)): + if i%batch_size == 0: + data.append([]) + data[-1].append(texts[i]) + + def make_batch(batch_texts): + batch_data = [] + batch_data = preprocess(batch_texts, text_lang) + if len(batch_data) == 0: + return None + batch, _ = to_batch(batch_data, ref, + threshold=batch_threshold, + ) + return batch[0] + + t2 = ttime() + try: + ###### inference ###### + t_34 = 0.0 + t_45 = 0.0 + audio = [] + for item in data: + t3 = ttime() + if return_fragment: + item = make_batch(item) + if item is None: + continue + + batch_phones:List[torch.LongTensor] = item["phones"] + batch_phones_len:torch.LongTensor = item["phones_len"] + all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] + all_phoneme_lens:torch.LongTensor = item["all_phones_len"] + all_bert_features:List[torch.LongTensor] = item["all_bert_features"] + norm_text:str = item["norm_text"] + + print(norm_text) + + prompt = ref.prompt_semantic.expand(len(all_phoneme_ids), -1).to(device) + + with torch.no_grad(): + pred_semantic_list, idx_list = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_lens, + prompt, + all_bert_features, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + t4 = ttime() + t_34 += t4 - t3 + + refer_audio_spec:torch.Tensor = ref.refer_spec.to(dtype=precision, device=device) + + batch_audio_fragment = [] - for text in texts: - # 简单防止纯符号引发参考音频泄露 - if only_punc(text): - continue - print(text) + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(vq_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(device) + _batch_audio_fragment = (vq_model.decode( + all_pred_semantic, _batch_phones,refer_audio_spec + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + + + t5 = ttime() + t_45 += t5 - t4 + if return_fragment: + logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) + yield audio_postprocess([batch_audio_fragment], + hps.data.sampling_rate, + None, + fragment_interval + ) + else: + audio.append(batch_audio_fragment) + + if not return_fragment: + logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + yield audio_postprocess(audio, + hps.data.sampling_rate, + batch_index_list, + fragment_interval + ) - audio_opt = [] - phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language) - bert = torch.cat([bert1, bert2], 1) + except Exception as e: + traceback.print_exc() + # 必须返回一个空音频, 否则会导致显存不释放。 + yield np.zeros(int(hps.data.sampling_rate), dtype=np.int16) + finally: + pass - all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - prompt = ref.prompt_semantic.unsqueeze(0).to(device) - t2 = ttime() - - zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) - - with torch.no_grad(): - # pred_semantic = t2s_model.model.infer( - pred_semantic, idx = t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_len, - prompt, - bert, - # prompt_phone_len=ph_offset, - top_k=config['inference']['top_k'], - early_stop_num=hz * max_sec) - t3 = ttime() - # print(pred_semantic.shape,idx) - if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): # 神秘代码,有些时候sys.path会出问题,import的是fast inference分支的AR - pred_semantic = pred_semantic[0] - idx=idx[0] - pred_semantic = pred_semantic[-idx:] - pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0) - else: - pred_semantic = pred_semantic[:,-idx:] - pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 - - # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] - audio = \ - vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - ref.refer_spec).detach().cpu().numpy()[0, 0] ###试试重建不带上prompt部分 - audio_opt.append(audio) - audio_opt.append(zero_wav) - t4 = ttime() - audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) - logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) - if return_fragment: - audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) - yield audio_chunk - - if not return_fragment: - if media_type == "wav": - audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate) - yield audio_bytes.getvalue() # -------------------------------- # 初始化部分 @@ -885,10 +884,10 @@ parser.add_argument("-dl", "--default_refer_language", type=str, default="", hel parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") -parser.add_argument("-bs", "--batch_size", type=int, default=1, help="批处理大小") +parser.add_argument("-bs", "--batch_size", type=int, default=2, help="批处理大小") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") -parser.add_argument("-rf", "--return_fragment", action="store_true", default=False, help="是否开启碎片返回") +parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") parser.add_argument("-sb", "--split_bucket", action="store_true", default=False, help="是否将批处理分成多个桶") parser.add_argument("-fa", "--flash_atten", action="store_true", default=False, help="是否开启flash_attention") # bool值的用法为 `python ./api.py -fp ...` @@ -909,7 +908,7 @@ cnhubert_base_path = args.hubert_path bert_path = args.bert_path default_cut_punc = args.cut_punc batch_size = args.batch_size -return_fragment = args.return_fragment +stream_mode = args.stream_mode split_bucket = args.split_bucket flash_atten = args.flash_atten @@ -955,6 +954,16 @@ precision = torch.float16 if is_half else torch.float32 device = torch.device(device) +##流式返回 +if stream_mode.lower() in ["normal","n"]: + stream_mode = "normal" + return_fragment = True + logger.info("流式返回已开启") +else: + stream_mode = "close" + return_fragment = False + + # 音频编码格式 if args.media_type.lower() in ["aac","ogg"]: media_type = args.media_type.lower() @@ -1019,8 +1028,8 @@ def handle_control(command): exit(0) -def handle_change(path, text, language): - global default_refer +def handle_change(path, text, language, cut_punc): + global default_refer, default_cut_punc if is_empty(path, text, language): return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400) @@ -1028,10 +1037,14 @@ def handle_change(path, text, language): (text != "" or text is not None) and\ (language != "" or language is not None): default_refer = REF(path, text, language) + default_refer.set_ref_audio() + if (cut_punc !="" or cut_punc is not None): + default_cut_punc = cut_punc logger.info(f"当前默认参考音频路径: {default_refer.path}") logger.info(f"当前默认参考音频文本: {default_refer.text}") logger.info(f"当前默认参考音频语种: {default_refer.language}") + logger.info(f"当前默认切分符号: {default_cut_punc}") logger.info(f"is_ready: {default_refer.is_ready()}") @@ -1043,6 +1056,8 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu (prompt_text != default_refer.text) or\ (prompt_language != default_refer.language): ref = REF(refer_wav_path, prompt_text, prompt_language) + if ref.is_ready(): + ref.set_ref_audio else: ref = default_refer @@ -1055,11 +1070,12 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) - if cut_punc == None: + if cut_punc == "" or cut_punc is None: text = cut_text(text,default_cut_punc) else: text = cut_text(text,cut_punc) + if is_fast_inference: @@ -1080,9 +1096,11 @@ async def set_model(request: Request): gpt_path=json_post_raw.get("gpt_model_path") global sovits_path sovits_path=json_post_raw.get("sovits_model_path") - logger.info("gptpath"+gpt_path+";vitspath"+sovits_path) + logger.info("gptpath: "+gpt_path) + logger.info("vitspath: "+sovits_path) change_sovits_weights(sovits_path) change_gpt_weights(gpt_path) + default_refer.set_ref_audio() return "ok" @@ -1103,7 +1121,8 @@ async def change_refer(request: Request): return handle_change( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), - json_post_raw.get("prompt_language") + json_post_raw.get("prompt_language"), + json_post_raw.get("cut_punc") ) @@ -1111,9 +1130,10 @@ async def change_refer(request: Request): async def change_refer( refer_wav_path: str = None, prompt_text: str = None, - prompt_language: str = None + prompt_language: str = None, + cut_punc:str = None ): - return handle_change(refer_wav_path, prompt_text, prompt_language) + return handle_change(refer_wav_path, prompt_text, prompt_language,cut_punc) @app.post("/") @@ -1131,12 +1151,12 @@ async def tts_endpoint(request: Request): @app.get("/") async def tts_endpoint( - refer_wav_path: str = "", - prompt_text: str = "", - prompt_language: str = "", - text: str = "", - text_language: str = "", - cut_punc: str = "", + refer_wav_path: str = None, + prompt_text: str = None, + prompt_language: str = None, + text: str = None, + text_language: str = None, + cut_punc: str = None, ): return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc)