mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
推理batch优化
This commit is contained in:
parent
78ab26ea17
commit
9a2f4dc697
@ -321,6 +321,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def check_end(self, logits, samples):
|
||||
if torch.all(torch.argmax(logits, dim=-1) == self.EOS) or torch.all(samples[:, 0] == self.EOS):
|
||||
|
||||
return True
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
x, #####全部文本token
|
||||
@ -338,8 +343,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
|
||||
x_len = x.shape[1]
|
||||
ref_y_len = y.shape[1] if y is not None else 0
|
||||
x_len = x_lens.max()
|
||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||
stop = False
|
||||
# print(1111111,self.num_layers)
|
||||
@ -385,10 +390,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
x.device
|
||||
)
|
||||
|
||||
batch_end_index = [0 for _ in range(x.shape[0])]
|
||||
|
||||
for idx in tqdm(range(1500)):
|
||||
|
||||
|
||||
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
@ -397,17 +402,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if(idx==0):###第一次跑不能EOS否则没有了
|
||||
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
||||
samples = sample(
|
||||
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0].unsqueeze(0)
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0]
|
||||
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
||||
# print(samples.shape)#[1,1]#第一个1是bs
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
# print(y)
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
if self.check_end(logits, samples):
|
||||
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
||||
stop = True
|
||||
if stop:
|
||||
@ -443,6 +449,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = torch.zeros(
|
||||
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
|
||||
)
|
||||
batch_end_index = [min(torch.where(by == self.EOS)[0]) if self.EOS in by else idx for by in y ]
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx-1
|
||||
return y[:, :-1], [0 for _ in range(y.shape[0])], batch_end_index
|
||||
return y[:, :-1], [ref_y_len for _ in range(y.shape[0])], batch_end_index
|
||||
|
@ -115,17 +115,17 @@ def logits_to_probs(
|
||||
top_p: Optional[int] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
if previous_tokens is not None:
|
||||
previous_tokens = previous_tokens.squeeze()
|
||||
# if previous_tokens is not None:
|
||||
# previous_tokens = previous_tokens.squeeze()
|
||||
# print(logits.shape,previous_tokens.shape)
|
||||
# pdb.set_trace()
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
@ -133,9 +133,9 @@ def logits_to_probs(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@ -143,7 +143,7 @@ def logits_to_probs(
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
@ -448,10 +448,9 @@ def multi_head_attention_forward_patched(
|
||||
k = k.view(bsz, num_heads, src_len, head_dim)
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
)
|
||||
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
|
@ -297,7 +297,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
|
||||
|
||||
if self.norm_first:
|
||||
x = x + self._sa_block(
|
||||
self.norm1(x, stage_embedding),
|
||||
|
@ -66,6 +66,7 @@ from time import time as ttime
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from my_utils import load_audio
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from utils import tensor_padding
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
@ -311,7 +312,7 @@ def merge_short_text_in_array(texts, threshold):
|
||||
result[len(result) - 1] += text
|
||||
return result
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False, infer_batch = 1):
|
||||
if prompt_text is None or len(prompt_text) == 0:
|
||||
ref_free = True
|
||||
t0 = ttime()
|
||||
@ -371,28 +372,48 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
if not ref_free:
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
||||
|
||||
for text in texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||
print(i18n("实际输入的目标文本(每句):"), text)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
||||
if not ref_free:
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||
batch_num = len(texts) // infer_batch
|
||||
if len(texts) % infer_batch != 0:
|
||||
batch_num += 1
|
||||
for i in range(batch_num):
|
||||
if i == batch_num - 1:
|
||||
text = texts[i * infer_batch:]
|
||||
else:
|
||||
bert = bert2
|
||||
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
text = texts[i * infer_batch: (i + 1) * infer_batch]
|
||||
# for text in texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
# 过滤每一个需要合成的文本
|
||||
all_phoneme_ids_batch = []
|
||||
all_bert_batch = []
|
||||
all_phoneme_len_batch = []
|
||||
all_vits_phones2 = []
|
||||
for T in text:
|
||||
if (len(T.strip()) == 0):
|
||||
continue
|
||||
if (T[-1] not in splits): T += "。" if text_language != "en" else "."
|
||||
print(i18n("实际输入的目标文本(每句):"), T)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(T, text_language)
|
||||
all_vits_phones2+=phones2
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
||||
if not ref_free:
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||
else:
|
||||
bert = bert2
|
||||
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_ids_batch.append(all_phoneme_ids)
|
||||
all_bert_batch.append(bert)
|
||||
all_phoneme_len_batch.append(all_phoneme_ids.shape[-1])
|
||||
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
all_phoneme_ids = tensor_padding(all_phoneme_ids_batch)
|
||||
bert = tensor_padding(all_bert_batch)
|
||||
all_phoneme_len = torch.tensor(all_phoneme_len_batch).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device).expand(all_phoneme_ids.shape[0], -1, -1)
|
||||
t2 = ttime()
|
||||
with torch.no_grad():
|
||||
# pred_semantic = t2s_model.model.infer(
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
pred_semantic, sidx, eidx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
None if ref_free else prompt,
|
||||
@ -405,9 +426,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
)
|
||||
t3 = ttime()
|
||||
# print(pred_semantic.shape,idx)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
|
||||
0
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
pred_semantic = torch.cat([pred[si:ei] for pred, si, ei in zip(pred_semantic, sidx, eidx)], dim=0).unsqueeze(0).unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||
if is_half == True:
|
||||
refer = refer.half().to(device)
|
||||
@ -416,7 +435,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||
audio = (
|
||||
vq_model.decode(
|
||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
|
||||
pred_semantic, torch.LongTensor(all_vits_phones2).to(device).unsqueeze(0), refer
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
@ -588,12 +607,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(minimum=1,maximum=32,step=1,label=i18n("batch_size"),value=1,interactive=True)
|
||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
|
||||
inference_button.click(
|
||||
get_tts_wav,
|
||||
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
|
||||
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free, batch_size],
|
||||
[output],
|
||||
)
|
||||
|
||||
|
@ -6,6 +6,7 @@ import logging
|
||||
import json
|
||||
import subprocess
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
@ -270,7 +271,6 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
|
||||
del_routine = lambda x: [os.remove(x), del_info(x)]
|
||||
rs = [del_routine(fn) for fn in to_del]
|
||||
|
||||
|
||||
def get_hparams_from_dir(model_dir):
|
||||
config_save_path = os.path.join(model_dir, "config.json")
|
||||
with open(config_save_path, "r") as f:
|
||||
@ -331,6 +331,12 @@ def get_logger(model_dir, filename="train.log"):
|
||||
return logger
|
||||
|
||||
|
||||
def tensor_padding(tensor_list: List[torch.Tensor]):
|
||||
max_len = max([t.shape[-1] for t in tensor_list])
|
||||
tensor_list = [torch.cat([t, torch.zeros(*list(t.shape[:-1]), max_len - t.size(-1)).to(t)], dim=-1) for t in tensor_list]
|
||||
|
||||
return torch.cat(tensor_list, dim=0)
|
||||
|
||||
class HParams:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
|
1
gweight.txt
Normal file
1
gweight.txt
Normal file
@ -0,0 +1 @@
|
||||
GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
1
sweight.txt
Normal file
1
sweight.txt
Normal file
@ -0,0 +1 @@
|
||||
GPT_SoVITS/pretrained_models/s2G488k.pth
|
Loading…
x
Reference in New Issue
Block a user