mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-14 20:49:21 +08:00
Merge c84ebba304c9f00ade27873ae1041fd2158eb898 into 7a1af7154511e0ce4e4be8d62faa8c5e5a3532d2
This commit is contained in:
commit
61f5f38839
@ -713,7 +713,9 @@ class Trainer:
|
||||
else:
|
||||
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
||||
|
||||
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
|
||||
def __move_components_to_device(self, dtype, ignore_list: List[str] = None):
|
||||
if ignore_list is None:
|
||||
ignore_list = []
|
||||
ignore_list = set(ignore_list)
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
@ -723,7 +725,9 @@ class Trainer:
|
||||
self.components, name, component.to(self.accelerator.device, dtype=dtype)
|
||||
)
|
||||
|
||||
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
||||
def __move_components_to_cpu(self, unload_list: List[str] = None):
|
||||
if unload_list is None:
|
||||
unload_list = []
|
||||
unload_list = set(unload_list)
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user