推理batch优化

This commit is contained in:
Watchtower-Liu 2024-03-08 11:02:49 +08:00
parent 78ab26ea17
commit 9a2f4dc697
8 changed files with 80 additions and 45 deletions

View File

@ -321,6 +321,11 @@ class Text2SemanticDecoder(nn.Module):
# 错位
return targets[:, :-1], targets[:, 1:]
def check_end(self, logits, samples):
if torch.all(torch.argmax(logits, dim=-1) == self.EOS) or torch.all(samples[:, 0] == self.EOS):
return True
def infer_panel(
self,
x, #####全部文本token
@ -338,8 +343,8 @@ class Text2SemanticDecoder(nn.Module):
# AR Decoder
y = prompts
x_len = x.shape[1]
ref_y_len = y.shape[1] if y is not None else 0
x_len = x_lens.max()
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
@ -385,7 +390,7 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
batch_end_index = [0 for _ in range(x.shape[0])]
for idx in tqdm(range(1500)):
@ -397,17 +402,18 @@ class Text2SemanticDecoder(nn.Module):
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0].unsqueeze(0)
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
y = torch.concat([y, samples], dim=1)
# print(y)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if self.check_end(logits, samples):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
@ -443,6 +449,7 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
batch_end_index = [min(torch.where(by == self.EOS)[0]) if self.EOS in by else idx for by in y ]
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx-1
return y[:, :-1], [0 for _ in range(y.shape[0])], batch_end_index
return y[:, :-1], [ref_y_len for _ in range(y.shape[0])], batch_end_index

View File

@ -115,17 +115,17 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
if previous_tokens is not None:
previous_tokens = previous_tokens.squeeze()
# if previous_tokens is not None:
# previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@ -133,9 +133,9 @@ def logits_to_probs(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@ -143,7 +143,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
pivot = v[: , -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)

View File

@ -448,7 +448,6 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)

View File

@ -66,6 +66,7 @@ from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
from tools.i18n.i18n import I18nAuto
from utils import tensor_padding
i18n = I18nAuto()
@ -311,7 +312,7 @@ def merge_short_text_in_array(texts, threshold):
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False, infer_batch = 1):
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
t0 = ttime()
@ -371,13 +372,28 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
for text in texts:
batch_num = len(texts) // infer_batch
if len(texts) % infer_batch != 0:
batch_num += 1
for i in range(batch_num):
if i == batch_num - 1:
text = texts[i * infer_batch:]
else:
text = texts[i * infer_batch: (i + 1) * infer_batch]
# for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
# 过滤每一个需要合成的文本
all_phoneme_ids_batch = []
all_bert_batch = []
all_phoneme_len_batch = []
all_vits_phones2 = []
for T in text:
if (len(T.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
if (T[-1] not in splits): T += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), T)
phones2,bert2,norm_text2=get_phones_and_bert(T, text_language)
all_vits_phones2+=phones2
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)
@ -385,14 +401,19 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
else:
bert = bert2
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
all_phoneme_ids_batch.append(all_phoneme_ids)
all_bert_batch.append(bert)
all_phoneme_len_batch.append(all_phoneme_ids.shape[-1])
all_phoneme_ids = tensor_padding(all_phoneme_ids_batch)
bert = tensor_padding(all_bert_batch)
all_phoneme_len = torch.tensor(all_phoneme_len_batch).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device).expand(all_phoneme_ids.shape[0], -1, -1)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
pred_semantic, sidx, eidx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
None if ref_free else prompt,
@ -405,9 +426,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
pred_semantic = torch.cat([pred[si:ei] for pred, si, ei in zip(pred_semantic, sidx, eidx)], dim=0).unsqueeze(0).unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
@ -416,7 +435,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
pred_semantic, torch.LongTensor(all_vits_phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
@ -588,12 +607,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
with gr.Row():
batch_size = gr.Slider(minimum=1,maximum=32,step=1,label=i18n("batch_size"),value=1,interactive=True)
inference_button = gr.Button(i18n("合成语音"), variant="primary")
output = gr.Audio(label=i18n("输出的语音"))
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free, batch_size],
[output],
)

View File

@ -6,6 +6,7 @@ import logging
import json
import subprocess
import traceback
from typing import List
import librosa
import numpy as np
@ -270,7 +271,6 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
@ -331,6 +331,12 @@ def get_logger(model_dir, filename="train.log"):
return logger
def tensor_padding(tensor_list: List[torch.Tensor]):
max_len = max([t.shape[-1] for t in tensor_list])
tensor_list = [torch.cat([t, torch.zeros(*list(t.shape[:-1]), max_len - t.size(-1)).to(t)], dim=-1) for t in tensor_list]
return torch.cat(tensor_list, dim=0)
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():

1
gweight.txt Normal file
View File

@ -0,0 +1 @@
GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt

1
sweight.txt Normal file
View File

@ -0,0 +1 @@
GPT_SoVITS/pretrained_models/s2G488k.pth