mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:solve unified kv cache shape handling, todo: clean up upper level to unify first and following step
This commit is contained in:
parent
0c5f61f98c
commit
26228402e3
@ -128,7 +128,7 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
|
def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None):
|
||||||
if top_k is None:
|
if top_k is None:
|
||||||
top_k = torch.LongTensor([15]).to(device=x.device)
|
top_k = torch.LongTensor([15]).to(device=x.device)
|
||||||
if top_p is None:
|
if top_p is None:
|
||||||
@ -146,7 +146,7 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
"k": None,
|
"k": None,
|
||||||
"v": None,
|
"v": None,
|
||||||
"y_emb": None,
|
"y_emb": None,
|
||||||
"first_infer": 1,
|
"first_infer": first_infer,
|
||||||
"stage": 0,
|
"stage": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,7 +216,7 @@ class T2SStageDecoder(nn.Module):
|
|||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None):
|
def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None):
|
||||||
if top_k is None:
|
if top_k is None:
|
||||||
top_k = torch.LongTensor([15]).to(device=y.device)
|
top_k = torch.LongTensor([15]).to(device=y.device)
|
||||||
if top_p is None:
|
if top_p is None:
|
||||||
@ -231,7 +231,7 @@ class T2SStageDecoder(nn.Module):
|
|||||||
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
||||||
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
|
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
|
||||||
"y_emb": y_emb,
|
"y_emb": y_emb,
|
||||||
"first_infer": 0,
|
"first_infer": first_infer,
|
||||||
"stage": 0,
|
"stage": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -336,11 +336,11 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
prefix_len = prompts.shape[1]
|
prefix_len = prompts.shape[1]
|
||||||
|
|
||||||
x = self.onnx_encoder(x, bert_feature)
|
x = self.onnx_encoder(x, bert_feature)
|
||||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k)
|
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k, first_infer=torch.LongTensor([1]))
|
||||||
|
|
||||||
stop = False
|
stop = False
|
||||||
for idx in tqdm(range(1, 1500)):
|
for idx in tqdm(range(1, 1500)):
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k)
|
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, first_infer=torch.LongTensor([0]))
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
stop = True
|
stop = True
|
||||||
|
@ -48,33 +48,23 @@ def multi_head_attention_forward_patched(
|
|||||||
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
|
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
|
||||||
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||||
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
||||||
# 首轮qkv会产生多个batch,后续每轮只会产生一个batch,
|
|
||||||
# onnx导出时处理batch变化导致的输出形状变化非常无力,
|
|
||||||
# 已尝试过where方法,索引方法,尽管可以动态运行正常导出,
|
|
||||||
# 但都无法在onnx运行时正确处理kv cache形状导致抛出错误
|
|
||||||
# 此实现需要整体重写,将kvcache增长和prefill交给外部调用
|
|
||||||
if cache["first_infer"] == 1:
|
|
||||||
cache["k"][cache["stage"]] = k
|
|
||||||
cache["v"][cache["stage"]] = v
|
|
||||||
else:
|
|
||||||
cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
|
|
||||||
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
|
|
||||||
k = cache["k"][cache["stage"]]
|
|
||||||
v = cache["v"][cache["stage"]]
|
|
||||||
|
|
||||||
|
# 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异
|
||||||
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
|
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
|
||||||
# # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside
|
# # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside
|
||||||
# first_infer_mask = cache["first_infer"]
|
first_infer_mask = cache["first_infer"]
|
||||||
# cache_k = cache["k"][cache["stage"]]
|
cache_k = cache["k"][cache["stage"]]
|
||||||
# cache_v = cache["v"][cache["stage"]]
|
cache_v = cache["v"][cache["stage"]]
|
||||||
# # Magic to get an index of either -1 or -N according to if first_infer_mask is set
|
# Magic to get an index of either -1 or -N according to if first_infer_mask is set
|
||||||
# index_offset = torch.min(torch.tensor([-1]).to(k.device).to(torch.int64), -1 * first_infer_mask * k.shape[0])
|
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
|
||||||
# cache_k[0, index_offset :, :, :] = k
|
multipled = minus_one * first_infer_mask * torch.onnx.operators.shape_as_tensor(query)[0]
|
||||||
# cache_v[0, index_offset :, :, :] = v
|
index_offset = torch.min(minus_one, multipled)
|
||||||
# cache["k"][cache["stage"]] = cache_k
|
cache_k[index_offset :, :, :] = k
|
||||||
# cache["v"][cache["stage"]] = cache_v
|
cache_v[index_offset :, :, :] = v
|
||||||
# k = cache_k
|
cache["k"][cache["stage"]] = cache_k
|
||||||
# v = cache_v
|
cache["v"][cache["stage"]] = cache_v
|
||||||
|
k = cache_k
|
||||||
|
v = cache_v
|
||||||
|
|
||||||
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
||||||
|
|
||||||
|
@ -112,14 +112,15 @@ class T2SInitStep(nn.Module):
|
|||||||
self.fsdc = t2s.first_stage_decoder
|
self.fsdc = t2s.first_stage_decoder
|
||||||
self.vits = vits
|
self.vits = vits
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None, first_infer=None):
|
||||||
|
first_infer = first_infer.to(torch.int64)
|
||||||
codes = self.vits.extract_latent(ssl_content)
|
codes = self.vits.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
||||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
prompt = prompt_semantic.unsqueeze(0)
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=first_infer)
|
||||||
fake_logits = torch.zeros((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
|
fake_logits = torch.zeros((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
|
||||||
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
|
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
|
||||||
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
||||||
@ -129,8 +130,9 @@ class T2SStageStep(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.stage_decoder = stage_decoder
|
self.stage_decoder = stage_decoder
|
||||||
|
|
||||||
def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None, first_infer=None):
|
||||||
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
first_infer = first_infer.to(torch.int64)
|
||||||
|
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=first_infer)
|
||||||
fake_x_example = torch.zeros((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
|
fake_x_example = torch.zeros((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
|
||||||
return y, k, v, y_emb, fake_x_example, logits, samples
|
return y, k, v, y_emb, fake_x_example, logits, samples
|
||||||
|
|
||||||
@ -155,10 +157,10 @@ class T2SModel(nn.Module):
|
|||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.LongTensor([1]))
|
||||||
|
|
||||||
for idx in range(5): # This is a fake one! DO NOT take this as reference
|
for idx in range(5): # This is a fake one! DO NOT take this as reference
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.LongTensor([0]))
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
# if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
# if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
# break
|
# break
|
||||||
@ -168,9 +170,9 @@ class T2SModel(nn.Module):
|
|||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.init_step,
|
self.init_step,
|
||||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k, top_p, repetition_penalty, temperature),
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k, top_p, repetition_penalty, temperature, torch.Tensor([True]).to(torch.bool)),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
|
||||||
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content", "top_k", "top_p", "repetition_penalty", "temperature"],
|
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step"],
|
||||||
output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'],
|
output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"ref_text_phones": {1: "ref_length"},
|
"ref_text_phones": {1: "ref_length"},
|
||||||
@ -180,16 +182,17 @@ class T2SModel(nn.Module):
|
|||||||
"hubert_ssl_content": {2: "ssl_length"},
|
"hubert_ssl_content": {2: "ssl_length"},
|
||||||
},
|
},
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
|
do_constant_folding=False
|
||||||
)
|
)
|
||||||
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
# simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
||||||
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.Tensor([True]).to(torch.bool))
|
||||||
|
|
||||||
stage_step = T2SStageStep(self.stage_decoder)
|
stage_step = T2SStageStep(self.stage_decoder)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
stage_step,
|
stage_step,
|
||||||
(y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature),
|
(y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature, torch.Tensor([False]).to(torch.bool)),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx",
|
||||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature"],
|
input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step"],
|
||||||
output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"],
|
output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"iy": {1: "iy_length"},
|
"iy": {1: "iy_length"},
|
||||||
@ -329,7 +332,7 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi
|
|||||||
|
|
||||||
# Define the inputs for the main graph
|
# Define the inputs for the main graph
|
||||||
# 1. The boolean condition to select the branch
|
# 1. The boolean condition to select the branch
|
||||||
cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, [])
|
# cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, [])
|
||||||
|
|
||||||
main_outputs = [output for output in init_step_model.graph.output]
|
main_outputs = [output for output in init_step_model.graph.output]
|
||||||
|
|
||||||
@ -346,7 +349,7 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi
|
|||||||
main_graph = helper.make_graph(
|
main_graph = helper.make_graph(
|
||||||
nodes=[if_node],
|
nodes=[if_node],
|
||||||
name="t2s_combined_graph",
|
name="t2s_combined_graph",
|
||||||
inputs=[cond_input] + data_inputs_init + data_inputs_stage,
|
inputs= data_inputs_init + data_inputs_stage,
|
||||||
outputs=main_outputs
|
outputs=main_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -483,29 +486,29 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试
|
# 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
# gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
# vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||||
exp_path = "v1_export"
|
# exp_path = "v1_export"
|
||||||
version = "v1"
|
# version = "v1"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||||
exp_path = "v2_export"
|
exp_path = "v2_export"
|
||||||
version = "v2"
|
version = "v2"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True)
|
||||||
|
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
# gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
||||||
exp_path = "v2pro_export"
|
# exp_path = "v2pro_export"
|
||||||
version = "v2Pro"
|
# version = "v2Pro"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||||
exp_path = "v2proplus_export"
|
# exp_path = "v2proplus_export"
|
||||||
version = "v2ProPlus"
|
# version = "v2ProPlus"
|
||||||
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True, half_precision=True)
|
# export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True, half_precision=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,7 +98,8 @@ t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx")
|
|||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"repetition_penalty": repetition_penalty,
|
"repetition_penalty": repetition_penalty,
|
||||||
"temperature": temperature
|
"temperature": temperature,
|
||||||
|
"if_init_step": np.array([True], dtype=bool)
|
||||||
})
|
})
|
||||||
|
|
||||||
# t2s_stage_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
# t2s_stage_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
||||||
@ -120,7 +121,8 @@ for idx in tqdm(range(1, 1500)):
|
|||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"repetition_penalty": repetition_penalty,
|
"repetition_penalty": repetition_penalty,
|
||||||
"temperature": temperature
|
"temperature": temperature,
|
||||||
|
"if_init_step": np.array([False], dtype=bool)
|
||||||
})
|
})
|
||||||
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
|
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
|
||||||
break
|
break
|
||||||
|
Loading…
x
Reference in New Issue
Block a user