From 9b5231a317f2940355fcf7e1f1d609b2031774cd Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:17:33 +0800 Subject: [PATCH] =?UTF-8?q?dpo=E6=94=B9=E5=AE=9E=E9=AA=8C=E6=80=A7?= =?UTF-8?q?=E5=8B=BE=E9=80=89=E8=80=8C=E9=9D=9E=E5=BF=85=E9=A1=BB=E3=80=82?= =?UTF-8?q?=E5=8B=BE=E9=80=89=E5=90=8Ebatch=20size=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E5=87=8F=E5=8D=8A=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit dpo改实验性勾选而非必须。勾选后batch size自动减半。 --- GPT_SoVITS/AR/data/data_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/GPT_SoVITS/AR/data/data_module.py b/GPT_SoVITS/AR/data/data_module.py index 037484a..54d4634 100644 --- a/GPT_SoVITS/AR/data/data_module.py +++ b/GPT_SoVITS/AR/data/data_module.py @@ -41,7 +41,8 @@ class Text2SemanticDataModule(LightningDataModule): # pad_val=self.config['data']['pad_val']) def train_dataloader(self): - batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1)#防止不保存 + batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"] + batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存 sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) return DataLoader( self._train_dataset,