mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-31 00:08:14 +08:00
style: apply ruff formatting
This commit is contained in:
parent
c5d6f7b557
commit
4215b632d2
@ -287,7 +287,9 @@ class TestDeviceMapSupport:
|
|||||||
mocks = create_mocked_cli_demo()
|
mocks = create_mocked_cli_demo()
|
||||||
cli_demo = mocks['cli_demo']
|
cli_demo = mocks['cli_demo']
|
||||||
|
|
||||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="sequential"))
|
cli_demo.generate_video(
|
||||||
|
**self._get_common_args(generate_type="t2v", device_map="sequential")
|
||||||
|
)
|
||||||
|
|
||||||
# Verify from_pretrained was called WITH device_map="sequential"
|
# Verify from_pretrained was called WITH device_map="sequential"
|
||||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||||
@ -302,7 +304,9 @@ class TestDeviceMapSupport:
|
|||||||
mocks = create_mocked_cli_demo()
|
mocks = create_mocked_cli_demo()
|
||||||
cli_demo = mocks['cli_demo']
|
cli_demo = mocks['cli_demo']
|
||||||
|
|
||||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map="sequential"))
|
cli_demo.generate_video(
|
||||||
|
**self._get_common_args(generate_type="i2v", device_map="sequential")
|
||||||
|
)
|
||||||
|
|
||||||
# Verify from_pretrained was called WITH device_map="sequential"
|
# Verify from_pretrained was called WITH device_map="sequential"
|
||||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||||
@ -317,7 +321,9 @@ class TestDeviceMapSupport:
|
|||||||
mocks = create_mocked_cli_demo()
|
mocks = create_mocked_cli_demo()
|
||||||
cli_demo = mocks['cli_demo']
|
cli_demo = mocks['cli_demo']
|
||||||
|
|
||||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map="sequential"))
|
cli_demo.generate_video(
|
||||||
|
**self._get_common_args(generate_type="v2v", device_map="sequential")
|
||||||
|
)
|
||||||
|
|
||||||
# Verify from_pretrained was called WITH device_map="sequential"
|
# Verify from_pretrained was called WITH device_map="sequential"
|
||||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||||
@ -377,7 +383,9 @@ class TestDeviceMapSupport:
|
|||||||
cli_demo = mocks['cli_demo']
|
cli_demo = mocks['cli_demo']
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="invalid"))
|
cli_demo.generate_video(
|
||||||
|
**self._get_common_args(generate_type="t2v", device_map="invalid")
|
||||||
|
)
|
||||||
|
|
||||||
assert "Invalid device_map" in str(exc_info.value)
|
assert "Invalid device_map" in str(exc_info.value)
|
||||||
assert "invalid" in str(exc_info.value)
|
assert "invalid" in str(exc_info.value)
|
||||||
@ -389,7 +397,9 @@ class TestDeviceMapSupport:
|
|||||||
cli_demo = mocks['cli_demo']
|
cli_demo = mocks['cli_demo']
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=invalid_value))
|
cli_demo.generate_video(
|
||||||
|
**self._get_common_args(generate_type="t2v", device_map=invalid_value)
|
||||||
|
)
|
||||||
|
|
||||||
assert "Invalid device_map" in str(exc_info.value)
|
assert "Invalid device_map" in str(exc_info.value)
|
||||||
|
|
||||||
@ -427,12 +437,15 @@ class TestDeviceMapSupport:
|
|||||||
# Parametrized comprehensive tests
|
# Parametrized comprehensive tests
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
@pytest.mark.parametrize(
|
||||||
(None, False, True),
|
"device_map,should_have_device_map,should_cpu_offload",
|
||||||
("auto", True, False),
|
[
|
||||||
("balanced", True, False),
|
(None, False, True),
|
||||||
("sequential", True, False),
|
("auto", True, False),
|
||||||
])
|
("balanced", True, False),
|
||||||
|
("sequential", True, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_device_map_behavior_t2v(self, device_map, should_have_device_map, should_cpu_offload):
|
def test_device_map_behavior_t2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||||
"""Comprehensive test for device_map behavior with t2v pipeline."""
|
"""Comprehensive test for device_map behavior with t2v pipeline."""
|
||||||
mocks = create_mocked_cli_demo()
|
mocks = create_mocked_cli_demo()
|
||||||
@ -456,12 +469,15 @@ class TestDeviceMapSupport:
|
|||||||
else:
|
else:
|
||||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
@pytest.mark.parametrize(
|
||||||
(None, False, True),
|
"device_map,should_have_device_map,should_cpu_offload",
|
||||||
("auto", True, False),
|
[
|
||||||
("balanced", True, False),
|
(None, False, True),
|
||||||
("sequential", True, False),
|
("auto", True, False),
|
||||||
])
|
("balanced", True, False),
|
||||||
|
("sequential", True, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_device_map_behavior_i2v(self, device_map, should_have_device_map, should_cpu_offload):
|
def test_device_map_behavior_i2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||||
"""Comprehensive test for device_map behavior with i2v pipeline."""
|
"""Comprehensive test for device_map behavior with i2v pipeline."""
|
||||||
mocks = create_mocked_cli_demo()
|
mocks = create_mocked_cli_demo()
|
||||||
@ -485,12 +501,15 @@ class TestDeviceMapSupport:
|
|||||||
else:
|
else:
|
||||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
@pytest.mark.parametrize(
|
||||||
(None, False, True),
|
"device_map,should_have_device_map,should_cpu_offload",
|
||||||
("auto", True, False),
|
[
|
||||||
("balanced", True, False),
|
(None, False, True),
|
||||||
("sequential", True, False),
|
("auto", True, False),
|
||||||
])
|
("balanced", True, False),
|
||||||
|
("sequential", True, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_device_map_behavior_v2v(self, device_map, should_have_device_map, should_cpu_offload):
|
def test_device_map_behavior_v2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||||
"""Comprehensive test for device_map behavior with v2v pipeline."""
|
"""Comprehensive test for device_map behavior with v2v pipeline."""
|
||||||
mocks = create_mocked_cli_demo()
|
mocks = create_mocked_cli_demo()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user