mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
|
|
from pytorch_lightning import LightningDataModule
|
|
from AR.data.bucket_sampler import DistributedBucketSampler
|
|
from AR.data.dataset import Text2SemanticDataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class Text2SemanticDataModule(LightningDataModule):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
train_semantic_path,
|
|
train_phoneme_path,
|
|
dev_semantic_path=None,
|
|
dev_phoneme_path=None,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.train_semantic_path = train_semantic_path
|
|
self.train_phoneme_path = train_phoneme_path
|
|
self.dev_semantic_path = dev_semantic_path
|
|
self.dev_phoneme_path = dev_phoneme_path
|
|
self.num_workers = self.config["data"]["num_workers"]
|
|
|
|
def prepare_data(self):
|
|
pass
|
|
|
|
def setup(self, stage=None, output_logs=False):
|
|
self._train_dataset = Text2SemanticDataset(
|
|
phoneme_path=self.train_phoneme_path,
|
|
semantic_path=self.train_semantic_path,
|
|
max_sec=self.config["data"]["max_sec"],
|
|
pad_val=self.config["data"]["pad_val"],
|
|
)
|
|
self._dev_dataset = self._train_dataset
|
|
# self._dev_dataset = Text2SemanticDataset(
|
|
# phoneme_path=self.dev_phoneme_path,
|
|
# semantic_path=self.dev_semantic_path,
|
|
# max_sample=self.config['data']['max_eval_sample'],
|
|
# max_sec=self.config['data']['max_sec'],
|
|
# pad_val=self.config['data']['pad_val'])
|
|
|
|
def train_dataloader(self):
|
|
batch_size = self.config["train"]["batch_size"]
|
|
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
|
return DataLoader(
|
|
self._train_dataset,
|
|
batch_size=batch_size,
|
|
sampler=sampler,
|
|
collate_fn=self._train_dataset.collate,
|
|
num_workers=self.num_workers,
|
|
persistent_workers=True,
|
|
prefetch_factor=16,
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(
|
|
self._dev_dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
collate_fn=self._train_dataset.collate,
|
|
num_workers=max(self.num_workers, 12),
|
|
persistent_workers=True,
|
|
prefetch_factor=16,
|
|
)
|
|
|
|
# 这个会使用到嘛?
|
|
def test_dataloader(self):
|
|
return DataLoader(
|
|
self._dev_dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
collate_fn=self._train_dataset.collate,
|
|
)
|