mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-05-13 00:08:59 +08:00
dpo改实验性勾选而非必须。勾选后batch size自动减半。
dpo改实验性勾选而非必须。勾选后batch size自动减半。
This commit is contained in:
parent
9b5231a317
commit
895fde46e4
@ -11,7 +11,6 @@ from AR.models.t2s_model import Text2SemanticDecoder
|
|||||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||||
from AR.modules.optim import ScaledAdam
|
from AR.modules.optim import ScaledAdam
|
||||||
|
|
||||||
|
|
||||||
class Text2SemanticLightningModule(LightningModule):
|
class Text2SemanticLightningModule(LightningModule):
|
||||||
def __init__(self, config, output_dir, is_train=True):
|
def __init__(self, config, output_dir, is_train=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -35,7 +34,8 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
def training_step(self, batch: Dict, batch_idx: int):
|
def training_step(self, batch: Dict, batch_idx: int):
|
||||||
opt = self.optimizers()
|
opt = self.optimizers()
|
||||||
scheduler = self.lr_schedulers()
|
scheduler = self.lr_schedulers()
|
||||||
loss, acc = self.model.forward(
|
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
|
||||||
|
loss, acc = forward(
|
||||||
batch["phoneme_ids"],
|
batch["phoneme_ids"],
|
||||||
batch["phoneme_ids_len"],
|
batch["phoneme_ids_len"],
|
||||||
batch["semantic_ids"],
|
batch["semantic_ids"],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user