dpo改实验性勾选而非必须。勾选后batch size自动减半。

dpo改实验性勾选而非必须。勾选后batch size自动减半。
This commit is contained in:
RVC-Boss 2024-02-15 20:17:33 +08:00 committed by GitHub
parent ccb9b08be3
commit 9b5231a317
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -41,7 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1)#防止不保存
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,