diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index 9678c7e1..98d48645 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -209,7 +209,7 @@ def make_reject_y(y_o, y_lens): reject_y = [] reject_y_lens = [] for b in range(bs): - process_item_idx = torch.randint(0, 1, size=(1, ))[0] + process_item_idx = torch.randint(0, 2, size=(1, ))[0] if process_item_idx == 0: new_y = repeat_P(y_o[b]) reject_y.append(new_y)