From 895fde46e420040ed26aaf0c5b7e99359d9b199b Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:17:44 +0800 Subject: [PATCH] =?UTF-8?q?dpo=E6=94=B9=E5=AE=9E=E9=AA=8C=E6=80=A7?= =?UTF-8?q?=E5=8B=BE=E9=80=89=E8=80=8C=E9=9D=9E=E5=BF=85=E9=A1=BB=E3=80=82?= =?UTF-8?q?=E5=8B=BE=E9=80=89=E5=90=8Ebatch=20size=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E5=87=8F=E5=8D=8A=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit dpo改实验性勾选而非必须。勾选后batch size自动减半。 --- GPT_SoVITS/AR/models/t2s_lightning_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index f9dfc64..594b73b 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -11,7 +11,6 @@ from AR.models.t2s_model import Text2SemanticDecoder from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.optim import ScaledAdam - class Text2SemanticLightningModule(LightningModule): def __init__(self, config, output_dir, is_train=True): super().__init__() @@ -35,7 +34,8 @@ class Text2SemanticLightningModule(LightningModule): def training_step(self, batch: Dict, batch_idx: int): opt = self.optimizers() scheduler = self.lr_schedulers() - loss, acc = self.model.forward( + forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old + loss, acc = forward( batch["phoneme_ids"], batch["phoneme_ids_len"], batch["semantic_ids"],