mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-06-01 09:04:08 +08:00
style: apply ruff formatting to resolution tests
This commit is contained in:
parent
106f987ded
commit
261868fcb8
@ -37,6 +37,7 @@ def resolution_map():
|
|||||||
"""Fixture providing the RESOLUTION_MAP from cli_demo."""
|
"""Fixture providing the RESOLUTION_MAP from cli_demo."""
|
||||||
# Import after mocking dependencies
|
# Import after mocking dependencies
|
||||||
from inference.cli_demo import RESOLUTION_MAP
|
from inference.cli_demo import RESOLUTION_MAP
|
||||||
|
|
||||||
return RESOLUTION_MAP
|
return RESOLUTION_MAP
|
||||||
|
|
||||||
|
|
||||||
@ -44,13 +45,16 @@ def resolution_map():
|
|||||||
def cli_demo_module():
|
def cli_demo_module():
|
||||||
"""Fixture providing the cli_demo module."""
|
"""Fixture providing the cli_demo module."""
|
||||||
import inference.cli_demo as cli_demo
|
import inference.cli_demo as cli_demo
|
||||||
|
|
||||||
return cli_demo
|
return cli_demo
|
||||||
|
|
||||||
|
|
||||||
class TestResolutionMapFallback:
|
class TestResolutionMapFallback:
|
||||||
"""Test suite for RESOLUTION_MAP lookup and fallback behavior."""
|
"""Test suite for RESOLUTION_MAP lookup and fallback behavior."""
|
||||||
|
|
||||||
def test_valid_model_name_returns_correct_resolution(self, resolution_map, cli_demo_module, caplog):
|
def test_valid_model_name_returns_correct_resolution(
|
||||||
|
self, resolution_map, cli_demo_module, caplog
|
||||||
|
):
|
||||||
"""Test that valid model names return their correct resolutions."""
|
"""Test that valid model names return their correct resolutions."""
|
||||||
|
|
||||||
# Test each valid model in RESOLUTION_MAP
|
# Test each valid model in RESOLUTION_MAP
|
||||||
@ -70,7 +74,9 @@ class TestResolutionMapFallback:
|
|||||||
model_name = model_path.split("/")[-1].lower()
|
model_name = model_path.split("/")[-1].lower()
|
||||||
|
|
||||||
# Verify the model name is in the map
|
# Verify the model name is in the map
|
||||||
assert expected_key in resolution_map, f"Expected key '{expected_key}' not in RESOLUTION_MAP"
|
assert (
|
||||||
|
expected_key in resolution_map
|
||||||
|
), f"Expected key '{expected_key}' not in RESOLUTION_MAP"
|
||||||
assert resolution_map[expected_key] == expected_resolution
|
assert resolution_map[expected_key] == expected_resolution
|
||||||
|
|
||||||
# Test that accessing the resolution works without KeyError
|
# Test that accessing the resolution works without KeyError
|
||||||
@ -118,10 +124,14 @@ class TestResolutionMapFallback:
|
|||||||
# Verify try/except block exists
|
# Verify try/except block exists
|
||||||
assert "try:" in source, "generate_video should have try/except block"
|
assert "try:" in source, "generate_video should have try/except block"
|
||||||
assert "except KeyError:" in source, "Should catch KeyError specifically"
|
assert "except KeyError:" in source, "Should catch KeyError specifically"
|
||||||
assert "RESOLUTION_MAP[model_name]" in source, "Should access RESOLUTION_MAP with model_name"
|
assert (
|
||||||
|
"RESOLUTION_MAP[model_name]" in source
|
||||||
|
), "Should access RESOLUTION_MAP with model_name"
|
||||||
|
|
||||||
# Verify fallback logic exists
|
# Verify fallback logic exists
|
||||||
assert "default_resolution" in source or "(480, 720)" in source, "Should define default resolution"
|
assert (
|
||||||
|
"default_resolution" in source or "(480, 720)" in source
|
||||||
|
), "Should define default resolution"
|
||||||
assert "logging.warning" in source, "Should log warning on fallback"
|
assert "logging.warning" in source, "Should log warning on fallback"
|
||||||
assert "valid models" in source.lower(), "Warning should mention valid models"
|
assert "valid models" in source.lower(), "Warning should mention valid models"
|
||||||
|
|
||||||
@ -135,32 +145,31 @@ class TestResolutionMapFallback:
|
|||||||
"cogvideox-2b",
|
"cogvideox-2b",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert set(resolution_map.keys()) == expected_models, (
|
assert (
|
||||||
"RESOLUTION_MAP should contain all expected models"
|
set(resolution_map.keys()) == expected_models
|
||||||
)
|
), "RESOLUTION_MAP should contain all expected models"
|
||||||
|
|
||||||
# Verify all resolutions are tuples of two integers
|
# Verify all resolutions are tuples of two integers
|
||||||
for model_name, resolution in resolution_map.items():
|
for model_name, resolution in resolution_map.items():
|
||||||
assert isinstance(resolution, tuple), (
|
assert isinstance(resolution, tuple), f"Resolution for '{model_name}' should be a tuple"
|
||||||
f"Resolution for '{model_name}' should be a tuple"
|
assert (
|
||||||
)
|
len(resolution) == 2
|
||||||
assert len(resolution) == 2, (
|
), f"Resolution for '{model_name}' should have 2 elements (height, width)"
|
||||||
f"Resolution for '{model_name}' should have 2 elements (height, width)"
|
assert all(
|
||||||
)
|
isinstance(dim, int) for dim in resolution
|
||||||
assert all(isinstance(dim, int) for dim in resolution), (
|
), f"Resolution for '{model_name}' should contain integers"
|
||||||
f"Resolution for '{model_name}' should contain integers"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_fallback_warning_message_format(self, cli_demo_module, caplog):
|
def test_fallback_warning_message_format(self, cli_demo_module, caplog):
|
||||||
"""Test the format and content of the fallback warning message."""
|
"""Test the format and content of the fallback warning message."""
|
||||||
|
|
||||||
# Mock the generate_video call to test just the resolution lookup part
|
# Mock the generate_video call to test just the resolution lookup part
|
||||||
with patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, \
|
with (
|
||||||
patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), \
|
patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline,
|
||||||
patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), \
|
patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'),
|
||||||
patch.object(cli_demo_module, 'export_to_video'), \
|
patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'),
|
||||||
caplog.at_level(logging.WARNING):
|
patch.object(cli_demo_module, 'export_to_video'),
|
||||||
|
caplog.at_level(logging.WARNING),
|
||||||
|
):
|
||||||
# Setup mock to avoid actual model loading
|
# Setup mock to avoid actual model loading
|
||||||
mock_instance = MagicMock()
|
mock_instance = MagicMock()
|
||||||
mock_pipeline.from_pretrained.return_value = mock_instance
|
mock_pipeline.from_pretrained.return_value = mock_instance
|
||||||
@ -189,11 +198,14 @@ class TestResolutionMapFallback:
|
|||||||
# Verify warning message contains key information
|
# Verify warning message contains key information
|
||||||
assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name"
|
assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name"
|
||||||
assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue"
|
assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue"
|
||||||
assert "480" in warning_msg and "720" in warning_msg, "Should mention default resolution"
|
assert (
|
||||||
|
"480" in warning_msg and "720" in warning_msg
|
||||||
|
), "Should mention default resolution"
|
||||||
assert "valid models:" in warning_msg.lower(), "Should list valid models"
|
assert "valid models:" in warning_msg.lower(), "Should list valid models"
|
||||||
|
|
||||||
# Verify all valid model names are in the warning
|
# Verify all valid model names are in the warning
|
||||||
from inference.cli_demo import RESOLUTION_MAP
|
from inference.cli_demo import RESOLUTION_MAP
|
||||||
|
|
||||||
for model_name in RESOLUTION_MAP.keys():
|
for model_name in RESOLUTION_MAP.keys():
|
||||||
assert model_name in warning_msg, f"Warning should list '{model_name}'"
|
assert model_name in warning_msg, f"Warning should list '{model_name}'"
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user