Update utils.py

fixed random error on DPO
This commit is contained in:
XXXXRT666 2024-05-02 00:00:35 +01:00 committed by GitHub
parent 0b806dba37
commit d2f991bb6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -209,7 +209,7 @@ def make_reject_y(y_o, y_lens):
reject_y = [] reject_y = []
reject_y_lens = [] reject_y_lens = []
for b in range(bs): for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1, ))[0] process_item_idx = torch.randint(0, 2, size=(1, ))[0]
if process_item_idx == 0: if process_item_idx == 0:
new_y = repeat_P(y_o[b]) new_y = repeat_P(y_o[b])
reject_y.append(new_y) reject_y.append(new_y)