fix inference issue

This commit is contained in:
wangzeyuan 2025-02-16 21:09:35 +08:00
parent c09af5ab7d
commit 0d11b60fb8

View File

@ -677,7 +677,7 @@ class Text2SemanticDecoder(nn.Module):
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
@ -857,7 +857,7 @@ class Text2SemanticDecoder(nn.Module):
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx - 1
return y[:, :-1], idx
def infer_panel(