diff --git a/finetune/trainer.py b/finetune/trainer.py index 5746fee..a0709a9 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -713,7 +713,9 @@ class Trainer: else: raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") - def __move_components_to_device(self, dtype, ignore_list: List[str] = []): + def __move_components_to_device(self, dtype, ignore_list: List[str] = None): + if ignore_list is None: + ignore_list = [] ignore_list = set(ignore_list) components = self.components.model_dump() for name, component in components.items(): @@ -723,7 +725,9 @@ class Trainer: self.components, name, component.to(self.accelerator.device, dtype=dtype) ) - def __move_components_to_cpu(self, unload_list: List[str] = []): + def __move_components_to_cpu(self, unload_list: List[str] = None): + if unload_list is None: + unload_list = [] unload_list = set(unload_list) components = self.components.model_dump() for name, component in components.items():