mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Optimize detail
This commit is contained in:
parent
49427c75bc
commit
dbb6b42fdb
19
api.py
19
api.py
@ -221,7 +221,7 @@ def get_sovits_weights(sovits_path):
|
|||||||
hps.model.version = "v1"
|
hps.model.version = "v1"
|
||||||
else:
|
else:
|
||||||
hps.model.version = "v2"
|
hps.model.version = "v2"
|
||||||
print("sovits版本:",hps.model.version)
|
logger.info(f"模型版本: {hps.model.version}")
|
||||||
model_params_dict = vars(hps.model)
|
model_params_dict = vars(hps.model)
|
||||||
vq_model = SynthesizerTrn(
|
vq_model = SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
@ -489,8 +489,7 @@ def pack_raw(audio_bytes, data, rate):
|
|||||||
def pack_wav(audio_bytes, rate):
|
def pack_wav(audio_bytes, rate):
|
||||||
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
|
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
|
||||||
wav_bytes = BytesIO()
|
wav_bytes = BytesIO()
|
||||||
sf.write(wav_bytes, data, rate, format='wav')
|
sf.write(wav_bytes, data, rate, format='WAV')
|
||||||
|
|
||||||
return wav_bytes
|
return wav_bytes
|
||||||
|
|
||||||
|
|
||||||
@ -543,6 +542,7 @@ def only_punc(text):
|
|||||||
return not any(t.isalnum() or t.isalpha() for t in text)
|
return not any(t.isalnum() or t.isalpha() for t in text)
|
||||||
|
|
||||||
|
|
||||||
|
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
|
||||||
infer_sovits = speaker_list[spk].sovits
|
infer_sovits = speaker_list[spk].sovits
|
||||||
vq_model = infer_sovits.vq_model
|
vq_model = infer_sovits.vq_model
|
||||||
@ -554,6 +554,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
|
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
|
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||||
prompt_language, text = prompt_language, text.strip("\n")
|
prompt_language, text = prompt_language, text.strip("\n")
|
||||||
dtype = torch.float16 if is_half == True else torch.float32
|
dtype = torch.float16 if is_half == True else torch.float32
|
||||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||||
@ -599,6 +600,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
|
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||||
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
||||||
bert = torch.cat([bert1, bert2], 1)
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
|
|
||||||
@ -607,7 +609,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# pred_semantic = t2s_model.model.infer(
|
|
||||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||||
all_phoneme_ids,
|
all_phoneme_ids,
|
||||||
all_phoneme_len,
|
all_phoneme_len,
|
||||||
@ -618,20 +619,20 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
top_p = top_p,
|
top_p = top_p,
|
||||||
temperature = temperature,
|
temperature = temperature,
|
||||||
early_stop_num=hz * max_sec)
|
early_stop_num=hz * max_sec)
|
||||||
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
|
||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
|
||||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
|
||||||
audio = \
|
audio = \
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||||
refers,speed=speed).detach().cpu().numpy()[
|
refers,speed=speed).detach().cpu().numpy()[
|
||||||
0, 0] ###试试重建不带上prompt部分
|
0, 0] ###试试重建不带上prompt部分
|
||||||
max_audio=np.abs(audio).max()#简单防止16bit爆音
|
max_audio=np.abs(audio).max()
|
||||||
if max_audio>1:audio/=max_audio
|
if max_audio>1:
|
||||||
|
audio/=max_audio
|
||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav)
|
audio_opt.append(zero_wav)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
|
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
|
||||||
|
# audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
|
||||||
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||||
if stream_mode == "normal":
|
if stream_mode == "normal":
|
||||||
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user