mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-11 01:48:38 +08:00
style: apply ruff formatting
This commit is contained in:
parent
c5d6f7b557
commit
4215b632d2
@ -287,7 +287,9 @@ class TestDeviceMapSupport:
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="sequential"))
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="t2v", device_map="sequential")
|
||||
)
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
@ -302,7 +304,9 @@ class TestDeviceMapSupport:
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map="sequential"))
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="i2v", device_map="sequential")
|
||||
)
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
@ -317,7 +321,9 @@ class TestDeviceMapSupport:
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map="sequential"))
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="v2v", device_map="sequential")
|
||||
)
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
@ -377,7 +383,9 @@ class TestDeviceMapSupport:
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="invalid"))
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="t2v", device_map="invalid")
|
||||
)
|
||||
|
||||
assert "Invalid device_map" in str(exc_info.value)
|
||||
assert "invalid" in str(exc_info.value)
|
||||
@ -389,7 +397,9 @@ class TestDeviceMapSupport:
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=invalid_value))
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="t2v", device_map=invalid_value)
|
||||
)
|
||||
|
||||
assert "Invalid device_map" in str(exc_info.value)
|
||||
|
||||
@ -427,12 +437,15 @@ class TestDeviceMapSupport:
|
||||
# Parametrized comprehensive tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"device_map,should_have_device_map,should_cpu_offload",
|
||||
[
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
],
|
||||
)
|
||||
def test_device_map_behavior_t2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with t2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
@ -456,12 +469,15 @@ class TestDeviceMapSupport:
|
||||
else:
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"device_map,should_have_device_map,should_cpu_offload",
|
||||
[
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
],
|
||||
)
|
||||
def test_device_map_behavior_i2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with i2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
@ -485,12 +501,15 @@ class TestDeviceMapSupport:
|
||||
else:
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"device_map,should_have_device_map,should_cpu_offload",
|
||||
[
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
],
|
||||
)
|
||||
def test_device_map_behavior_v2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with v2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user