fix: correct torch.randint upper bound to include both values

This commit is contained in:
Mr-Neutr0n 2026-02-11 23:45:46 +05:30
parent 2d9193b0d3
commit 5503a5891b

View File

@ -262,7 +262,7 @@ def make_reject_y(y_o, y_lens):
reject_y = []
reject_y_lens = []
for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1,))[0]
process_item_idx = torch.randint(0, 2, size=(1,))[0]
if process_item_idx == 0:
new_y = repeat_P(y_o[b])
reject_y.append(new_y)