diff --git a/GPT_SoVITS/AR/data/data_module.py b/GPT_SoVITS/AR/data/data_module.py index 037484a..54d4634 100644 --- a/GPT_SoVITS/AR/data/data_module.py +++ b/GPT_SoVITS/AR/data/data_module.py @@ -41,7 +41,8 @@ class Text2SemanticDataModule(LightningDataModule): # pad_val=self.config['data']['pad_val']) def train_dataloader(self): - batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1)#防止不保存 + batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"] + batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存 sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) return DataLoader( self._train_dataset,