Merge c84ebba304c9f00ade27873ae1041fd2158eb898 into 7a1af7154511e0ce4e4be8d62faa8c5e5a3532d2

This commit is contained in:
Harikrishna KP 2026-02-12 00:13:08 +05:30 committed by GitHub
commit 61f5f38839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -713,7 +713,9 @@ class Trainer:
else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
def __move_components_to_device(self, dtype, ignore_list: List[str] = None):
if ignore_list is None:
ignore_list = []
ignore_list = set(ignore_list)
components = self.components.model_dump()
for name, component in components.items():
@ -723,7 +725,9 @@ class Trainer:
self.components, name, component.to(self.accelerator.device, dtype=dtype)
)
def __move_components_to_cpu(self, unload_list: List[str] = []):
def __move_components_to_cpu(self, unload_list: List[str] = None):
if unload_list is None:
unload_list = []
unload_list = set(unload_list)
components = self.components.model_dump()
for name, component in components.items():