diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index cc4f24d8..f3b08a9f 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -262,7 +262,7 @@ def make_reject_y(y_o, y_lens): reject_y = [] reject_y_lens = [] for b in range(bs): - process_item_idx = torch.randint(0, 1, size=(1,))[0] + process_item_idx = torch.randint(0, 2, size=(1,))[0] if process_item_idx == 0: new_y = repeat_P(y_o[b]) reject_y.append(new_y)