style: apply ruff formatting to resolution tests

This commit is contained in:
Test User 2026-02-19 03:35:53 +00:00
parent 106f987ded
commit 261868fcb8

View File

@ -19,13 +19,13 @@ def mock_heavy_imports():
mock_torch.dtype = MagicMock() mock_torch.dtype = MagicMock()
mock_torch.bfloat16 = MagicMock() mock_torch.bfloat16 = MagicMock()
mock_torch.float16 = MagicMock() mock_torch.float16 = MagicMock()
sys.modules['torch'] = mock_torch sys.modules['torch'] = mock_torch
sys.modules['diffusers'] = MagicMock() sys.modules['diffusers'] = MagicMock()
sys.modules['diffusers.utils'] = MagicMock() sys.modules['diffusers.utils'] = MagicMock()
yield yield
# Cleanup # Cleanup
for module in ['torch', 'diffusers', 'diffusers.utils']: for module in ['torch', 'diffusers', 'diffusers.utils']:
if module in sys.modules: if module in sys.modules:
@ -37,6 +37,7 @@ def resolution_map():
"""Fixture providing the RESOLUTION_MAP from cli_demo.""" """Fixture providing the RESOLUTION_MAP from cli_demo."""
# Import after mocking dependencies # Import after mocking dependencies
from inference.cli_demo import RESOLUTION_MAP from inference.cli_demo import RESOLUTION_MAP
return RESOLUTION_MAP return RESOLUTION_MAP
@ -44,15 +45,18 @@ def resolution_map():
def cli_demo_module(): def cli_demo_module():
"""Fixture providing the cli_demo module.""" """Fixture providing the cli_demo module."""
import inference.cli_demo as cli_demo import inference.cli_demo as cli_demo
return cli_demo return cli_demo
class TestResolutionMapFallback: class TestResolutionMapFallback:
"""Test suite for RESOLUTION_MAP lookup and fallback behavior.""" """Test suite for RESOLUTION_MAP lookup and fallback behavior."""
def test_valid_model_name_returns_correct_resolution(self, resolution_map, cli_demo_module, caplog): def test_valid_model_name_returns_correct_resolution(
self, resolution_map, cli_demo_module, caplog
):
"""Test that valid model names return their correct resolutions.""" """Test that valid model names return their correct resolutions."""
# Test each valid model in RESOLUTION_MAP # Test each valid model in RESOLUTION_MAP
test_cases = [ test_cases = [
("THUDM/CogVideoX-5b", "cogvideox-5b", (480, 720)), ("THUDM/CogVideoX-5b", "cogvideox-5b", (480, 720)),
@ -61,18 +65,20 @@ class TestResolutionMapFallback:
("THUDM/CogVideoX1.5-5b", "cogvideox1.5-5b", (768, 1360)), ("THUDM/CogVideoX1.5-5b", "cogvideox1.5-5b", (768, 1360)),
("THUDM/CogVideoX1.5-5b-I2V", "cogvideox1.5-5b-i2v", (768, 1360)), ("THUDM/CogVideoX1.5-5b-I2V", "cogvideox1.5-5b-i2v", (768, 1360)),
] ]
for model_path, expected_key, expected_resolution in test_cases: for model_path, expected_key, expected_resolution in test_cases:
with caplog.at_level(logging.WARNING): with caplog.at_level(logging.WARNING):
caplog.clear() caplog.clear()
# Extract model name using the same logic as cli_demo # Extract model name using the same logic as cli_demo
model_name = model_path.split("/")[-1].lower() model_name = model_path.split("/")[-1].lower()
# Verify the model name is in the map # Verify the model name is in the map
assert expected_key in resolution_map, f"Expected key '{expected_key}' not in RESOLUTION_MAP" assert (
expected_key in resolution_map
), f"Expected key '{expected_key}' not in RESOLUTION_MAP"
assert resolution_map[expected_key] == expected_resolution assert resolution_map[expected_key] == expected_resolution
# Test that accessing the resolution works without KeyError # Test that accessing the resolution works without KeyError
try: try:
actual_resolution = resolution_map[model_name] actual_resolution = resolution_map[model_name]
@ -82,27 +88,27 @@ class TestResolutionMapFallback:
def test_invalid_model_name_raises_key_error(self, resolution_map): def test_invalid_model_name_raises_key_error(self, resolution_map):
"""Test that invalid model names raise KeyError (which is then caught by try/except).""" """Test that invalid model names raise KeyError (which is then caught by try/except)."""
invalid_models = [ invalid_models = [
"invalidmodel-xyz", "invalidmodel-xyz",
"unknown-5b", "unknown-5b",
"cogvideox-5b-custom", # Not in map "cogvideox-5b-custom", # Not in map
] ]
for invalid_model in invalid_models: for invalid_model in invalid_models:
with pytest.raises(KeyError): with pytest.raises(KeyError):
_ = resolution_map[invalid_model] _ = resolution_map[invalid_model]
def test_local_path_model_name_extraction(self): def test_local_path_model_name_extraction(self):
"""Test that local paths are parsed correctly to extract model names.""" """Test that local paths are parsed correctly to extract model names."""
local_paths = [ local_paths = [
("./local_models/CogVideoX-5B", "cogvideox-5b"), ("./local_models/CogVideoX-5B", "cogvideox-5b"),
("/path/to/custom_model", "custom_model"), ("/path/to/custom_model", "custom_model"),
("~/models/my-cogvideo-model", "my-cogvideo-model"), ("~/models/my-cogvideo-model", "my-cogvideo-model"),
("./CogVideoX-Custom", "cogvideox-custom"), ("./CogVideoX-Custom", "cogvideox-custom"),
] ]
for local_path, expected_model_name in local_paths: for local_path, expected_model_name in local_paths:
# Extract model name using the same logic as cli_demo # Extract model name using the same logic as cli_demo
model_name = local_path.split("/")[-1].lower() model_name = local_path.split("/")[-1].lower()
@ -111,17 +117,21 @@ class TestResolutionMapFallback:
def test_fallback_code_structure(self, cli_demo_module): def test_fallback_code_structure(self, cli_demo_module):
"""Test that the fallback code structure is present in generate_video function.""" """Test that the fallback code structure is present in generate_video function."""
import inspect import inspect
# Get the source code of generate_video # Get the source code of generate_video
source = inspect.getsource(cli_demo_module.generate_video) source = inspect.getsource(cli_demo_module.generate_video)
# Verify try/except block exists # Verify try/except block exists
assert "try:" in source, "generate_video should have try/except block" assert "try:" in source, "generate_video should have try/except block"
assert "except KeyError:" in source, "Should catch KeyError specifically" assert "except KeyError:" in source, "Should catch KeyError specifically"
assert "RESOLUTION_MAP[model_name]" in source, "Should access RESOLUTION_MAP with model_name" assert (
"RESOLUTION_MAP[model_name]" in source
), "Should access RESOLUTION_MAP with model_name"
# Verify fallback logic exists # Verify fallback logic exists
assert "default_resolution" in source or "(480, 720)" in source, "Should define default resolution" assert (
"default_resolution" in source or "(480, 720)" in source
), "Should define default resolution"
assert "logging.warning" in source, "Should log warning on fallback" assert "logging.warning" in source, "Should log warning on fallback"
assert "valid models" in source.lower(), "Warning should mention valid models" assert "valid models" in source.lower(), "Warning should mention valid models"
@ -134,40 +144,39 @@ class TestResolutionMapFallback:
"cogvideox-5b", "cogvideox-5b",
"cogvideox-2b", "cogvideox-2b",
} }
assert set(resolution_map.keys()) == expected_models, ( assert (
"RESOLUTION_MAP should contain all expected models" set(resolution_map.keys()) == expected_models
) ), "RESOLUTION_MAP should contain all expected models"
# Verify all resolutions are tuples of two integers # Verify all resolutions are tuples of two integers
for model_name, resolution in resolution_map.items(): for model_name, resolution in resolution_map.items():
assert isinstance(resolution, tuple), ( assert isinstance(resolution, tuple), f"Resolution for '{model_name}' should be a tuple"
f"Resolution for '{model_name}' should be a tuple" assert (
) len(resolution) == 2
assert len(resolution) == 2, ( ), f"Resolution for '{model_name}' should have 2 elements (height, width)"
f"Resolution for '{model_name}' should have 2 elements (height, width)" assert all(
) isinstance(dim, int) for dim in resolution
assert all(isinstance(dim, int) for dim in resolution), ( ), f"Resolution for '{model_name}' should contain integers"
f"Resolution for '{model_name}' should contain integers"
)
def test_fallback_warning_message_format(self, cli_demo_module, caplog): def test_fallback_warning_message_format(self, cli_demo_module, caplog):
"""Test the format and content of the fallback warning message.""" """Test the format and content of the fallback warning message."""
# Mock the generate_video call to test just the resolution lookup part # Mock the generate_video call to test just the resolution lookup part
with patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, \ with (
patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), \ patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline,
patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), \ patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'),
patch.object(cli_demo_module, 'export_to_video'), \ patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'),
caplog.at_level(logging.WARNING): patch.object(cli_demo_module, 'export_to_video'),
caplog.at_level(logging.WARNING),
):
# Setup mock to avoid actual model loading # Setup mock to avoid actual model loading
mock_instance = MagicMock() mock_instance = MagicMock()
mock_pipeline.from_pretrained.return_value = mock_instance mock_pipeline.from_pretrained.return_value = mock_instance
mock_instance.return_value = MagicMock(frames=[]) mock_instance.return_value = MagicMock(frames=[])
caplog.clear() caplog.clear()
try: try:
cli_demo_module.generate_video( cli_demo_module.generate_video(
prompt="Test prompt", prompt="Test prompt",
@ -178,39 +187,42 @@ class TestResolutionMapFallback:
except Exception: except Exception:
# We're only interested in the warning, not successful execution # We're only interested in the warning, not successful execution
pass pass
# Check that a warning was logged # Check that a warning was logged
warning_records = [r for r in caplog.records if r.levelname == "WARNING"] warning_records = [r for r in caplog.records if r.levelname == "WARNING"]
assert len(warning_records) > 0, "Should log at least one warning" assert len(warning_records) > 0, "Should log at least one warning"
# Get the first warning message # Get the first warning message
warning_msg = warning_records[0].message warning_msg = warning_records[0].message
# Verify warning message contains key information # Verify warning message contains key information
assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name" assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name"
assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue" assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue"
assert "480" in warning_msg and "720" in warning_msg, "Should mention default resolution" assert (
"480" in warning_msg and "720" in warning_msg
), "Should mention default resolution"
assert "valid models:" in warning_msg.lower(), "Should list valid models" assert "valid models:" in warning_msg.lower(), "Should list valid models"
# Verify all valid model names are in the warning # Verify all valid model names are in the warning
from inference.cli_demo import RESOLUTION_MAP from inference.cli_demo import RESOLUTION_MAP
for model_name in RESOLUTION_MAP.keys(): for model_name in RESOLUTION_MAP.keys():
assert model_name in warning_msg, f"Warning should list '{model_name}'" assert model_name in warning_msg, f"Warning should list '{model_name}'"
def test_case_insensitive_model_name_matching(self, resolution_map): def test_case_insensitive_model_name_matching(self, resolution_map):
"""Test that model names are converted to lowercase for matching.""" """Test that model names are converted to lowercase for matching."""
# All keys in RESOLUTION_MAP should be lowercase # All keys in RESOLUTION_MAP should be lowercase
for key in resolution_map.keys(): for key in resolution_map.keys():
assert key == key.lower(), f"RESOLUTION_MAP key '{key}' should be lowercase" assert key == key.lower(), f"RESOLUTION_MAP key '{key}' should be lowercase"
# Test that various case inputs would work after .lower() conversion # Test that various case inputs would work after .lower() conversion
test_paths = [ test_paths = [
("THUDM/COGVIDEOX-5B", "cogvideox-5b"), ("THUDM/COGVIDEOX-5B", "cogvideox-5b"),
("thudm/cogvideox-5b", "cogvideox-5b"), ("thudm/cogvideox-5b", "cogvideox-5b"),
("ThUdM/CoGvIdEoX-5b", "cogvideox-5b"), ("ThUdM/CoGvIdEoX-5b", "cogvideox-5b"),
] ]
for model_path, expected_key in test_paths: for model_path, expected_key in test_paths:
model_name = model_path.split("/")[-1].lower() model_name = model_path.split("/")[-1].lower()
assert model_name == expected_key assert model_name == expected_key
@ -219,22 +231,22 @@ class TestResolutionMapFallback:
def test_default_resolution_value(self, cli_demo_module): def test_default_resolution_value(self, cli_demo_module):
"""Test that the default fallback resolution is (480, 720).""" """Test that the default fallback resolution is (480, 720)."""
import inspect import inspect
source = inspect.getsource(cli_demo_module.generate_video) source = inspect.getsource(cli_demo_module.generate_video)
# The default should be defined as (480, 720) # The default should be defined as (480, 720)
assert "(480, 720)" in source, "Default resolution should be (480, 720)" assert "(480, 720)" in source, "Default resolution should be (480, 720)"
def test_resolution_dimensions_order(self, resolution_map): def test_resolution_dimensions_order(self, resolution_map):
"""Test that resolutions are in (height, width) format.""" """Test that resolutions are in (height, width) format."""
# Based on the code comments and typical usage # Based on the code comments and typical usage
for model_name, (dim1, dim2) in resolution_map.items(): for model_name, (dim1, dim2) in resolution_map.items():
# Height typically comes first # Height typically comes first
# Verify dimensions are reasonable (not negative or zero) # Verify dimensions are reasonable (not negative or zero)
assert dim1 > 0, f"First dimension for {model_name} should be positive" assert dim1 > 0, f"First dimension for {model_name} should be positive"
assert dim2 > 0, f"Second dimension for {model_name} should be positive" assert dim2 > 0, f"Second dimension for {model_name} should be positive"
# CogVideoX1.5 models should have larger resolutions # CogVideoX1.5 models should have larger resolutions
if "1.5" in model_name: if "1.5" in model_name:
assert dim1 >= 768, f"CogVideoX1.5 model {model_name} should have height >= 768" assert dim1 >= 768, f"CogVideoX1.5 model {model_name} should have height >= 768"