mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-10 09:34:23 +08:00
test: add tests for multi-GPU device_map support
- Add comprehensive tests for device_map argument functionality - Test all device_map options: None, auto, balanced, sequential - Test all pipeline types: t2v, i2v, v2v - Verify CPU offload logic (enabled when device_map=None, disabled otherwise) - Verify from_pretrained receives correct device_map parameter - Test VAE optimizations are always enabled regardless of device_map - Test invalid device_map values raise ValueError - Verify backward compatibility (default behavior unchanged) - Mock heavy dependencies (torch, diffusers) for fast CI execution 47 tests covering all multi-GPU device_map scenarios
This commit is contained in:
parent
207020afea
commit
c5d6f7b557
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Tests for CogVideoX CLI tools
|
||||
581
tests/test_cli_demo_multi_gpu.py
Normal file
581
tests/test_cli_demo_multi_gpu.py
Normal file
@ -0,0 +1,581 @@
|
||||
"""
|
||||
Tests for multi-GPU device_map support in cli_demo.py
|
||||
|
||||
These tests verify:
|
||||
- device_map=None → CPU offload enabled, no device_map passed to from_pretrained
|
||||
- device_map="auto" → No CPU offload, device_map="auto" passed to from_pretrained
|
||||
- device_map="balanced" → No CPU offload, device_map="balanced" passed
|
||||
- device_map="sequential" → No CPU offload, device_map="sequential" passed
|
||||
- All three pipeline types (t2v, i2v, v2v) work correctly
|
||||
- Backward compatibility (default behavior unchanged)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
import importlib.util
|
||||
import pytest
|
||||
|
||||
|
||||
def create_mocked_cli_demo():
|
||||
"""
|
||||
Load cli_demo.py with mocked heavy dependencies.
|
||||
Returns the module and the mock objects for assertions.
|
||||
"""
|
||||
# Create mock for torch
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.bfloat16 = "bfloat16"
|
||||
mock_torch.float16 = "float16"
|
||||
mock_torch.Generator.return_value.manual_seed.return_value = MagicMock()
|
||||
|
||||
# Create pipeline mocks
|
||||
mock_t2v_pipe = MagicMock()
|
||||
mock_t2v_pipe.scheduler = MagicMock()
|
||||
mock_t2v_pipe.scheduler.config = {}
|
||||
mock_t2v_pipe.vae = MagicMock()
|
||||
mock_t2v_pipe.return_value = MagicMock(frames=[[MagicMock()]])
|
||||
|
||||
mock_i2v_pipe = MagicMock()
|
||||
mock_i2v_pipe.scheduler = MagicMock()
|
||||
mock_i2v_pipe.scheduler.config = {}
|
||||
mock_i2v_pipe.vae = MagicMock()
|
||||
mock_i2v_pipe.return_value = MagicMock(frames=[[MagicMock()]])
|
||||
|
||||
mock_v2v_pipe = MagicMock()
|
||||
mock_v2v_pipe.scheduler = MagicMock()
|
||||
mock_v2v_pipe.scheduler.config = {}
|
||||
mock_v2v_pipe.vae = MagicMock()
|
||||
mock_v2v_pipe.return_value = MagicMock(frames=[[MagicMock()]])
|
||||
|
||||
# Create pipeline class mocks
|
||||
mock_CogVideoXPipeline = MagicMock()
|
||||
mock_CogVideoXPipeline.from_pretrained.return_value = mock_t2v_pipe
|
||||
|
||||
mock_CogVideoXImageToVideoPipeline = MagicMock()
|
||||
mock_CogVideoXImageToVideoPipeline.from_pretrained.return_value = mock_i2v_pipe
|
||||
|
||||
mock_CogVideoXVideoToVideoPipeline = MagicMock()
|
||||
mock_CogVideoXVideoToVideoPipeline.from_pretrained.return_value = mock_v2v_pipe
|
||||
|
||||
mock_CogVideoXDPMScheduler = MagicMock()
|
||||
mock_CogVideoXDPMScheduler.from_config.return_value = MagicMock()
|
||||
|
||||
# Create mock diffusers module
|
||||
mock_diffusers = MagicMock()
|
||||
mock_diffusers.CogVideoXPipeline = mock_CogVideoXPipeline
|
||||
mock_diffusers.CogVideoXImageToVideoPipeline = mock_CogVideoXImageToVideoPipeline
|
||||
mock_diffusers.CogVideoXVideoToVideoPipeline = mock_CogVideoXVideoToVideoPipeline
|
||||
mock_diffusers.CogVideoXDPMScheduler = mock_CogVideoXDPMScheduler
|
||||
|
||||
mock_diffusers_utils = MagicMock()
|
||||
mock_diffusers_utils.export_to_video = MagicMock()
|
||||
mock_diffusers_utils.load_image = MagicMock(return_value=MagicMock())
|
||||
mock_diffusers_utils.load_video = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Save original modules
|
||||
original_modules = {}
|
||||
for mod_name in ['torch', 'diffusers', 'diffusers.utils']:
|
||||
if mod_name in sys.modules:
|
||||
original_modules[mod_name] = sys.modules[mod_name]
|
||||
|
||||
# Remove cli_demo if cached
|
||||
if 'cli_demo' in sys.modules:
|
||||
del sys.modules['cli_demo']
|
||||
|
||||
# Install mocks
|
||||
sys.modules['torch'] = mock_torch
|
||||
sys.modules['diffusers'] = mock_diffusers
|
||||
sys.modules['diffusers.utils'] = mock_diffusers_utils
|
||||
|
||||
try:
|
||||
# Load cli_demo.py using importlib
|
||||
cli_demo_path = os.path.join(os.path.dirname(__file__), '..', 'inference', 'cli_demo.py')
|
||||
spec = importlib.util.spec_from_file_location("cli_demo", cli_demo_path)
|
||||
cli_demo = importlib.util.module_from_spec(spec)
|
||||
sys.modules['cli_demo'] = cli_demo
|
||||
spec.loader.exec_module(cli_demo)
|
||||
|
||||
return {
|
||||
'cli_demo': cli_demo,
|
||||
'mock_torch': mock_torch,
|
||||
'mock_CogVideoXPipeline': mock_CogVideoXPipeline,
|
||||
'mock_CogVideoXImageToVideoPipeline': mock_CogVideoXImageToVideoPipeline,
|
||||
'mock_CogVideoXVideoToVideoPipeline': mock_CogVideoXVideoToVideoPipeline,
|
||||
'mock_CogVideoXDPMScheduler': mock_CogVideoXDPMScheduler,
|
||||
'mock_t2v_pipe': mock_t2v_pipe,
|
||||
'mock_i2v_pipe': mock_i2v_pipe,
|
||||
'mock_v2v_pipe': mock_v2v_pipe,
|
||||
}
|
||||
finally:
|
||||
# Restore original modules
|
||||
for mod_name, mod in original_modules.items():
|
||||
sys.modules[mod_name] = mod
|
||||
|
||||
|
||||
class TestDeviceMapSupport:
|
||||
"""Test suite for multi-GPU device_map functionality."""
|
||||
|
||||
def _get_common_args(self, generate_type="t2v", device_map=None):
|
||||
"""Get common arguments for generate_video function."""
|
||||
args = {
|
||||
"prompt": "A test video",
|
||||
"model_path": "THUDM/CogVideoX-5b",
|
||||
"generate_type": generate_type,
|
||||
"output_path": "./test_output.mp4",
|
||||
"num_inference_steps": 1,
|
||||
"num_frames": 9,
|
||||
"seed": 42,
|
||||
"device_map": device_map,
|
||||
}
|
||||
if generate_type in ("i2v", "v2v"):
|
||||
args["image_or_video_path"] = "/fake/path/image.png"
|
||||
return args
|
||||
|
||||
# =========================================================================
|
||||
# Tests for device_map=None (default behavior - CPU offload)
|
||||
# =========================================================================
|
||||
|
||||
def test_device_map_none_t2v_enables_cpu_offload(self):
|
||||
"""Test that device_map=None enables CPU offload for t2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=None))
|
||||
|
||||
# Verify from_pretrained was called WITHOUT device_map
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert 'device_map' not in call_kwargs, "device_map should NOT be passed when None"
|
||||
|
||||
# Verify CPU offload was enabled
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
|
||||
def test_device_map_none_i2v_enables_cpu_offload(self):
|
||||
"""Test that device_map=None enables CPU offload for i2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map=None))
|
||||
|
||||
# Verify from_pretrained was called WITHOUT device_map
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert 'device_map' not in call_kwargs, "device_map should NOT be passed when None"
|
||||
|
||||
# Verify CPU offload was enabled
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
|
||||
def test_device_map_none_v2v_enables_cpu_offload(self):
|
||||
"""Test that device_map=None enables CPU offload for v2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map=None))
|
||||
|
||||
# Verify from_pretrained was called WITHOUT device_map
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert 'device_map' not in call_kwargs, "device_map should NOT be passed when None"
|
||||
|
||||
# Verify CPU offload was enabled
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
|
||||
# =========================================================================
|
||||
# Tests for device_map="auto"
|
||||
# =========================================================================
|
||||
|
||||
def test_device_map_auto_t2v_no_cpu_offload(self):
|
||||
"""Test that device_map='auto' passes device_map and skips CPU offload for t2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="auto"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="auto"
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "auto"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_auto_i2v_no_cpu_offload(self):
|
||||
"""Test that device_map='auto' passes device_map and skips CPU offload for i2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map="auto"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="auto"
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "auto"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_auto_v2v_no_cpu_offload(self):
|
||||
"""Test that device_map='auto' passes device_map and skips CPU offload for v2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map="auto"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="auto"
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "auto"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# Tests for device_map="balanced"
|
||||
# =========================================================================
|
||||
|
||||
def test_device_map_balanced_t2v_no_cpu_offload(self):
|
||||
"""Test that device_map='balanced' passes device_map and skips CPU offload for t2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="balanced"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="balanced"
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "balanced"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_balanced_i2v_no_cpu_offload(self):
|
||||
"""Test that device_map='balanced' passes device_map and skips CPU offload for i2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map="balanced"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="balanced"
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "balanced"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_balanced_v2v_no_cpu_offload(self):
|
||||
"""Test that device_map='balanced' passes device_map and skips CPU offload for v2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map="balanced"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="balanced"
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "balanced"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# Tests for device_map="sequential"
|
||||
# =========================================================================
|
||||
|
||||
def test_device_map_sequential_t2v_no_cpu_offload(self):
|
||||
"""Test that device_map='sequential' passes device_map and skips CPU offload for t2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="sequential"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "sequential"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_sequential_i2v_no_cpu_offload(self):
|
||||
"""Test that device_map='sequential' passes device_map and skips CPU offload for i2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map="sequential"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "sequential"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_sequential_v2v_no_cpu_offload(self):
|
||||
"""Test that device_map='sequential' passes device_map and skips CPU offload for v2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map="sequential"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "sequential"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# VAE optimizations (should always be enabled)
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.parametrize("device_map", [None, "auto", "balanced", "sequential"])
|
||||
def test_vae_optimizations_always_enabled_t2v(self, device_map):
|
||||
"""Test that VAE slicing and tiling are always enabled for t2v regardless of device_map."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=device_map))
|
||||
|
||||
# VAE optimizations should always be called
|
||||
mocks['mock_t2v_pipe'].vae.enable_slicing.assert_called_once()
|
||||
mocks['mock_t2v_pipe'].vae.enable_tiling.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize("device_map", [None, "auto", "balanced", "sequential"])
|
||||
def test_vae_optimizations_always_enabled_i2v(self, device_map):
|
||||
"""Test that VAE slicing and tiling are always enabled for i2v regardless of device_map."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map=device_map))
|
||||
|
||||
# VAE optimizations should always be called
|
||||
mocks['mock_i2v_pipe'].vae.enable_slicing.assert_called_once()
|
||||
mocks['mock_i2v_pipe'].vae.enable_tiling.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize("device_map", [None, "auto", "balanced", "sequential"])
|
||||
def test_vae_optimizations_always_enabled_v2v(self, device_map):
|
||||
"""Test that VAE slicing and tiling are always enabled for v2v regardless of device_map."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map=device_map))
|
||||
|
||||
# VAE optimizations should always be called
|
||||
mocks['mock_v2v_pipe'].vae.enable_slicing.assert_called_once()
|
||||
mocks['mock_v2v_pipe'].vae.enable_tiling.assert_called_once()
|
||||
|
||||
# =========================================================================
|
||||
# Invalid device_map values
|
||||
# =========================================================================
|
||||
|
||||
def test_invalid_device_map_raises_error(self):
|
||||
"""Test that invalid device_map values raise ValueError."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="invalid"))
|
||||
|
||||
assert "Invalid device_map" in str(exc_info.value)
|
||||
assert "invalid" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize("invalid_value", ["cuda:0", "cpu", "GPU", "multi", "distributed"])
|
||||
def test_various_invalid_device_map_values(self, invalid_value):
|
||||
"""Test that various invalid device_map values all raise ValueError."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=invalid_value))
|
||||
|
||||
assert "Invalid device_map" in str(exc_info.value)
|
||||
|
||||
# =========================================================================
|
||||
# Backward compatibility tests
|
||||
# =========================================================================
|
||||
|
||||
def test_default_behavior_unchanged_no_device_map_arg(self):
|
||||
"""Test that default behavior (no device_map argument) remains unchanged."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
# Call without device_map argument at all (relies on default)
|
||||
args = {
|
||||
"prompt": "A test video",
|
||||
"model_path": "THUDM/CogVideoX-5b",
|
||||
"generate_type": "t2v",
|
||||
"output_path": "./test_output.mp4",
|
||||
"num_inference_steps": 1,
|
||||
"num_frames": 9,
|
||||
"seed": 42,
|
||||
# Note: device_map is intentionally NOT included to test default
|
||||
}
|
||||
cli_demo.generate_video(**args)
|
||||
|
||||
# Verify from_pretrained was called WITHOUT device_map
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert 'device_map' not in call_kwargs
|
||||
|
||||
# Verify CPU offload WAS enabled (default single-GPU behavior)
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
|
||||
# =========================================================================
|
||||
# Parametrized comprehensive tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
])
|
||||
def test_device_map_behavior_t2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with t2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=device_map))
|
||||
|
||||
# Check from_pretrained call
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
|
||||
if should_have_device_map:
|
||||
assert 'device_map' in call_kwargs
|
||||
assert call_kwargs['device_map'] == device_map
|
||||
else:
|
||||
assert 'device_map' not in call_kwargs
|
||||
|
||||
# Check CPU offload
|
||||
if should_cpu_offload:
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
else:
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
])
|
||||
def test_device_map_behavior_i2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with i2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map=device_map))
|
||||
|
||||
# Check from_pretrained call
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
|
||||
if should_have_device_map:
|
||||
assert 'device_map' in call_kwargs
|
||||
assert call_kwargs['device_map'] == device_map
|
||||
else:
|
||||
assert 'device_map' not in call_kwargs
|
||||
|
||||
# Check CPU offload
|
||||
if should_cpu_offload:
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
else:
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("device_map,should_have_device_map,should_cpu_offload", [
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
])
|
||||
def test_device_map_behavior_v2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with v2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map=device_map))
|
||||
|
||||
# Check from_pretrained call
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
|
||||
if should_have_device_map:
|
||||
assert 'device_map' in call_kwargs
|
||||
assert call_kwargs['device_map'] == device_map
|
||||
else:
|
||||
assert 'device_map' not in call_kwargs
|
||||
|
||||
# Check CPU offload
|
||||
if should_cpu_offload:
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
else:
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
|
||||
class TestValidDeviceMapsConstant:
|
||||
"""Test the VALID_DEVICE_MAPS constant."""
|
||||
|
||||
def test_valid_device_maps_contains_expected_values(self):
|
||||
"""Verify VALID_DEVICE_MAPS contains the expected device map options."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
assert "auto" in cli_demo.VALID_DEVICE_MAPS
|
||||
assert "balanced" in cli_demo.VALID_DEVICE_MAPS
|
||||
assert "sequential" in cli_demo.VALID_DEVICE_MAPS
|
||||
assert len(cli_demo.VALID_DEVICE_MAPS) == 3
|
||||
|
||||
def test_valid_device_maps_is_set(self):
|
||||
"""Verify VALID_DEVICE_MAPS is a set for O(1) lookup."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
assert isinstance(cli_demo.VALID_DEVICE_MAPS, set)
|
||||
|
||||
|
||||
class TestCLIArgumentParsing:
|
||||
"""Test CLI argument parsing for device_map."""
|
||||
|
||||
def test_cli_device_map_choices(self):
|
||||
"""Test that CLI parser accepts valid device_map choices."""
|
||||
import argparse
|
||||
|
||||
# Create a minimal parser matching the CLI
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--device_map",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["auto", "balanced", "sequential"],
|
||||
)
|
||||
|
||||
# Test valid values
|
||||
for valid_value in ["auto", "balanced", "sequential"]:
|
||||
args = parser.parse_args(["--device_map", valid_value])
|
||||
assert args.device_map == valid_value
|
||||
|
||||
# Test default (no argument)
|
||||
args = parser.parse_args([])
|
||||
assert args.device_map is None
|
||||
|
||||
def test_cli_device_map_invalid_choice(self):
|
||||
"""Test that CLI parser rejects invalid device_map choices."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--device_map",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["auto", "balanced", "sequential"],
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
parser.parse_args(["--device_map", "invalid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
x
Reference in New Issue
Block a user