# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py # reference: https://github.com/lifeiteng/vall-e from pytorch_lightning import LightningModule from .t2s_model_onnx import Text2SemanticDecoder class Text2SemanticLightningModule(LightningModule): def __init__(self, config, output_dir, is_train=True): super().__init__() self.config = config self.top_k = 3 self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)