mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-06-03 18:38:46 +08:00
style: apply ruff formatting to resolution tests
This commit is contained in:
parent
106f987ded
commit
261868fcb8
@ -37,6 +37,7 @@ def resolution_map():
|
||||
"""Fixture providing the RESOLUTION_MAP from cli_demo."""
|
||||
# Import after mocking dependencies
|
||||
from inference.cli_demo import RESOLUTION_MAP
|
||||
|
||||
return RESOLUTION_MAP
|
||||
|
||||
|
||||
@ -44,13 +45,16 @@ def resolution_map():
|
||||
def cli_demo_module():
|
||||
"""Fixture providing the cli_demo module."""
|
||||
import inference.cli_demo as cli_demo
|
||||
|
||||
return cli_demo
|
||||
|
||||
|
||||
class TestResolutionMapFallback:
|
||||
"""Test suite for RESOLUTION_MAP lookup and fallback behavior."""
|
||||
|
||||
def test_valid_model_name_returns_correct_resolution(self, resolution_map, cli_demo_module, caplog):
|
||||
def test_valid_model_name_returns_correct_resolution(
|
||||
self, resolution_map, cli_demo_module, caplog
|
||||
):
|
||||
"""Test that valid model names return their correct resolutions."""
|
||||
|
||||
# Test each valid model in RESOLUTION_MAP
|
||||
@ -70,7 +74,9 @@ class TestResolutionMapFallback:
|
||||
model_name = model_path.split("/")[-1].lower()
|
||||
|
||||
# Verify the model name is in the map
|
||||
assert expected_key in resolution_map, f"Expected key '{expected_key}' not in RESOLUTION_MAP"
|
||||
assert (
|
||||
expected_key in resolution_map
|
||||
), f"Expected key '{expected_key}' not in RESOLUTION_MAP"
|
||||
assert resolution_map[expected_key] == expected_resolution
|
||||
|
||||
# Test that accessing the resolution works without KeyError
|
||||
@ -118,10 +124,14 @@ class TestResolutionMapFallback:
|
||||
# Verify try/except block exists
|
||||
assert "try:" in source, "generate_video should have try/except block"
|
||||
assert "except KeyError:" in source, "Should catch KeyError specifically"
|
||||
assert "RESOLUTION_MAP[model_name]" in source, "Should access RESOLUTION_MAP with model_name"
|
||||
assert (
|
||||
"RESOLUTION_MAP[model_name]" in source
|
||||
), "Should access RESOLUTION_MAP with model_name"
|
||||
|
||||
# Verify fallback logic exists
|
||||
assert "default_resolution" in source or "(480, 720)" in source, "Should define default resolution"
|
||||
assert (
|
||||
"default_resolution" in source or "(480, 720)" in source
|
||||
), "Should define default resolution"
|
||||
assert "logging.warning" in source, "Should log warning on fallback"
|
||||
assert "valid models" in source.lower(), "Warning should mention valid models"
|
||||
|
||||
@ -135,32 +145,31 @@ class TestResolutionMapFallback:
|
||||
"cogvideox-2b",
|
||||
}
|
||||
|
||||
assert set(resolution_map.keys()) == expected_models, (
|
||||
"RESOLUTION_MAP should contain all expected models"
|
||||
)
|
||||
assert (
|
||||
set(resolution_map.keys()) == expected_models
|
||||
), "RESOLUTION_MAP should contain all expected models"
|
||||
|
||||
# Verify all resolutions are tuples of two integers
|
||||
for model_name, resolution in resolution_map.items():
|
||||
assert isinstance(resolution, tuple), (
|
||||
f"Resolution for '{model_name}' should be a tuple"
|
||||
)
|
||||
assert len(resolution) == 2, (
|
||||
f"Resolution for '{model_name}' should have 2 elements (height, width)"
|
||||
)
|
||||
assert all(isinstance(dim, int) for dim in resolution), (
|
||||
f"Resolution for '{model_name}' should contain integers"
|
||||
)
|
||||
assert isinstance(resolution, tuple), f"Resolution for '{model_name}' should be a tuple"
|
||||
assert (
|
||||
len(resolution) == 2
|
||||
), f"Resolution for '{model_name}' should have 2 elements (height, width)"
|
||||
assert all(
|
||||
isinstance(dim, int) for dim in resolution
|
||||
), f"Resolution for '{model_name}' should contain integers"
|
||||
|
||||
def test_fallback_warning_message_format(self, cli_demo_module, caplog):
|
||||
"""Test the format and content of the fallback warning message."""
|
||||
|
||||
# Mock the generate_video call to test just the resolution lookup part
|
||||
with patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, \
|
||||
patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), \
|
||||
patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), \
|
||||
patch.object(cli_demo_module, 'export_to_video'), \
|
||||
caplog.at_level(logging.WARNING):
|
||||
|
||||
with (
|
||||
patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline,
|
||||
patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'),
|
||||
patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'),
|
||||
patch.object(cli_demo_module, 'export_to_video'),
|
||||
caplog.at_level(logging.WARNING),
|
||||
):
|
||||
# Setup mock to avoid actual model loading
|
||||
mock_instance = MagicMock()
|
||||
mock_pipeline.from_pretrained.return_value = mock_instance
|
||||
@ -189,11 +198,14 @@ class TestResolutionMapFallback:
|
||||
# Verify warning message contains key information
|
||||
assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name"
|
||||
assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue"
|
||||
assert "480" in warning_msg and "720" in warning_msg, "Should mention default resolution"
|
||||
assert (
|
||||
"480" in warning_msg and "720" in warning_msg
|
||||
), "Should mention default resolution"
|
||||
assert "valid models:" in warning_msg.lower(), "Should list valid models"
|
||||
|
||||
# Verify all valid model names are in the warning
|
||||
from inference.cli_demo import RESOLUTION_MAP
|
||||
|
||||
for model_name in RESOLUTION_MAP.keys():
|
||||
assert model_name in warning_msg, f"Warning should list '{model_name}'"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user