diff --git a/tests/test_cli_demo_multi_gpu.py b/tests/test_cli_demo_multi_gpu.py index 0afd83a..2071a8c 100644 --- a/tests/test_cli_demo_multi_gpu.py +++ b/tests/test_cli_demo_multi_gpu.py @@ -287,7 +287,9 @@ class TestDeviceMapSupport: mocks = create_mocked_cli_demo() cli_demo = mocks['cli_demo'] - cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="sequential")) + cli_demo.generate_video( + **self._get_common_args(generate_type="t2v", device_map="sequential") + ) # Verify from_pretrained was called WITH device_map="sequential" mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once() @@ -302,7 +304,9 @@ class TestDeviceMapSupport: mocks = create_mocked_cli_demo() cli_demo = mocks['cli_demo'] - cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map="sequential")) + cli_demo.generate_video( + **self._get_common_args(generate_type="i2v", device_map="sequential") + ) # Verify from_pretrained was called WITH device_map="sequential" mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once() @@ -317,7 +321,9 @@ class TestDeviceMapSupport: mocks = create_mocked_cli_demo() cli_demo = mocks['cli_demo'] - cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map="sequential")) + cli_demo.generate_video( + **self._get_common_args(generate_type="v2v", device_map="sequential") + ) # Verify from_pretrained was called WITH device_map="sequential" mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once() @@ -377,7 +383,9 @@ class TestDeviceMapSupport: cli_demo = mocks['cli_demo'] with pytest.raises(ValueError) as exc_info: - cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="invalid")) + cli_demo.generate_video( + **self._get_common_args(generate_type="t2v", device_map="invalid") + ) assert "Invalid device_map" in str(exc_info.value) assert "invalid" in str(exc_info.value) @@ -389,7 +397,9 @@ class TestDeviceMapSupport: cli_demo = mocks['cli_demo'] with pytest.raises(ValueError) as exc_info: - cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=invalid_value)) + cli_demo.generate_video( + **self._get_common_args(generate_type="t2v", device_map=invalid_value) + ) assert "Invalid device_map" in str(exc_info.value) @@ -427,12 +437,15 @@ class TestDeviceMapSupport: # Parametrized comprehensive tests # ========================================================================= - @pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [ - (None, False, True), - ("auto", True, False), - ("balanced", True, False), - ("sequential", True, False), - ]) + @pytest.mark.parametrize( + "device_map,should_have_device_map,should_cpu_offload", + [ + (None, False, True), + ("auto", True, False), + ("balanced", True, False), + ("sequential", True, False), + ], + ) def test_device_map_behavior_t2v(self, device_map, should_have_device_map, should_cpu_offload): """Comprehensive test for device_map behavior with t2v pipeline.""" mocks = create_mocked_cli_demo() @@ -456,12 +469,15 @@ class TestDeviceMapSupport: else: mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called() - @pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [ - (None, False, True), - ("auto", True, False), - ("balanced", True, False), - ("sequential", True, False), - ]) + @pytest.mark.parametrize( + "device_map,should_have_device_map,should_cpu_offload", + [ + (None, False, True), + ("auto", True, False), + ("balanced", True, False), + ("sequential", True, False), + ], + ) def test_device_map_behavior_i2v(self, device_map, should_have_device_map, should_cpu_offload): """Comprehensive test for device_map behavior with i2v pipeline.""" mocks = create_mocked_cli_demo() @@ -485,12 +501,15 @@ class TestDeviceMapSupport: else: mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called() - @pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [ - (None, False, True), - ("auto", True, False), - ("balanced", True, False), - ("sequential", True, False), - ]) + @pytest.mark.parametrize( + "device_map,should_have_device_map,should_cpu_offload", + [ + (None, False, True), + ("auto", True, False), + ("balanced", True, False), + ("sequential", True, False), + ], + ) def test_device_map_behavior_v2v(self, device_map, should_have_device_map, should_cpu_offload): """Comprehensive test for device_map behavior with v2v pipeline.""" mocks = create_mocked_cli_demo()