diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index f9dfc64..594b73b 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -11,7 +11,6 @@ from AR.models.t2s_model import Text2SemanticDecoder from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.optim import ScaledAdam - class Text2SemanticLightningModule(LightningModule): def __init__(self, config, output_dir, is_train=True): super().__init__() @@ -35,7 +34,8 @@ class Text2SemanticLightningModule(LightningModule): def training_step(self, batch: Dict, batch_idx: int): opt = self.optimizers() scheduler = self.lr_schedulers() - loss, acc = self.model.forward( + forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old + loss, acc = forward( batch["phoneme_ids"], batch["phoneme_ids_len"], batch["semantic_ids"],