fix inference issue (#2061)

Co-authored-by: wangzeyuan <wangzeyuan@agora.io>
This commit is contained in:
wzy3650 2025-02-17 10:38:04 +08:00 committed by GitHub
parent f454834cbb
commit 087fd24579
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -677,7 +677,7 @@ class Text2SemanticDecoder(nn.Module):
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y] # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y: for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i] batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1 idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1] y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()] batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
@ -857,7 +857,7 @@ class Text2SemanticDecoder(nn.Module):
if ref_free: if ref_free:
return y[:, :-1], 0 return y[:, :-1], 0
return y[:, :-1], idx - 1 return y[:, :-1], idx
def infer_panel( def infer_panel(