Fixed some bug

This commit is contained in:
XXXXRT 2024-04-03 23:33:02 +01:00 committed by XXXXRT666
parent b50ab5419f
commit f45fff72d9

716
api.py
View File

@ -158,6 +158,78 @@ import math
i18n = I18nAuto()
class REF:
def __init__(self, ref_path="", ref_text="", ref_language=""):
if ref_text:
ref_text = ref_text.strip("\n")
if (ref_text[-1] not in splits): ref_text += "" if ref_language != "en" else "."
if ref_language:
ref_language = dict_language[ref_language.lower()]
self.path = ref_path
self.text = ref_text
self.language = ref_language
self.prompt_semantic = None
self.refer_spec = None
def set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
if is_half:
wav16k = wav16k.half()
zero_wav_torch = zero_wav_torch.half()
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0].to(device)
self.prompt_semantic = prompt_semantic
self.codes = codes
self.ssl_content = ssl_content
def set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
spec = spec.to(device)
if is_half:
spec = spec.half()
# self.refer_spec = spec
self.refer_spec = spec
def set_ref_audio(self):
'''
To set the reference audio for the TTS model,
including the prompt_semantic and refer_spec.
Args:
ref_audio_path: str, the path of the reference audio.
'''
self.set_prompt_semantic(self.path)
self.set_ref_spec(self.path)
self.phone, self.bert_feature, self.norm_text = get_phones_and_bert(self.text, self.language)
def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language)
def is_empty(*items): # 任意一项不为空返回False
for item in items:
@ -193,6 +265,7 @@ def change_sovits_weights(sovits_path):
vq_model = vq_model.to(device)
vq_model.eval()
vq_model.load_state_dict(dict_s2["weight"], strict=False)
def change_gpt_weights(gpt_path):
@ -243,10 +316,6 @@ def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
bert = torch.zeros(
(1024, len(phones)),
dtype=precision,
).to(device)
else:
bert = torch.zeros(
(1024, len(phones)),
@ -332,77 +401,6 @@ class DictToAttrRecursive:
setattr(self, key, value)
class REF:
def __init__(self, ref_path="", ref_text="", ref_language=""):
ref_text = ref_text.strip("\n")
if ref_text:
if (ref_text[-1] not in splits): ref_text += "" if ref_language != "en" else "."
if ref_language:
ref_language = dict_language[ref_language.lower()]
self.path = ref_path
self.text = ref_text
self.language = ref_language
def set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
if is_half:
wav16k = wav16k.half()
zero_wav_torch = zero_wav_torch.half()
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0].to(device)
self.prompt_semantic = prompt_semantic
self.codes = codes
self.ssl_content = ssl_content
def set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
spec = spec.to(device)
if is_half:
spec = spec.half()
# self.refer_spec = spec
self.refer_spec = spec
def set_ref_audio(self):
'''
To set the reference audio for the TTS model,
including the prompt_semantic and refer_spec.
Args:
ref_audio_path: str, the path of the reference audio.
'''
self.set_prompt_semantic(self.path)
self.set_ref_spec(self.path)
self.phone, self.bert_feature, self.norm_text = get_phones_and_bert(self.text, self.language)
def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language)
def pack_audio(audio_bytes, data, rate):
if media_type == "ogg":
audio_bytes = pack_ogg(audio_bytes, data, rate)
@ -485,6 +483,73 @@ def only_punc(text):
return not any(t.isalnum() or t.isalpha() for t in text)
def get_tts_wav(ref:REF, text, text_language):
logger.info("get_tts_wav")
t0 = ttime()
t1 = ttime()
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = ref.phone, ref.bert_feature, ref.norm_text
texts = text.split("\n")
audio_bytes = BytesIO()
for text in texts:
# 简单防止纯符号引发参考音频泄露
if only_punc(text):
continue
print(text)
audio_opt = []
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = ref.prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): # 神秘代码,有些时候sys.path会出问题,import的是fast inference分支的AR
pred_semantic = pred_semantic[0]
idx=idx[0]
pred_semantic = pred_semantic[-idx:]
pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0)
else:
pred_semantic = pred_semantic[:,-idx:]
pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
ref.refer_spec).detach().cpu().numpy()[0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if return_fragment:
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
yield audio_chunk
if not return_fragment:
if media_type == "wav":
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
yield audio_bytes.getvalue()
def preprocess(text:list, lang:str)->List[Dict]:
result = []
for _text in text:
@ -559,88 +624,88 @@ def batch_sequences(sequences: List[torch.Tensor], axis:int = 0, pad_value:int =
def to_batch(data:list, ref:REF,
threshold:float=0.75,
):
threshold:float=0.75,
):
_data:list = []
index_and_len_list = []
for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
_data:list = []
index_and_len_list = []
for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
batch_index_list = []
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list = []
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
logger.info("batch_size: "+str(batch_size))
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
assert batch_index_list_len == len(data)
assert batch_index_list_len == len(data)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
for item in item_list:
all_bert_features = torch.cat([ref.bert_feature, item["bert_features"]], 1).to(dtype=precision, device=device)
all_phones = torch.LongTensor(ref.phone+item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = ref.norm_text+item["norm_text"]
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
for item in item_list:
all_bert_features = torch.cat([ref.bert_feature, item["bert_features"]], 1).to(dtype=precision, device=device)
all_phones = torch.LongTensor(ref.phone+item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = ref.norm_text+item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_batch = phones_list
all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list
phones_batch = phones_list
all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list
batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch
batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch
}
_data.append(batch)
_data.append(batch)
return _data, batch_index_list
return _data, batch_index_list
def recovery_order(data:list, batch_index_list:list)->list:
@ -663,204 +728,138 @@ def recovery_order(data:list, batch_index_list:list)->list:
def run(ref:REF, text, text_lang):
logger.info("run")
logger.info("run")
logger.info(f"batch_size: {batch_size}")
########## variables initialization ###########
top_k = 5
top_p = 1
temperature = 1
batch_threshold = 0.75
fragment_interval = 0.3
text_lang = dict_language[text_lang.lower()]
########## variables initialization ###########
top_k = 5
top_p = 1
temperature = 1
batch_threshold = 0.75
fragment_interval = 0.3
text_lang = dict_language[text_lang.lower()]
if ref.path in [None, ""] or \
((ref.prompt_semantic is None) or (ref.refer_spec is None)):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
if ref.path in [None, ""] or \
((ref.prompt_semantic is None) or (ref.refer_spec is None)):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
t0 = ttime()
###### text preprocessing ########
t1 = ttime()
data:list = None
if not return_fragment:
data = text.split("\n")
if len(data) == 0:
yield np.zeros(int(hps.data.sampling_rate), type=np.int16)
return
batch_index_list:list = None
data = preprocess(data, text_lang)
data, batch_index_list = to_batch(data, ref,
threshold=batch_threshold,
)
else:
texts = text.split("\n")
data = []
for i in range(len(texts)):
if i%batch_size == 0:
data.append([])
data[-1].append(texts[i])
def make_batch(batch_texts):
batch_data = []
batch_data = preprocess(batch_texts, text_lang)
if len(batch_data) == 0:
return None
batch, _ = to_batch(batch_data, ref,
threshold=batch_threshold,
)
return batch[0]
t2 = ttime()
try:
###### inference ######
t_34 = 0.0
t_45 = 0.0
audio = []
for item in data:
t3 = ttime()
if return_fragment:
item = make_batch(item)
if item is None:
continue
batch_phones:List[torch.LongTensor] = item["phones"]
batch_phones_len:torch.LongTensor = item["phones_len"]
all_phoneme_ids:List[torch.LongTensor] = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
all_bert_features:List[torch.LongTensor] = item["all_bert_features"]
norm_text:str = item["norm_text"]
print(norm_text)
prompt = ref.prompt_semantic.expand(len(all_phoneme_ids), -1).to(device)
with torch.no_grad():
pred_semantic_list, idx_list = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
t4 = ttime()
t_34 += t4 - t3
refer_audio_spec:torch.Tensor = ref.refer_spec.to(dtype=precision, device=device)
batch_audio_fragment = []
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(vq_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(device)
_batch_audio_fragment = (vq_model.decode(
all_pred_semantic, _batch_phones,refer_audio_spec
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
t5 = ttime()
t_45 += t5 - t4
if return_fragment:
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield audio_postprocess([batch_audio_fragment],
hps.data.sampling_rate,
None,
fragment_interval
)
else:
audio.append(batch_audio_fragment)
logger.info("return_fragment:"+str(return_fragment)+" split_bucket:"+str(split_bucket)+" batch_size"+str(batch_size)+" media_type:"+media_type)
if not return_fragment:
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield audio_postprocess(audio,
hps.data.sampling_rate,
batch_index_list,
fragment_interval
)
except Exception as e:
traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。
yield np.zeros(int(hps.data.sampling_rate), dtype=np.int16)
finally:
pass
def get_tts_wav(ref:REF, text, text_language):
logger.info("get_tts_wav")
t0 = ttime()
###### text preprocessing ########
t1 = ttime()
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = ref.phone, ref.bert_feature, ref.norm_text
texts = text.split("\n")
audio_bytes = BytesIO()
data:list = None
if not return_fragment:
data = text.split("\n")
if len(data) == 0:
yield np.zeros(int(hps.data.sampling_rate), type=np.int16)
return
batch_index_list:list = None
data = preprocess(data, text_lang)
data, batch_index_list = to_batch(data, ref,
threshold=batch_threshold,
)
else:
texts = text.split("\n")
data = []
for i in range(len(texts)):
if i%batch_size == 0:
data.append([])
data[-1].append(texts[i])
def make_batch(batch_texts):
batch_data = []
batch_data = preprocess(batch_texts, text_lang)
if len(batch_data) == 0:
return None
batch, _ = to_batch(batch_data, ref,
threshold=batch_threshold,
)
return batch[0]
t2 = ttime()
try:
###### inference ######
t_34 = 0.0
t_45 = 0.0
audio = []
for item in data:
t3 = ttime()
if return_fragment:
item = make_batch(item)
if item is None:
continue
batch_phones:List[torch.LongTensor] = item["phones"]
batch_phones_len:torch.LongTensor = item["phones_len"]
all_phoneme_ids:List[torch.LongTensor] = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
all_bert_features:List[torch.LongTensor] = item["all_bert_features"]
norm_text:str = item["norm_text"]
print(norm_text)
prompt = ref.prompt_semantic.expand(len(all_phoneme_ids), -1).to(device)
with torch.no_grad():
pred_semantic_list, idx_list = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
t4 = ttime()
t_34 += t4 - t3
refer_audio_spec:torch.Tensor = ref.refer_spec.to(dtype=precision, device=device)
batch_audio_fragment = []
for text in texts:
# 简单防止纯符号引发参考音频泄露
if only_punc(text):
continue
print(text)
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(vq_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(device)
_batch_audio_fragment = (vq_model.decode(
all_pred_semantic, _batch_phones,refer_audio_spec
).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
t5 = ttime()
t_45 += t5 - t4
if return_fragment:
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield audio_postprocess([batch_audio_fragment],
hps.data.sampling_rate,
None,
fragment_interval
)
else:
audio.append(batch_audio_fragment)
if not return_fragment:
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield audio_postprocess(audio,
hps.data.sampling_rate,
batch_index_list,
fragment_interval
)
audio_opt = []
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
bert = torch.cat([bert1, bert2], 1)
except Exception as e:
traceback.print_exc()
# 必须返回一个空音频, 否则会导致显存不释放。
yield np.zeros(int(hps.data.sampling_rate), dtype=np.int16)
finally:
pass
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = ref.prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
if isinstance(pred_semantic, list) and isinstance(pred_semantic, list): # 神秘代码,有些时候sys.path会出问题,import的是fast inference分支的AR
pred_semantic = pred_semantic[0]
idx=idx[0]
pred_semantic = pred_semantic[-idx:]
pred_semantic = pred_semantic.unsqueeze(0).unsqueeze(0)
else:
pred_semantic = pred_semantic[:,-idx:]
pred_semantic = pred_semantic.unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
ref.refer_spec).detach().cpu().numpy()[0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if return_fragment:
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
yield audio_chunk
if not return_fragment:
if media_type == "wav":
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
yield audio_bytes.getvalue()
# --------------------------------
# 初始化部分
@ -885,10 +884,10 @@ parser.add_argument("-dl", "--default_refer_language", type=str, default="", hel
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-bs", "--batch_size", type=int, default=1, help="批处理大小")
parser.add_argument("-bs", "--batch_size", type=int, default=2, help="批处理大小")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
parser.add_argument("-rf", "--return_fragment", action="store_true", default=False, help="是否开启碎片返回")
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
parser.add_argument("-sb", "--split_bucket", action="store_true", default=False, help="是否将批处理分成多个桶")
parser.add_argument("-fa", "--flash_atten", action="store_true", default=False, help="是否开启flash_attention")
# bool值的用法为 `python ./api.py -fp ...`
@ -909,7 +908,7 @@ cnhubert_base_path = args.hubert_path
bert_path = args.bert_path
default_cut_punc = args.cut_punc
batch_size = args.batch_size
return_fragment = args.return_fragment
stream_mode = args.stream_mode
split_bucket = args.split_bucket
flash_atten = args.flash_atten
@ -955,6 +954,16 @@ precision = torch.float16 if is_half else torch.float32
device = torch.device(device)
##流式返回
if stream_mode.lower() in ["normal","n"]:
stream_mode = "normal"
return_fragment = True
logger.info("流式返回已开启")
else:
stream_mode = "close"
return_fragment = False
# 音频编码格式
if args.media_type.lower() in ["aac","ogg"]:
media_type = args.media_type.lower()
@ -1019,8 +1028,8 @@ def handle_control(command):
exit(0)
def handle_change(path, text, language):
global default_refer
def handle_change(path, text, language, cut_punc):
global default_refer, default_cut_punc
if is_empty(path, text, language):
return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
@ -1028,10 +1037,14 @@ def handle_change(path, text, language):
(text != "" or text is not None) and\
(language != "" or language is not None):
default_refer = REF(path, text, language)
default_refer.set_ref_audio()
if (cut_punc !="" or cut_punc is not None):
default_cut_punc = cut_punc
logger.info(f"当前默认参考音频路径: {default_refer.path}")
logger.info(f"当前默认参考音频文本: {default_refer.text}")
logger.info(f"当前默认参考音频语种: {default_refer.language}")
logger.info(f"当前默认切分符号: {default_cut_punc}")
logger.info(f"is_ready: {default_refer.is_ready()}")
@ -1043,6 +1056,8 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
(prompt_text != default_refer.text) or\
(prompt_language != default_refer.language):
ref = REF(refer_wav_path, prompt_text, prompt_language)
if ref.is_ready():
ref.set_ref_audio
else:
ref = default_refer
@ -1055,11 +1070,12 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if cut_punc == None:
if cut_punc == "" or cut_punc is None:
text = cut_text(text,default_cut_punc)
else:
text = cut_text(text,cut_punc)
if is_fast_inference:
@ -1080,9 +1096,11 @@ async def set_model(request: Request):
gpt_path=json_post_raw.get("gpt_model_path")
global sovits_path
sovits_path=json_post_raw.get("sovits_model_path")
logger.info("gptpath"+gpt_path+";vitspath"+sovits_path)
logger.info("gptpath: "+gpt_path)
logger.info("vitspath: "+sovits_path)
change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path)
default_refer.set_ref_audio()
return "ok"
@ -1103,7 +1121,8 @@ async def change_refer(request: Request):
return handle_change(
json_post_raw.get("refer_wav_path"),
json_post_raw.get("prompt_text"),
json_post_raw.get("prompt_language")
json_post_raw.get("prompt_language"),
json_post_raw.get("cut_punc")
)
@ -1111,9 +1130,10 @@ async def change_refer(request: Request):
async def change_refer(
refer_wav_path: str = None,
prompt_text: str = None,
prompt_language: str = None
prompt_language: str = None,
cut_punc:str = None
):
return handle_change(refer_wav_path, prompt_text, prompt_language)
return handle_change(refer_wav_path, prompt_text, prompt_language,cut_punc)
@app.post("/")
@ -1131,12 +1151,12 @@ async def tts_endpoint(request: Request):
@app.get("/")
async def tts_endpoint(
refer_wav_path: str = "",
prompt_text: str = "",
prompt_language: str = "",
text: str = "",
text_language: str = "",
cut_punc: str = "",
refer_wav_path: str = None,
prompt_text: str = None,
prompt_language: str = None,
text: str = None,
text_language: str = None,
cut_punc: str = None,
):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc)