mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-06-01 00:38:16 +08:00
style: apply ruff formatting to resolution tests
This commit is contained in:
parent
106f987ded
commit
261868fcb8
@ -19,13 +19,13 @@ def mock_heavy_imports():
|
|||||||
mock_torch.dtype = MagicMock()
|
mock_torch.dtype = MagicMock()
|
||||||
mock_torch.bfloat16 = MagicMock()
|
mock_torch.bfloat16 = MagicMock()
|
||||||
mock_torch.float16 = MagicMock()
|
mock_torch.float16 = MagicMock()
|
||||||
|
|
||||||
sys.modules['torch'] = mock_torch
|
sys.modules['torch'] = mock_torch
|
||||||
sys.modules['diffusers'] = MagicMock()
|
sys.modules['diffusers'] = MagicMock()
|
||||||
sys.modules['diffusers.utils'] = MagicMock()
|
sys.modules['diffusers.utils'] = MagicMock()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
for module in ['torch', 'diffusers', 'diffusers.utils']:
|
for module in ['torch', 'diffusers', 'diffusers.utils']:
|
||||||
if module in sys.modules:
|
if module in sys.modules:
|
||||||
@ -37,6 +37,7 @@ def resolution_map():
|
|||||||
"""Fixture providing the RESOLUTION_MAP from cli_demo."""
|
"""Fixture providing the RESOLUTION_MAP from cli_demo."""
|
||||||
# Import after mocking dependencies
|
# Import after mocking dependencies
|
||||||
from inference.cli_demo import RESOLUTION_MAP
|
from inference.cli_demo import RESOLUTION_MAP
|
||||||
|
|
||||||
return RESOLUTION_MAP
|
return RESOLUTION_MAP
|
||||||
|
|
||||||
|
|
||||||
@ -44,15 +45,18 @@ def resolution_map():
|
|||||||
def cli_demo_module():
|
def cli_demo_module():
|
||||||
"""Fixture providing the cli_demo module."""
|
"""Fixture providing the cli_demo module."""
|
||||||
import inference.cli_demo as cli_demo
|
import inference.cli_demo as cli_demo
|
||||||
|
|
||||||
return cli_demo
|
return cli_demo
|
||||||
|
|
||||||
|
|
||||||
class TestResolutionMapFallback:
|
class TestResolutionMapFallback:
|
||||||
"""Test suite for RESOLUTION_MAP lookup and fallback behavior."""
|
"""Test suite for RESOLUTION_MAP lookup and fallback behavior."""
|
||||||
|
|
||||||
def test_valid_model_name_returns_correct_resolution(self, resolution_map, cli_demo_module, caplog):
|
def test_valid_model_name_returns_correct_resolution(
|
||||||
|
self, resolution_map, cli_demo_module, caplog
|
||||||
|
):
|
||||||
"""Test that valid model names return their correct resolutions."""
|
"""Test that valid model names return their correct resolutions."""
|
||||||
|
|
||||||
# Test each valid model in RESOLUTION_MAP
|
# Test each valid model in RESOLUTION_MAP
|
||||||
test_cases = [
|
test_cases = [
|
||||||
("THUDM/CogVideoX-5b", "cogvideox-5b", (480, 720)),
|
("THUDM/CogVideoX-5b", "cogvideox-5b", (480, 720)),
|
||||||
@ -61,18 +65,20 @@ class TestResolutionMapFallback:
|
|||||||
("THUDM/CogVideoX1.5-5b", "cogvideox1.5-5b", (768, 1360)),
|
("THUDM/CogVideoX1.5-5b", "cogvideox1.5-5b", (768, 1360)),
|
||||||
("THUDM/CogVideoX1.5-5b-I2V", "cogvideox1.5-5b-i2v", (768, 1360)),
|
("THUDM/CogVideoX1.5-5b-I2V", "cogvideox1.5-5b-i2v", (768, 1360)),
|
||||||
]
|
]
|
||||||
|
|
||||||
for model_path, expected_key, expected_resolution in test_cases:
|
for model_path, expected_key, expected_resolution in test_cases:
|
||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
|
|
||||||
# Extract model name using the same logic as cli_demo
|
# Extract model name using the same logic as cli_demo
|
||||||
model_name = model_path.split("/")[-1].lower()
|
model_name = model_path.split("/")[-1].lower()
|
||||||
|
|
||||||
# Verify the model name is in the map
|
# Verify the model name is in the map
|
||||||
assert expected_key in resolution_map, f"Expected key '{expected_key}' not in RESOLUTION_MAP"
|
assert (
|
||||||
|
expected_key in resolution_map
|
||||||
|
), f"Expected key '{expected_key}' not in RESOLUTION_MAP"
|
||||||
assert resolution_map[expected_key] == expected_resolution
|
assert resolution_map[expected_key] == expected_resolution
|
||||||
|
|
||||||
# Test that accessing the resolution works without KeyError
|
# Test that accessing the resolution works without KeyError
|
||||||
try:
|
try:
|
||||||
actual_resolution = resolution_map[model_name]
|
actual_resolution = resolution_map[model_name]
|
||||||
@ -82,27 +88,27 @@ class TestResolutionMapFallback:
|
|||||||
|
|
||||||
def test_invalid_model_name_raises_key_error(self, resolution_map):
|
def test_invalid_model_name_raises_key_error(self, resolution_map):
|
||||||
"""Test that invalid model names raise KeyError (which is then caught by try/except)."""
|
"""Test that invalid model names raise KeyError (which is then caught by try/except)."""
|
||||||
|
|
||||||
invalid_models = [
|
invalid_models = [
|
||||||
"invalidmodel-xyz",
|
"invalidmodel-xyz",
|
||||||
"unknown-5b",
|
"unknown-5b",
|
||||||
"cogvideox-5b-custom", # Not in map
|
"cogvideox-5b-custom", # Not in map
|
||||||
]
|
]
|
||||||
|
|
||||||
for invalid_model in invalid_models:
|
for invalid_model in invalid_models:
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
_ = resolution_map[invalid_model]
|
_ = resolution_map[invalid_model]
|
||||||
|
|
||||||
def test_local_path_model_name_extraction(self):
|
def test_local_path_model_name_extraction(self):
|
||||||
"""Test that local paths are parsed correctly to extract model names."""
|
"""Test that local paths are parsed correctly to extract model names."""
|
||||||
|
|
||||||
local_paths = [
|
local_paths = [
|
||||||
("./local_models/CogVideoX-5B", "cogvideox-5b"),
|
("./local_models/CogVideoX-5B", "cogvideox-5b"),
|
||||||
("/path/to/custom_model", "custom_model"),
|
("/path/to/custom_model", "custom_model"),
|
||||||
("~/models/my-cogvideo-model", "my-cogvideo-model"),
|
("~/models/my-cogvideo-model", "my-cogvideo-model"),
|
||||||
("./CogVideoX-Custom", "cogvideox-custom"),
|
("./CogVideoX-Custom", "cogvideox-custom"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for local_path, expected_model_name in local_paths:
|
for local_path, expected_model_name in local_paths:
|
||||||
# Extract model name using the same logic as cli_demo
|
# Extract model name using the same logic as cli_demo
|
||||||
model_name = local_path.split("/")[-1].lower()
|
model_name = local_path.split("/")[-1].lower()
|
||||||
@ -111,17 +117,21 @@ class TestResolutionMapFallback:
|
|||||||
def test_fallback_code_structure(self, cli_demo_module):
|
def test_fallback_code_structure(self, cli_demo_module):
|
||||||
"""Test that the fallback code structure is present in generate_video function."""
|
"""Test that the fallback code structure is present in generate_video function."""
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
# Get the source code of generate_video
|
# Get the source code of generate_video
|
||||||
source = inspect.getsource(cli_demo_module.generate_video)
|
source = inspect.getsource(cli_demo_module.generate_video)
|
||||||
|
|
||||||
# Verify try/except block exists
|
# Verify try/except block exists
|
||||||
assert "try:" in source, "generate_video should have try/except block"
|
assert "try:" in source, "generate_video should have try/except block"
|
||||||
assert "except KeyError:" in source, "Should catch KeyError specifically"
|
assert "except KeyError:" in source, "Should catch KeyError specifically"
|
||||||
assert "RESOLUTION_MAP[model_name]" in source, "Should access RESOLUTION_MAP with model_name"
|
assert (
|
||||||
|
"RESOLUTION_MAP[model_name]" in source
|
||||||
|
), "Should access RESOLUTION_MAP with model_name"
|
||||||
|
|
||||||
# Verify fallback logic exists
|
# Verify fallback logic exists
|
||||||
assert "default_resolution" in source or "(480, 720)" in source, "Should define default resolution"
|
assert (
|
||||||
|
"default_resolution" in source or "(480, 720)" in source
|
||||||
|
), "Should define default resolution"
|
||||||
assert "logging.warning" in source, "Should log warning on fallback"
|
assert "logging.warning" in source, "Should log warning on fallback"
|
||||||
assert "valid models" in source.lower(), "Warning should mention valid models"
|
assert "valid models" in source.lower(), "Warning should mention valid models"
|
||||||
|
|
||||||
@ -134,40 +144,39 @@ class TestResolutionMapFallback:
|
|||||||
"cogvideox-5b",
|
"cogvideox-5b",
|
||||||
"cogvideox-2b",
|
"cogvideox-2b",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert set(resolution_map.keys()) == expected_models, (
|
assert (
|
||||||
"RESOLUTION_MAP should contain all expected models"
|
set(resolution_map.keys()) == expected_models
|
||||||
)
|
), "RESOLUTION_MAP should contain all expected models"
|
||||||
|
|
||||||
# Verify all resolutions are tuples of two integers
|
# Verify all resolutions are tuples of two integers
|
||||||
for model_name, resolution in resolution_map.items():
|
for model_name, resolution in resolution_map.items():
|
||||||
assert isinstance(resolution, tuple), (
|
assert isinstance(resolution, tuple), f"Resolution for '{model_name}' should be a tuple"
|
||||||
f"Resolution for '{model_name}' should be a tuple"
|
assert (
|
||||||
)
|
len(resolution) == 2
|
||||||
assert len(resolution) == 2, (
|
), f"Resolution for '{model_name}' should have 2 elements (height, width)"
|
||||||
f"Resolution for '{model_name}' should have 2 elements (height, width)"
|
assert all(
|
||||||
)
|
isinstance(dim, int) for dim in resolution
|
||||||
assert all(isinstance(dim, int) for dim in resolution), (
|
), f"Resolution for '{model_name}' should contain integers"
|
||||||
f"Resolution for '{model_name}' should contain integers"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_fallback_warning_message_format(self, cli_demo_module, caplog):
|
def test_fallback_warning_message_format(self, cli_demo_module, caplog):
|
||||||
"""Test the format and content of the fallback warning message."""
|
"""Test the format and content of the fallback warning message."""
|
||||||
|
|
||||||
# Mock the generate_video call to test just the resolution lookup part
|
# Mock the generate_video call to test just the resolution lookup part
|
||||||
with patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline, \
|
with (
|
||||||
patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'), \
|
patch.object(cli_demo_module, 'CogVideoXPipeline') as mock_pipeline,
|
||||||
patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'), \
|
patch.object(cli_demo_module, 'CogVideoXImageToVideoPipeline'),
|
||||||
patch.object(cli_demo_module, 'export_to_video'), \
|
patch.object(cli_demo_module, 'CogVideoXVideoToVideoPipeline'),
|
||||||
caplog.at_level(logging.WARNING):
|
patch.object(cli_demo_module, 'export_to_video'),
|
||||||
|
caplog.at_level(logging.WARNING),
|
||||||
|
):
|
||||||
# Setup mock to avoid actual model loading
|
# Setup mock to avoid actual model loading
|
||||||
mock_instance = MagicMock()
|
mock_instance = MagicMock()
|
||||||
mock_pipeline.from_pretrained.return_value = mock_instance
|
mock_pipeline.from_pretrained.return_value = mock_instance
|
||||||
mock_instance.return_value = MagicMock(frames=[])
|
mock_instance.return_value = MagicMock(frames=[])
|
||||||
|
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cli_demo_module.generate_video(
|
cli_demo_module.generate_video(
|
||||||
prompt="Test prompt",
|
prompt="Test prompt",
|
||||||
@ -178,39 +187,42 @@ class TestResolutionMapFallback:
|
|||||||
except Exception:
|
except Exception:
|
||||||
# We're only interested in the warning, not successful execution
|
# We're only interested in the warning, not successful execution
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check that a warning was logged
|
# Check that a warning was logged
|
||||||
warning_records = [r for r in caplog.records if r.levelname == "WARNING"]
|
warning_records = [r for r in caplog.records if r.levelname == "WARNING"]
|
||||||
assert len(warning_records) > 0, "Should log at least one warning"
|
assert len(warning_records) > 0, "Should log at least one warning"
|
||||||
|
|
||||||
# Get the first warning message
|
# Get the first warning message
|
||||||
warning_msg = warning_records[0].message
|
warning_msg = warning_records[0].message
|
||||||
|
|
||||||
# Verify warning message contains key information
|
# Verify warning message contains key information
|
||||||
assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name"
|
assert "unknown-model" in warning_msg.lower(), "Should mention the invalid model name"
|
||||||
assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue"
|
assert "not found in resolution map" in warning_msg.lower(), "Should explain the issue"
|
||||||
assert "480" in warning_msg and "720" in warning_msg, "Should mention default resolution"
|
assert (
|
||||||
|
"480" in warning_msg and "720" in warning_msg
|
||||||
|
), "Should mention default resolution"
|
||||||
assert "valid models:" in warning_msg.lower(), "Should list valid models"
|
assert "valid models:" in warning_msg.lower(), "Should list valid models"
|
||||||
|
|
||||||
# Verify all valid model names are in the warning
|
# Verify all valid model names are in the warning
|
||||||
from inference.cli_demo import RESOLUTION_MAP
|
from inference.cli_demo import RESOLUTION_MAP
|
||||||
|
|
||||||
for model_name in RESOLUTION_MAP.keys():
|
for model_name in RESOLUTION_MAP.keys():
|
||||||
assert model_name in warning_msg, f"Warning should list '{model_name}'"
|
assert model_name in warning_msg, f"Warning should list '{model_name}'"
|
||||||
|
|
||||||
def test_case_insensitive_model_name_matching(self, resolution_map):
|
def test_case_insensitive_model_name_matching(self, resolution_map):
|
||||||
"""Test that model names are converted to lowercase for matching."""
|
"""Test that model names are converted to lowercase for matching."""
|
||||||
|
|
||||||
# All keys in RESOLUTION_MAP should be lowercase
|
# All keys in RESOLUTION_MAP should be lowercase
|
||||||
for key in resolution_map.keys():
|
for key in resolution_map.keys():
|
||||||
assert key == key.lower(), f"RESOLUTION_MAP key '{key}' should be lowercase"
|
assert key == key.lower(), f"RESOLUTION_MAP key '{key}' should be lowercase"
|
||||||
|
|
||||||
# Test that various case inputs would work after .lower() conversion
|
# Test that various case inputs would work after .lower() conversion
|
||||||
test_paths = [
|
test_paths = [
|
||||||
("THUDM/COGVIDEOX-5B", "cogvideox-5b"),
|
("THUDM/COGVIDEOX-5B", "cogvideox-5b"),
|
||||||
("thudm/cogvideox-5b", "cogvideox-5b"),
|
("thudm/cogvideox-5b", "cogvideox-5b"),
|
||||||
("ThUdM/CoGvIdEoX-5b", "cogvideox-5b"),
|
("ThUdM/CoGvIdEoX-5b", "cogvideox-5b"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for model_path, expected_key in test_paths:
|
for model_path, expected_key in test_paths:
|
||||||
model_name = model_path.split("/")[-1].lower()
|
model_name = model_path.split("/")[-1].lower()
|
||||||
assert model_name == expected_key
|
assert model_name == expected_key
|
||||||
@ -219,22 +231,22 @@ class TestResolutionMapFallback:
|
|||||||
def test_default_resolution_value(self, cli_demo_module):
|
def test_default_resolution_value(self, cli_demo_module):
|
||||||
"""Test that the default fallback resolution is (480, 720)."""
|
"""Test that the default fallback resolution is (480, 720)."""
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
source = inspect.getsource(cli_demo_module.generate_video)
|
source = inspect.getsource(cli_demo_module.generate_video)
|
||||||
|
|
||||||
# The default should be defined as (480, 720)
|
# The default should be defined as (480, 720)
|
||||||
assert "(480, 720)" in source, "Default resolution should be (480, 720)"
|
assert "(480, 720)" in source, "Default resolution should be (480, 720)"
|
||||||
|
|
||||||
def test_resolution_dimensions_order(self, resolution_map):
|
def test_resolution_dimensions_order(self, resolution_map):
|
||||||
"""Test that resolutions are in (height, width) format."""
|
"""Test that resolutions are in (height, width) format."""
|
||||||
|
|
||||||
# Based on the code comments and typical usage
|
# Based on the code comments and typical usage
|
||||||
for model_name, (dim1, dim2) in resolution_map.items():
|
for model_name, (dim1, dim2) in resolution_map.items():
|
||||||
# Height typically comes first
|
# Height typically comes first
|
||||||
# Verify dimensions are reasonable (not negative or zero)
|
# Verify dimensions are reasonable (not negative or zero)
|
||||||
assert dim1 > 0, f"First dimension for {model_name} should be positive"
|
assert dim1 > 0, f"First dimension for {model_name} should be positive"
|
||||||
assert dim2 > 0, f"Second dimension for {model_name} should be positive"
|
assert dim2 > 0, f"Second dimension for {model_name} should be positive"
|
||||||
|
|
||||||
# CogVideoX1.5 models should have larger resolutions
|
# CogVideoX1.5 models should have larger resolutions
|
||||||
if "1.5" in model_name:
|
if "1.5" in model_name:
|
||||||
assert dim1 >= 768, f"CogVideoX1.5 model {model_name} should have height >= 768"
|
assert dim1 >= 768, f"CogVideoX1.5 model {model_name} should have height >= 768"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user