dpo改实验性勾选而非必须。勾选后batch size自动减半。

dpo改实验性勾选而非必须。勾选后batch size自动减半。
This commit is contained in:
RVC-Boss 2024-02-15 20:17:44 +08:00 committed by GitHub
parent 9b5231a317
commit 895fde46e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,7 +11,6 @@ from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
@ -35,7 +34,8 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
loss, acc = self.model.forward(
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
loss, acc = forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],
batch["semantic_ids"],