From 261868fcb85729cf2762349def9a971083cf2bd3 Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 19 Feb 2026 03:35:53 +0000 Subject: [PATCH] style: apply ruff formatting to resolution tests --- tests/test_cli_demo_resolution.py | 120 ++++++++++++++++-------------- 1 file changed, 66 insertions(+), 54 deletions(-) diff --git a/tests/test_cli_demo_resolution.py b/tests/test_cli_demo_resolution.py index d01745c..aae665d 100644 --- a/tests/test_cli_demo_resolution.py +++ b/tests/test_cli_demo_resolution.py @@ -19,13 +19,13 @@ def mock_heavy_imports(): mock_torch.dtype = MagicMock() mock_torch.bfloat16 = MagicMock() mock_torch.float16 = MagicMock() - + sys.modules['torch'] = mock_torch sys.modules['diffusers'] = MagicMock() sys.modules['diffusers.utils'] = MagicMock() - + yield - + # Cleanup for module in ['torch', 'diffusers', 'diffusers.utils']: if module in sys.modules: @@ -37,6 +37,7 @@ def resolution_map(): """Fixture providing the RESOLUTION_MAP from cli_demo.""" # Import after mocking dependencies from inference.cli_demo import RESOLUTION_MAP + return RESOLUTION_MAP @@ -44,15 +45,18 @@ def resolution_map(): def cli_demo_module(): """Fixture providing the cli_demo module.""" import inference.cli_demo as cli_demo + return cli_demo class TestResolutionMapFallback: """Test suite for RESOLUTION_MAP lookup and fallback behavior.""" - def test_valid_model_name_returns_correct_resolution(self, resolution_map, cli_demo_module, caplog): + def test_valid_model_name_returns_correct_resolution( + self, resolution_map, cli_demo_module, caplog + ): """Test that valid model names return their correct resolutions.""" - + # Test each valid model in RESOLUTION_MAP test_cases = [ ("THUDM/CogVideoX-5b", "cogvideox-5b", (480, 720)), @@ -61,18 +65,20 @@ class TestResolutionMapFallback: ("THUDM/CogVideoX1.5-5b", "cogvideox1.5-5b", (768, 1360)), ("THUDM/CogVideoX1.5-5b-I2V", "cogvideox1.5-5b-i2v", (768, 1360)), ] - + for model_path, expected_key, expected_resolution in test_cases: with caplog.at_level(logging.WARNING): caplog.clear() - + # Extract model name using the same logic as cli_demo model_name = model_path.split("/")[-1].lower() - + # Verify the model name is in the map - assert expected_key in resolution_map, f"Expected key '{expected_key}' not in RESOLUTION_MAP" + assert ( + expected_key in resolution_map + ), f"Expected key '{expected_key}' not in RESOLUTION_MAP" assert resolution_map[expected_key] == expected_resolution - + # Test that accessing the resolution works without KeyError try: actual_resolution = resolution_map[model_name] @@ -82,27 +88,27 @@ class TestResolutionMapFallback: def test_invalid_model_name_raises_key_error(self, resolution_map): """Test that invalid model names raise KeyError (which is then caught by try/except).""" - + invalid_models = [ "invalidmodel-xyz", "unknown-5b", "cogvideox-5b-custom", # Not in map ] - + for invalid_model in invalid_models: with pytest.raises(KeyError): _ = resolution_map[invalid_model] def test_local_path_model_name_extraction(self): """Test that local paths are parsed correctly to extract model names.""" - + local_paths = [ ("./local_models/CogVideoX-5B", "cogvideox-5b"), ("/path/to/custom_model", "custom_model"), ("~/models/my-cogvideo-model", "my-cogvideo-model"), ("./CogVideoX-Custom", "cogvideox-custom"), ] - + for local_path, expected_model_name in local_paths: # Extract model name using the same logic as cli_demo model_name = local_path.split("/")[-1].lower() @@ -111,17 +117,21 @@ class TestResolutionMapFallback: def test_fallback_code_structure(self, cli_demo_module): """Test that the fallback code structure is present in generate_video function.""" import inspect - + # Get the source code of generate_video source = inspect.getsource(cli_demo_module.generate_video) - + # Verify try/except block exists assert "try:" in source, "generate_video should have try/except block" assert "except KeyError:" in source, "Should catch KeyError specifically" - assert "RESOLUTION_MAP[model_name]" in source, "Should access RESOLUTION_MAP with model_name" - + assert ( + "RESOLUTION_MAP[model_name]" in source + ), "Should access RESOLUTION_MAP with model_name" + # Verify fallback logic exists - assert "default_resolution" in source or "(480, 720)" in source, "Should define default resolution" + assert ( + "default_resolution" in source or "(480, 720)" in source + ), "Should define default resolution" assert "logging.warning" in source, "Should log warning on fallback" assert "valid models" in source.lower(), "Warning should mention valid models" @@ -134,40 +144,39 @@ class TestResolutionMapFallback: "cogvideox-5b", "cogvideox-2b", } - - assert set(resolution_map.keys()) == expected_models, ( - "RESOLUTION_MAP should contain all expected models" - ) - + + assert ( + set(resolution_map.keys()) == expected_models + ), "RESOLUTION_MAP should contain all expected models" + # Verify all resolutions are tuples of two integers for model_name, resolution in resolution_map.items(): - assert isinstance(resolution, tuple), ( - f"Resolution for '{model_name}' should be a tuple" - ) - assert len(resolution) == 2, ( - f"Resolution for '{model_name}' should have 2 elements (height, width)" - ) - assert all(isinstance(dim, int) for dim in resolution), ( - f"Resolution for '{model_name}' should contain integers" - ) + assert isinstance(resolution, tuple), f"Resolution for '{model_name}' should be a tuple" + assert ( + len(resolution) == 2 + ), f"Resolution for '{model_name}' should have 2 elements (height, width)" + assert all( + isinstance(dim, int) for dim in resolution + ), f"Resolution for '{model_name}' should contain integers" def test_fallback_warning_message_format(self, cli_demo_module, caplog): """Test the format and content of the fallback warning message.""" - + # Mock the generate_video call to test just the resolution lookup part - with patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, \ - patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), \ - patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), \ - patch.object(cli_demo_module, 'export_to_video'), \ - caplog.at_level(logging.WARNING): - + with ( + patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, + patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), + patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), + patch.object(cli_demo_module, 'export_to_video'), + caplog.at_level(logging.WARNING), + ): # Setup mock to avoid actual model loading mock_instance = MagicMock() mock_pipeline.from_pretrained.return_value = mock_instance mock_instance.return_value = MagicMock(frames=[]) - + caplog.clear() - + try: cli_demo_module.generate_video( prompt="Test prompt", @@ -178,39 +187,42 @@ class TestResolutionMapFallback: except Exception: # We're only interested in the warning, not successful execution pass - + # Check that a warning was logged warning_records = [r for r in caplog.records if r.levelname == "WARNING"] assert len(warning_records) > 0, "Should log at least one warning" - + # Get the first warning message warning_msg = warning_records[0].message - + # Verify warning message contains key information assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name" assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue" - assert "480" in warning_msg and "720" in warning_msg, "Should mention default resolution" + assert ( + "480" in warning_msg and "720" in warning_msg + ), "Should mention default resolution" assert "valid models:" in warning_msg.lower(), "Should list valid models" - + # Verify all valid model names are in the warning from inference.cli_demo import RESOLUTION_MAP + for model_name in RESOLUTION_MAP.keys(): assert model_name in warning_msg, f"Warning should list '{model_name}'" def test_case_insensitive_model_name_matching(self, resolution_map): """Test that model names are converted to lowercase for matching.""" - + # All keys in RESOLUTION_MAP should be lowercase for key in resolution_map.keys(): assert key == key.lower(), f"RESOLUTION_MAP key '{key}' should be lowercase" - + # Test that various case inputs would work after .lower() conversion test_paths = [ ("THUDM/COGVIDEOX-5B", "cogvideox-5b"), ("thudm/cogvideox-5b", "cogvideox-5b"), ("ThUdM/CoGvIdEoX-5b", "cogvideox-5b"), ] - + for model_path, expected_key in test_paths: model_name = model_path.split("/")[-1].lower() assert model_name == expected_key @@ -219,22 +231,22 @@ class TestResolutionMapFallback: def test_default_resolution_value(self, cli_demo_module): """Test that the default fallback resolution is (480, 720).""" import inspect - + source = inspect.getsource(cli_demo_module.generate_video) - + # The default should be defined as (480, 720) assert "(480, 720)" in source, "Default resolution should be (480, 720)" def test_resolution_dimensions_order(self, resolution_map): """Test that resolutions are in (height, width) format.""" - + # Based on the code comments and typical usage for model_name, (dim1, dim2) in resolution_map.items(): # Height typically comes first # Verify dimensions are reasonable (not negative or zero) assert dim1 > 0, f"First dimension for {model_name} should be positive" assert dim2 > 0, f"Second dimension for {model_name} should be positive" - + # CogVideoX1.5 models should have larger resolutions if "1.5" in model_name: assert dim1 >= 768, f"CogVideoX1.5 model {model_name} should have height >= 768"