From fed94a861eb0430ad65d13c925cc323ee9d1da2c Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 19 Feb 2026 03:04:39 +0000 Subject: [PATCH 1/4] fix: Add graceful fallback for unknown models in RESOLUTION_MAP Wraps RESOLUTION_MAP lookup in try/except to handle local model paths or unrecognized model names. Provides helpful error message listing valid model names and falls back to default 480x720 resolution. Fixes KeyError when using local paths like ./local_models/CogVideoX-5B --- inference/cli_demo.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/inference/cli_demo.py b/inference/cli_demo.py index 2e28165..9cfc211 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -96,7 +96,17 @@ def generate_video( video = None model_name = model_path.split("/")[-1].lower() - desired_resolution = RESOLUTION_MAP[model_name] + try: + desired_resolution = RESOLUTION_MAP[model_name] + except KeyError: + valid_models = ", ".join(RESOLUTION_MAP.keys()) + default_resolution = (480, 720) + logging.warning( + f"\033[1;33mModel '{model_name}' not found in resolution map. " + f"Valid models: {valid_models}. " + f"Falling back to default resolution {default_resolution}.\033[0m" + ) + desired_resolution = default_resolution if width is None or height is None: height, width = desired_resolution logging.info( From 106f987ded49d60847199b379f7c23d7cd8e79ce Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 19 Feb 2026 03:20:26 +0000 Subject: [PATCH 2/4] test: add tests for RESOLUTION_MAP fallback --- inference/__init__.py | 1 + tests/test_cli_demo_resolution.py | 250 ++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 inference/__init__.py create mode 100644 tests/test_cli_demo_resolution.py diff --git a/inference/__init__.py b/inference/__init__.py new file mode 100644 index 0000000..89bb16c --- /dev/null +++ b/inference/__init__.py @@ -0,0 +1 @@ +"""CogVideoX inference package.""" diff --git a/tests/test_cli_demo_resolution.py b/tests/test_cli_demo_resolution.py new file mode 100644 index 0000000..d01745c --- /dev/null +++ b/tests/test_cli_demo_resolution.py @@ -0,0 +1,250 @@ +""" +Tests for RESOLUTION_MAP fallback behavior in cli_demo.py + +This module tests the graceful fallback mechanism when model names +are not found in the RESOLUTION_MAP dictionary. +""" + +import logging +import sys +from unittest.mock import MagicMock, patch, Mock +import pytest + + +# Mock heavy dependencies before importing cli_demo +@pytest.fixture(scope="module", autouse=True) +def mock_heavy_imports(): + """Mock heavy ML dependencies to allow testing without installation.""" + mock_torch = MagicMock() + mock_torch.dtype = MagicMock() + mock_torch.bfloat16 = MagicMock() + mock_torch.float16 = MagicMock() + + sys.modules['torch'] = mock_torch + sys.modules['diffusers'] = MagicMock() + sys.modules['diffusers.utils'] = MagicMock() + + yield + + # Cleanup + for module in ['torch', 'diffusers', 'diffusers.utils']: + if module in sys.modules: + del sys.modules[module] + + +@pytest.fixture +def resolution_map(): + """Fixture providing the RESOLUTION_MAP from cli_demo.""" + # Import after mocking dependencies + from inference.cli_demo import RESOLUTION_MAP + return RESOLUTION_MAP + + +@pytest.fixture +def cli_demo_module(): + """Fixture providing the cli_demo module.""" + import inference.cli_demo as cli_demo + return cli_demo + + +class TestResolutionMapFallback: + """Test suite for RESOLUTION_MAP lookup and fallback behavior.""" + + def test_valid_model_name_returns_correct_resolution(self, resolution_map, cli_demo_module, caplog): + """Test that valid model names return their correct resolutions.""" + + # Test each valid model in RESOLUTION_MAP + test_cases = [ + ("THUDM/CogVideoX-5b", "cogvideox-5b", (480, 720)), + ("THUDM/CogVideoX-2b", "cogvideox-2b", (480, 720)), + ("THUDM/CogVideoX-5b-I2V", "cogvideox-5b-i2v", (480, 720)), + ("THUDM/CogVideoX1.5-5b", "cogvideox1.5-5b", (768, 1360)), + ("THUDM/CogVideoX1.5-5b-I2V", "cogvideox1.5-5b-i2v", (768, 1360)), + ] + + for model_path, expected_key, expected_resolution in test_cases: + with caplog.at_level(logging.WARNING): + caplog.clear() + + # Extract model name using the same logic as cli_demo + model_name = model_path.split("/")[-1].lower() + + # Verify the model name is in the map + assert expected_key in resolution_map, f"Expected key '{expected_key}' not in RESOLUTION_MAP" + assert resolution_map[expected_key] == expected_resolution + + # Test that accessing the resolution works without KeyError + try: + actual_resolution = resolution_map[model_name] + assert actual_resolution == expected_resolution + except KeyError: + pytest.fail(f"Valid model '{model_name}' raised KeyError") + + def test_invalid_model_name_raises_key_error(self, resolution_map): + """Test that invalid model names raise KeyError (which is then caught by try/except).""" + + invalid_models = [ + "invalidmodel-xyz", + "unknown-5b", + "cogvideox-5b-custom", # Not in map + ] + + for invalid_model in invalid_models: + with pytest.raises(KeyError): + _ = resolution_map[invalid_model] + + def test_local_path_model_name_extraction(self): + """Test that local paths are parsed correctly to extract model names.""" + + local_paths = [ + ("./local_models/CogVideoX-5B", "cogvideox-5b"), + ("/path/to/custom_model", "custom_model"), + ("~/models/my-cogvideo-model", "my-cogvideo-model"), + ("./CogVideoX-Custom", "cogvideox-custom"), + ] + + for local_path, expected_model_name in local_paths: + # Extract model name using the same logic as cli_demo + model_name = local_path.split("/")[-1].lower() + assert model_name == expected_model_name + + def test_fallback_code_structure(self, cli_demo_module): + """Test that the fallback code structure is present in generate_video function.""" + import inspect + + # Get the source code of generate_video + source = inspect.getsource(cli_demo_module.generate_video) + + # Verify try/except block exists + assert "try:" in source, "generate_video should have try/except block" + assert "except KeyError:" in source, "Should catch KeyError specifically" + assert "RESOLUTION_MAP[model_name]" in source, "Should access RESOLUTION_MAP with model_name" + + # Verify fallback logic exists + assert "default_resolution" in source or "(480, 720)" in source, "Should define default resolution" + assert "logging.warning" in source, "Should log warning on fallback" + assert "valid models" in source.lower(), "Warning should mention valid models" + + def test_resolution_map_completeness(self, resolution_map): + """Test that RESOLUTION_MAP contains expected models.""" + expected_models = { + "cogvideox1.5-5b-i2v", + "cogvideox1.5-5b", + "cogvideox-5b-i2v", + "cogvideox-5b", + "cogvideox-2b", + } + + assert set(resolution_map.keys()) == expected_models, ( + "RESOLUTION_MAP should contain all expected models" + ) + + # Verify all resolutions are tuples of two integers + for model_name, resolution in resolution_map.items(): + assert isinstance(resolution, tuple), ( + f"Resolution for '{model_name}' should be a tuple" + ) + assert len(resolution) == 2, ( + f"Resolution for '{model_name}' should have 2 elements (height, width)" + ) + assert all(isinstance(dim, int) for dim in resolution), ( + f"Resolution for '{model_name}' should contain integers" + ) + + def test_fallback_warning_message_format(self, cli_demo_module, caplog): + """Test the format and content of the fallback warning message.""" + + # Mock the generate_video call to test just the resolution lookup part + with patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, \ + patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), \ + patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), \ + patch.object(cli_demo_module, 'export_to_video'), \ + caplog.at_level(logging.WARNING): + + # Setup mock to avoid actual model loading + mock_instance = MagicMock() + mock_pipeline.from_pretrained.return_value = mock_instance + mock_instance.return_value = MagicMock(frames=[]) + + caplog.clear() + + try: + cli_demo_module.generate_video( + prompt="Test prompt", + model_path="invalid/unknown-model", + generate_type="t2v", + num_inference_steps=1, + ) + except Exception: + # We're only interested in the warning, not successful execution + pass + + # Check that a warning was logged + warning_records = [r for r in caplog.records if r.levelname == "WARNING"] + assert len(warning_records) > 0, "Should log at least one warning" + + # Get the first warning message + warning_msg = warning_records[0].message + + # Verify warning message contains key information + assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name" + assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue" + assert "480" in warning_msg and "720" in warning_msg, "Should mention default resolution" + assert "valid models:" in warning_msg.lower(), "Should list valid models" + + # Verify all valid model names are in the warning + from inference.cli_demo import RESOLUTION_MAP + for model_name in RESOLUTION_MAP.keys(): + assert model_name in warning_msg, f"Warning should list '{model_name}'" + + def test_case_insensitive_model_name_matching(self, resolution_map): + """Test that model names are converted to lowercase for matching.""" + + # All keys in RESOLUTION_MAP should be lowercase + for key in resolution_map.keys(): + assert key == key.lower(), f"RESOLUTION_MAP key '{key}' should be lowercase" + + # Test that various case inputs would work after .lower() conversion + test_paths = [ + ("THUDM/COGVIDEOX-5B", "cogvideox-5b"), + ("thudm/cogvideox-5b", "cogvideox-5b"), + ("ThUdM/CoGvIdEoX-5b", "cogvideox-5b"), + ] + + for model_path, expected_key in test_paths: + model_name = model_path.split("/")[-1].lower() + assert model_name == expected_key + assert model_name in resolution_map + + def test_default_resolution_value(self, cli_demo_module): + """Test that the default fallback resolution is (480, 720).""" + import inspect + + source = inspect.getsource(cli_demo_module.generate_video) + + # The default should be defined as (480, 720) + assert "(480, 720)" in source, "Default resolution should be (480, 720)" + + def test_resolution_dimensions_order(self, resolution_map): + """Test that resolutions are in (height, width) format.""" + + # Based on the code comments and typical usage + for model_name, (dim1, dim2) in resolution_map.items(): + # Height typically comes first + # Verify dimensions are reasonable (not negative or zero) + assert dim1 > 0, f"First dimension for {model_name} should be positive" + assert dim2 > 0, f"Second dimension for {model_name} should be positive" + + # CogVideoX1.5 models should have larger resolutions + if "1.5" in model_name: + assert dim1 >= 768, f"CogVideoX1.5 model {model_name} should have height >= 768" + assert dim2 >= 1360, f"CogVideoX1.5 model {model_name} should have width >= 1360" + else: + # Standard CogVideoX models + assert dim1 == 480, f"Standard CogVideoX model {model_name} should have height 480" + assert dim2 == 720, f"Standard CogVideoX model {model_name} should have width 720" + + +if __name__ == "__main__": + # Allow running tests directly + pytest.main([__file__, "-v"]) From 261868fcb85729cf2762349def9a971083cf2bd3 Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 19 Feb 2026 03:35:53 +0000 Subject: [PATCH 3/4] style: apply ruff formatting to resolution tests --- tests/test_cli_demo_resolution.py | 120 ++++++++++++++++-------------- 1 file changed, 66 insertions(+), 54 deletions(-) diff --git a/tests/test_cli_demo_resolution.py b/tests/test_cli_demo_resolution.py index d01745c..aae665d 100644 --- a/tests/test_cli_demo_resolution.py +++ b/tests/test_cli_demo_resolution.py @@ -19,13 +19,13 @@ def mock_heavy_imports(): mock_torch.dtype = MagicMock() mock_torch.bfloat16 = MagicMock() mock_torch.float16 = MagicMock() - + sys.modules['torch'] = mock_torch sys.modules['diffusers'] = MagicMock() sys.modules['diffusers.utils'] = MagicMock() - + yield - + # Cleanup for module in ['torch', 'diffusers', 'diffusers.utils']: if module in sys.modules: @@ -37,6 +37,7 @@ def resolution_map(): """Fixture providing the RESOLUTION_MAP from cli_demo.""" # Import after mocking dependencies from inference.cli_demo import RESOLUTION_MAP + return RESOLUTION_MAP @@ -44,15 +45,18 @@ def resolution_map(): def cli_demo_module(): """Fixture providing the cli_demo module.""" import inference.cli_demo as cli_demo + return cli_demo class TestResolutionMapFallback: """Test suite for RESOLUTION_MAP lookup and fallback behavior.""" - def test_valid_model_name_returns_correct_resolution(self, resolution_map, cli_demo_module, caplog): + def test_valid_model_name_returns_correct_resolution( + self, resolution_map, cli_demo_module, caplog + ): """Test that valid model names return their correct resolutions.""" - + # Test each valid model in RESOLUTION_MAP test_cases = [ ("THUDM/CogVideoX-5b", "cogvideox-5b", (480, 720)), @@ -61,18 +65,20 @@ class TestResolutionMapFallback: ("THUDM/CogVideoX1.5-5b", "cogvideox1.5-5b", (768, 1360)), ("THUDM/CogVideoX1.5-5b-I2V", "cogvideox1.5-5b-i2v", (768, 1360)), ] - + for model_path, expected_key, expected_resolution in test_cases: with caplog.at_level(logging.WARNING): caplog.clear() - + # Extract model name using the same logic as cli_demo model_name = model_path.split("/")[-1].lower() - + # Verify the model name is in the map - assert expected_key in resolution_map, f"Expected key '{expected_key}' not in RESOLUTION_MAP" + assert ( + expected_key in resolution_map + ), f"Expected key '{expected_key}' not in RESOLUTION_MAP" assert resolution_map[expected_key] == expected_resolution - + # Test that accessing the resolution works without KeyError try: actual_resolution = resolution_map[model_name] @@ -82,27 +88,27 @@ class TestResolutionMapFallback: def test_invalid_model_name_raises_key_error(self, resolution_map): """Test that invalid model names raise KeyError (which is then caught by try/except).""" - + invalid_models = [ "invalidmodel-xyz", "unknown-5b", "cogvideox-5b-custom", # Not in map ] - + for invalid_model in invalid_models: with pytest.raises(KeyError): _ = resolution_map[invalid_model] def test_local_path_model_name_extraction(self): """Test that local paths are parsed correctly to extract model names.""" - + local_paths = [ ("./local_models/CogVideoX-5B", "cogvideox-5b"), ("/path/to/custom_model", "custom_model"), ("~/models/my-cogvideo-model", "my-cogvideo-model"), ("./CogVideoX-Custom", "cogvideox-custom"), ] - + for local_path, expected_model_name in local_paths: # Extract model name using the same logic as cli_demo model_name = local_path.split("/")[-1].lower() @@ -111,17 +117,21 @@ class TestResolutionMapFallback: def test_fallback_code_structure(self, cli_demo_module): """Test that the fallback code structure is present in generate_video function.""" import inspect - + # Get the source code of generate_video source = inspect.getsource(cli_demo_module.generate_video) - + # Verify try/except block exists assert "try:" in source, "generate_video should have try/except block" assert "except KeyError:" in source, "Should catch KeyError specifically" - assert "RESOLUTION_MAP[model_name]" in source, "Should access RESOLUTION_MAP with model_name" - + assert ( + "RESOLUTION_MAP[model_name]" in source + ), "Should access RESOLUTION_MAP with model_name" + # Verify fallback logic exists - assert "default_resolution" in source or "(480, 720)" in source, "Should define default resolution" + assert ( + "default_resolution" in source or "(480, 720)" in source + ), "Should define default resolution" assert "logging.warning" in source, "Should log warning on fallback" assert "valid models" in source.lower(), "Warning should mention valid models" @@ -134,40 +144,39 @@ class TestResolutionMapFallback: "cogvideox-5b", "cogvideox-2b", } - - assert set(resolution_map.keys()) == expected_models, ( - "RESOLUTION_MAP should contain all expected models" - ) - + + assert ( + set(resolution_map.keys()) == expected_models + ), "RESOLUTION_MAP should contain all expected models" + # Verify all resolutions are tuples of two integers for model_name, resolution in resolution_map.items(): - assert isinstance(resolution, tuple), ( - f"Resolution for '{model_name}' should be a tuple" - ) - assert len(resolution) == 2, ( - f"Resolution for '{model_name}' should have 2 elements (height, width)" - ) - assert all(isinstance(dim, int) for dim in resolution), ( - f"Resolution for '{model_name}' should contain integers" - ) + assert isinstance(resolution, tuple), f"Resolution for '{model_name}' should be a tuple" + assert ( + len(resolution) == 2 + ), f"Resolution for '{model_name}' should have 2 elements (height, width)" + assert all( + isinstance(dim, int) for dim in resolution + ), f"Resolution for '{model_name}' should contain integers" def test_fallback_warning_message_format(self, cli_demo_module, caplog): """Test the format and content of the fallback warning message.""" - + # Mock the generate_video call to test just the resolution lookup part - with patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, \ - patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), \ - patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), \ - patch.object(cli_demo_module, 'export_to_video'), \ - caplog.at_level(logging.WARNING): - + with ( + patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, + patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), + patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), + patch.object(cli_demo_module, 'export_to_video'), + caplog.at_level(logging.WARNING), + ): # Setup mock to avoid actual model loading mock_instance = MagicMock() mock_pipeline.from_pretrained.return_value = mock_instance mock_instance.return_value = MagicMock(frames=[]) - + caplog.clear() - + try: cli_demo_module.generate_video( prompt="Test prompt", @@ -178,39 +187,42 @@ class TestResolutionMapFallback: except Exception: # We're only interested in the warning, not successful execution pass - + # Check that a warning was logged warning_records = [r for r in caplog.records if r.levelname == "WARNING"] assert len(warning_records) > 0, "Should log at least one warning" - + # Get the first warning message warning_msg = warning_records[0].message - + # Verify warning message contains key information assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name" assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue" - assert "480" in warning_msg and "720" in warning_msg, "Should mention default resolution" + assert ( + "480" in warning_msg and "720" in warning_msg + ), "Should mention default resolution" assert "valid models:" in warning_msg.lower(), "Should list valid models" - + # Verify all valid model names are in the warning from inference.cli_demo import RESOLUTION_MAP + for model_name in RESOLUTION_MAP.keys(): assert model_name in warning_msg, f"Warning should list '{model_name}'" def test_case_insensitive_model_name_matching(self, resolution_map): """Test that model names are converted to lowercase for matching.""" - + # All keys in RESOLUTION_MAP should be lowercase for key in resolution_map.keys(): assert key == key.lower(), f"RESOLUTION_MAP key '{key}' should be lowercase" - + # Test that various case inputs would work after .lower() conversion test_paths = [ ("THUDM/COGVIDEOX-5B", "cogvideox-5b"), ("thudm/cogvideox-5b", "cogvideox-5b"), ("ThUdM/CoGvIdEoX-5b", "cogvideox-5b"), ] - + for model_path, expected_key in test_paths: model_name = model_path.split("/")[-1].lower() assert model_name == expected_key @@ -219,22 +231,22 @@ class TestResolutionMapFallback: def test_default_resolution_value(self, cli_demo_module): """Test that the default fallback resolution is (480, 720).""" import inspect - + source = inspect.getsource(cli_demo_module.generate_video) - + # The default should be defined as (480, 720) assert "(480, 720)" in source, "Default resolution should be (480, 720)" def test_resolution_dimensions_order(self, resolution_map): """Test that resolutions are in (height, width) format.""" - + # Based on the code comments and typical usage for model_name, (dim1, dim2) in resolution_map.items(): # Height typically comes first # Verify dimensions are reasonable (not negative or zero) assert dim1 > 0, f"First dimension for {model_name} should be positive" assert dim2 > 0, f"Second dimension for {model_name} should be positive" - + # CogVideoX1.5 models should have larger resolutions if "1.5" in model_name: assert dim1 >= 768, f"CogVideoX1.5 model {model_name} should have height >= 768" From c312b95925df9ed358c0219ce7e9746a14bc9ddb Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 19 Feb 2026 03:36:37 +0000 Subject: [PATCH 4/4] chore: allow tests/ directory in gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index ad4bbeb..303dbae 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ logs/ .idea output* test* +!tests/ +!tests/** venv **/.swp **/*.log