mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
推理batch优化
This commit is contained in:
parent
78ab26ea17
commit
9a2f4dc697
@ -321,6 +321,11 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# 错位
|
# 错位
|
||||||
return targets[:, :-1], targets[:, 1:]
|
return targets[:, :-1], targets[:, 1:]
|
||||||
|
|
||||||
|
def check_end(self, logits, samples):
|
||||||
|
if torch.all(torch.argmax(logits, dim=-1) == self.EOS) or torch.all(samples[:, 0] == self.EOS):
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def infer_panel(
|
def infer_panel(
|
||||||
self,
|
self,
|
||||||
x, #####全部文本token
|
x, #####全部文本token
|
||||||
@ -338,8 +343,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
# AR Decoder
|
# AR Decoder
|
||||||
y = prompts
|
y = prompts
|
||||||
|
ref_y_len = y.shape[1] if y is not None else 0
|
||||||
x_len = x.shape[1]
|
x_len = x_lens.max()
|
||||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||||
stop = False
|
stop = False
|
||||||
# print(1111111,self.num_layers)
|
# print(1111111,self.num_layers)
|
||||||
@ -385,7 +390,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||||
x.device
|
x.device
|
||||||
)
|
)
|
||||||
|
batch_end_index = [0 for _ in range(x.shape[0])]
|
||||||
|
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
|
|
||||||
@ -397,17 +402,18 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
if(idx==0):###第一次跑不能EOS否则没有了
|
if(idx==0):###第一次跑不能EOS否则没有了
|
||||||
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
||||||
samples = sample(
|
samples = sample(
|
||||||
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||||
)[0].unsqueeze(0)
|
)[0]
|
||||||
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
||||||
# print(samples.shape)#[1,1]#第一个1是bs
|
# print(samples.shape)#[1,1]#第一个1是bs
|
||||||
y = torch.concat([y, samples], dim=1)
|
|
||||||
|
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
# print(y)
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
print("use early stop num:", early_stop_num)
|
print("use early stop num:", early_stop_num)
|
||||||
stop = True
|
stop = True
|
||||||
|
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
if self.check_end(logits, samples):
|
||||||
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
||||||
stop = True
|
stop = True
|
||||||
if stop:
|
if stop:
|
||||||
@ -443,6 +449,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
xy_attn_mask = torch.zeros(
|
xy_attn_mask = torch.zeros(
|
||||||
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
|
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
|
||||||
)
|
)
|
||||||
|
batch_end_index = [min(torch.where(by == self.EOS)[0]) if self.EOS in by else idx for by in y ]
|
||||||
if ref_free:
|
if ref_free:
|
||||||
return y[:, :-1], 0
|
return y[:, :-1], [0 for _ in range(y.shape[0])], batch_end_index
|
||||||
return y[:, :-1], idx-1
|
return y[:, :-1], [ref_y_len for _ in range(y.shape[0])], batch_end_index
|
||||||
|
@ -115,17 +115,17 @@ def logits_to_probs(
|
|||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
):
|
):
|
||||||
if previous_tokens is not None:
|
# if previous_tokens is not None:
|
||||||
previous_tokens = previous_tokens.squeeze()
|
# previous_tokens = previous_tokens.squeeze()
|
||||||
# print(logits.shape,previous_tokens.shape)
|
# print(logits.shape,previous_tokens.shape)
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||||
previous_tokens = previous_tokens.long()
|
previous_tokens = previous_tokens.long()
|
||||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||||
score = torch.where(
|
score = torch.where(
|
||||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||||
)
|
)
|
||||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
@ -133,9 +133,9 @@ def logits_to_probs(
|
|||||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||||
)
|
)
|
||||||
sorted_indices_to_remove = cum_probs > top_p
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
sorted_indices_to_remove[0] = False # keep at least one option
|
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||||
)
|
)
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ def logits_to_probs(
|
|||||||
|
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
pivot = v[: , -1].unsqueeze(-1)
|
||||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||||
|
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
@ -448,10 +448,9 @@ def multi_head_attention_forward_patched(
|
|||||||
k = k.view(bsz, num_heads, src_len, head_dim)
|
k = k.view(bsz, num_heads, src_len, head_dim)
|
||||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||||
|
|
||||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
|
||||||
attn_output = scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(
|
||||||
q, k, v, attn_mask, dropout_p, is_causal
|
q, k, v, attn_mask, dropout_p, is_causal
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = (
|
attn_output = (
|
||||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||||
|
@ -66,6 +66,7 @@ from time import time as ttime
|
|||||||
from module.mel_processing import spectrogram_torch
|
from module.mel_processing import spectrogram_torch
|
||||||
from my_utils import load_audio
|
from my_utils import load_audio
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
from utils import tensor_padding
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
@ -311,7 +312,7 @@ def merge_short_text_in_array(texts, threshold):
|
|||||||
result[len(result) - 1] += text
|
result[len(result) - 1] += text
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False, infer_batch = 1):
|
||||||
if prompt_text is None or len(prompt_text) == 0:
|
if prompt_text is None or len(prompt_text) == 0:
|
||||||
ref_free = True
|
ref_free = True
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
@ -371,28 +372,48 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
if not ref_free:
|
if not ref_free:
|
||||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
||||||
|
|
||||||
for text in texts:
|
batch_num = len(texts) // infer_batch
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
if len(texts) % infer_batch != 0:
|
||||||
if (len(text.strip()) == 0):
|
batch_num += 1
|
||||||
continue
|
for i in range(batch_num):
|
||||||
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
if i == batch_num - 1:
|
||||||
print(i18n("实际输入的目标文本(每句):"), text)
|
text = texts[i * infer_batch:]
|
||||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
|
||||||
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
|
||||||
if not ref_free:
|
|
||||||
bert = torch.cat([bert1, bert2], 1)
|
|
||||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
|
||||||
else:
|
else:
|
||||||
bert = bert2
|
text = texts[i * infer_batch: (i + 1) * infer_batch]
|
||||||
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
# for text in texts:
|
||||||
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
|
# 过滤每一个需要合成的文本
|
||||||
|
all_phoneme_ids_batch = []
|
||||||
|
all_bert_batch = []
|
||||||
|
all_phoneme_len_batch = []
|
||||||
|
all_vits_phones2 = []
|
||||||
|
for T in text:
|
||||||
|
if (len(T.strip()) == 0):
|
||||||
|
continue
|
||||||
|
if (T[-1] not in splits): T += "。" if text_language != "en" else "."
|
||||||
|
print(i18n("实际输入的目标文本(每句):"), T)
|
||||||
|
phones2,bert2,norm_text2=get_phones_and_bert(T, text_language)
|
||||||
|
all_vits_phones2+=phones2
|
||||||
|
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
||||||
|
if not ref_free:
|
||||||
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
|
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
bert = bert2
|
||||||
|
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||||
|
bert = bert.to(device).unsqueeze(0)
|
||||||
|
all_phoneme_ids_batch.append(all_phoneme_ids)
|
||||||
|
all_bert_batch.append(bert)
|
||||||
|
all_phoneme_len_batch.append(all_phoneme_ids.shape[-1])
|
||||||
|
|
||||||
bert = bert.to(device).unsqueeze(0)
|
all_phoneme_ids = tensor_padding(all_phoneme_ids_batch)
|
||||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
bert = tensor_padding(all_bert_batch)
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
all_phoneme_len = torch.tensor(all_phoneme_len_batch).to(device)
|
||||||
|
prompt = prompt_semantic.unsqueeze(0).to(device).expand(all_phoneme_ids.shape[0], -1, -1)
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# pred_semantic = t2s_model.model.infer(
|
# pred_semantic = t2s_model.model.infer(
|
||||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
pred_semantic, sidx, eidx = t2s_model.model.infer_panel(
|
||||||
all_phoneme_ids,
|
all_phoneme_ids,
|
||||||
all_phoneme_len,
|
all_phoneme_len,
|
||||||
None if ref_free else prompt,
|
None if ref_free else prompt,
|
||||||
@ -405,9 +426,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
)
|
)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
# print(pred_semantic.shape,idx)
|
||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
|
pred_semantic = torch.cat([pred[si:ei] for pred, si, ei in zip(pred_semantic, sidx, eidx)], dim=0).unsqueeze(0).unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
0
|
|
||||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
|
||||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
refer = refer.half().to(device)
|
refer = refer.half().to(device)
|
||||||
@ -416,7 +435,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||||
audio = (
|
audio = (
|
||||||
vq_model.decode(
|
vq_model.decode(
|
||||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
|
pred_semantic, torch.LongTensor(all_vits_phones2).to(device).unsqueeze(0), refer
|
||||||
)
|
)
|
||||||
.detach()
|
.detach()
|
||||||
.cpu()
|
.cpu()
|
||||||
@ -588,12 +607,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
with gr.Row():
|
||||||
|
batch_size = gr.Slider(minimum=1,maximum=32,step=1,label=i18n("batch_size"),value=1,interactive=True)
|
||||||
|
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||||
output = gr.Audio(label=i18n("输出的语音"))
|
output = gr.Audio(label=i18n("输出的语音"))
|
||||||
|
|
||||||
inference_button.click(
|
inference_button.click(
|
||||||
get_tts_wav,
|
get_tts_wav,
|
||||||
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
|
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free, batch_size],
|
||||||
[output],
|
[output],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import logging
|
|||||||
import json
|
import json
|
||||||
import subprocess
|
import subprocess
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -270,7 +271,6 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
|
|||||||
del_routine = lambda x: [os.remove(x), del_info(x)]
|
del_routine = lambda x: [os.remove(x), del_info(x)]
|
||||||
rs = [del_routine(fn) for fn in to_del]
|
rs = [del_routine(fn) for fn in to_del]
|
||||||
|
|
||||||
|
|
||||||
def get_hparams_from_dir(model_dir):
|
def get_hparams_from_dir(model_dir):
|
||||||
config_save_path = os.path.join(model_dir, "config.json")
|
config_save_path = os.path.join(model_dir, "config.json")
|
||||||
with open(config_save_path, "r") as f:
|
with open(config_save_path, "r") as f:
|
||||||
@ -331,6 +331,12 @@ def get_logger(model_dir, filename="train.log"):
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_padding(tensor_list: List[torch.Tensor]):
|
||||||
|
max_len = max([t.shape[-1] for t in tensor_list])
|
||||||
|
tensor_list = [torch.cat([t, torch.zeros(*list(t.shape[:-1]), max_len - t.size(-1)).to(t)], dim=-1) for t in tensor_list]
|
||||||
|
|
||||||
|
return torch.cat(tensor_list, dim=0)
|
||||||
|
|
||||||
class HParams:
|
class HParams:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
|
1
gweight.txt
Normal file
1
gweight.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
1
sweight.txt
Normal file
1
sweight.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
GPT_SoVITS/pretrained_models/s2G488k.pth
|
Loading…
x
Reference in New Issue
Block a user