mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-07 23:40:23 +08:00
Merge 5a69df93f3274f4af54cd810ba14fe3c6e88d1eb into 7a1af7154511e0ce4e4be8d62faa8c5e5a3532d2
This commit is contained in:
commit
fe088f71b3
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,6 +8,8 @@ logs/
|
||||
.idea
|
||||
output*
|
||||
test*
|
||||
!tests/
|
||||
!tests/**
|
||||
venv
|
||||
**/.swp
|
||||
**/*.log
|
||||
|
||||
@ -11,11 +11,24 @@ Running the Script:
|
||||
To run the script, use the following command with appropriate arguments:
|
||||
|
||||
```bash
|
||||
# Single GPU (default behavior, uses CPU offload):
|
||||
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v"
|
||||
|
||||
# Multi-GPU with balanced device mapping:
|
||||
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v" --device_map balanced
|
||||
|
||||
# Multi-GPU with auto device mapping:
|
||||
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v" --device_map auto
|
||||
```
|
||||
|
||||
You can change `pipe.enable_sequential_cpu_offload()` to `pipe.enable_model_cpu_offload()` to speed up inference, but this will use more GPU memory
|
||||
|
||||
Multi-GPU Support:
|
||||
- Use `--device_map balanced` to distribute the model evenly across available GPUs (recommended for inference)
|
||||
- Use `--device_map auto` for automatic device placement by accelerate
|
||||
- Use `--device_map sequential` to fill GPUs sequentially (useful for uneven memory)
|
||||
- Default behavior (no --device_map) uses CPU offload for single-GPU setups
|
||||
|
||||
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||
|
||||
"""
|
||||
@ -48,6 +61,9 @@ RESOLUTION_MAP = {
|
||||
"cogvideox-2b": (480, 720),
|
||||
}
|
||||
|
||||
# Valid device_map options for multi-GPU support
|
||||
VALID_DEVICE_MAPS = {"auto", "balanced", "sequential"}
|
||||
|
||||
|
||||
def generate_video(
|
||||
prompt: str,
|
||||
@ -66,6 +82,7 @@ def generate_video(
|
||||
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
||||
seed: int = 42,
|
||||
fps: int = 16,
|
||||
device_map: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generates a video based on the given prompt and saves it to the specified path.
|
||||
@ -86,11 +103,23 @@ def generate_video(
|
||||
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
|
||||
- seed (int): The seed for reproducibility.
|
||||
- fps (int): The frames per second for the generated video.
|
||||
- device_map (str): Device placement strategy for multi-GPU support. Options:
|
||||
- None (default): Uses sequential CPU offload for single-GPU setups
|
||||
- "balanced": Distributes model layers evenly across available GPUs (recommended)
|
||||
- "auto": Automatic device placement by accelerate library
|
||||
- "sequential": Fills GPUs one by one in order
|
||||
|
||||
Multi-GPU Usage Examples:
|
||||
# Balanced distribution across GPUs (recommended):
|
||||
generate_video(prompt="...", model_path="...", device_map="balanced")
|
||||
|
||||
# Automatic device placement:
|
||||
generate_video(prompt="...", model_path="...", device_map="auto")
|
||||
"""
|
||||
|
||||
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
|
||||
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
|
||||
# function to use Multi GPUs.
|
||||
# When device_map is specified, the model is distributed across multiple GPUs.
|
||||
# When device_map is None (default), CPU offload is enabled for single-GPU setups.
|
||||
|
||||
image = None
|
||||
video = None
|
||||
@ -115,13 +144,25 @@ def generate_video(
|
||||
)
|
||||
height, width = desired_resolution
|
||||
|
||||
# Validate device_map if provided
|
||||
if device_map is not None and device_map not in VALID_DEVICE_MAPS:
|
||||
raise ValueError(
|
||||
f"Invalid device_map '{device_map}'. Must be one of: {', '.join(sorted(VALID_DEVICE_MAPS))} or None"
|
||||
)
|
||||
|
||||
# Build kwargs for from_pretrained - add device_map only when specified
|
||||
load_kwargs = {"torch_dtype": dtype}
|
||||
if device_map is not None:
|
||||
load_kwargs["device_map"] = device_map
|
||||
logging.info(f"\033[1mUsing device_map='{device_map}' for multi-GPU inference\033[0m")
|
||||
|
||||
if generate_type == "i2v":
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, **load_kwargs)
|
||||
image = load_image(image=image_or_video_path)
|
||||
elif generate_type == "t2v":
|
||||
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||
pipe = CogVideoXPipeline.from_pretrained(model_path, **load_kwargs)
|
||||
else:
|
||||
pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||
pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, **load_kwargs)
|
||||
video = load_video(image_or_video_path)
|
||||
|
||||
# If you're using with lora, add this code
|
||||
@ -141,13 +182,16 @@ def generate_video(
|
||||
pipe.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
|
||||
# 3. Enable CPU offload for the model.
|
||||
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
||||
# and enable to("cuda")
|
||||
# pipe.to("cuda")
|
||||
# 3. Enable CPU offload for the model (only when not using multi-GPU device_map).
|
||||
# When device_map is specified, the model is already distributed across GPUs,
|
||||
# so CPU offload is not needed and would conflict with device placement.
|
||||
if device_map is None:
|
||||
# Single-GPU mode: use CPU offload to manage memory
|
||||
# Turn off if you have multiple GPUs or enough GPU memory (such as H100)
|
||||
# pipe.enable_model_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
# pipe.enable_model_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
# VAE optimizations work in both single and multi-GPU modes
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
@ -248,6 +292,13 @@ if __name__ == "__main__":
|
||||
"--dtype", type=str, default="bfloat16", help="The data type for computation"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
||||
parser.add_argument(
|
||||
"--device_map",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["auto", "balanced", "sequential"],
|
||||
help="Device placement strategy for multi-GPU inference. Options: 'balanced' (recommended, distributes evenly), 'auto' (automatic placement), 'sequential' (fills GPUs in order). Default: None (uses CPU offload for single-GPU)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
||||
@ -268,4 +319,5 @@ if __name__ == "__main__":
|
||||
generate_type=args.generate_type,
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
device_map=args.device_map,
|
||||
)
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Tests for CogVideoX CLI tools
|
||||
600
tests/test_cli_demo_multi_gpu.py
Normal file
600
tests/test_cli_demo_multi_gpu.py
Normal file
@ -0,0 +1,600 @@
|
||||
"""
|
||||
Tests for multi-GPU device_map support in cli_demo.py
|
||||
|
||||
These tests verify:
|
||||
- device_map=None → CPU offload enabled, no device_map passed to from_pretrained
|
||||
- device_map="auto" → No CPU offload, device_map="auto" passed to from_pretrained
|
||||
- device_map="balanced" → No CPU offload, device_map="balanced" passed
|
||||
- device_map="sequential" → No CPU offload, device_map="sequential" passed
|
||||
- All three pipeline types (t2v, i2v, v2v) work correctly
|
||||
- Backward compatibility (default behavior unchanged)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
import importlib.util
|
||||
import pytest
|
||||
|
||||
|
||||
def create_mocked_cli_demo():
|
||||
"""
|
||||
Load cli_demo.py with mocked heavy dependencies.
|
||||
Returns the module and the mock objects for assertions.
|
||||
"""
|
||||
# Create mock for torch
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.bfloat16 = "bfloat16"
|
||||
mock_torch.float16 = "float16"
|
||||
mock_torch.Generator.return_value.manual_seed.return_value = MagicMock()
|
||||
|
||||
# Create pipeline mocks
|
||||
mock_t2v_pipe = MagicMock()
|
||||
mock_t2v_pipe.scheduler = MagicMock()
|
||||
mock_t2v_pipe.scheduler.config = {}
|
||||
mock_t2v_pipe.vae = MagicMock()
|
||||
mock_t2v_pipe.return_value = MagicMock(frames=[[MagicMock()]])
|
||||
|
||||
mock_i2v_pipe = MagicMock()
|
||||
mock_i2v_pipe.scheduler = MagicMock()
|
||||
mock_i2v_pipe.scheduler.config = {}
|
||||
mock_i2v_pipe.vae = MagicMock()
|
||||
mock_i2v_pipe.return_value = MagicMock(frames=[[MagicMock()]])
|
||||
|
||||
mock_v2v_pipe = MagicMock()
|
||||
mock_v2v_pipe.scheduler = MagicMock()
|
||||
mock_v2v_pipe.scheduler.config = {}
|
||||
mock_v2v_pipe.vae = MagicMock()
|
||||
mock_v2v_pipe.return_value = MagicMock(frames=[[MagicMock()]])
|
||||
|
||||
# Create pipeline class mocks
|
||||
mock_CogVideoXPipeline = MagicMock()
|
||||
mock_CogVideoXPipeline.from_pretrained.return_value = mock_t2v_pipe
|
||||
|
||||
mock_CogVideoXImageToVideoPipeline = MagicMock()
|
||||
mock_CogVideoXImageToVideoPipeline.from_pretrained.return_value = mock_i2v_pipe
|
||||
|
||||
mock_CogVideoXVideoToVideoPipeline = MagicMock()
|
||||
mock_CogVideoXVideoToVideoPipeline.from_pretrained.return_value = mock_v2v_pipe
|
||||
|
||||
mock_CogVideoXDPMScheduler = MagicMock()
|
||||
mock_CogVideoXDPMScheduler.from_config.return_value = MagicMock()
|
||||
|
||||
# Create mock diffusers module
|
||||
mock_diffusers = MagicMock()
|
||||
mock_diffusers.CogVideoXPipeline = mock_CogVideoXPipeline
|
||||
mock_diffusers.CogVideoXImageToVideoPipeline = mock_CogVideoXImageToVideoPipeline
|
||||
mock_diffusers.CogVideoXVideoToVideoPipeline = mock_CogVideoXVideoToVideoPipeline
|
||||
mock_diffusers.CogVideoXDPMScheduler = mock_CogVideoXDPMScheduler
|
||||
|
||||
mock_diffusers_utils = MagicMock()
|
||||
mock_diffusers_utils.export_to_video = MagicMock()
|
||||
mock_diffusers_utils.load_image = MagicMock(return_value=MagicMock())
|
||||
mock_diffusers_utils.load_video = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Save original modules
|
||||
original_modules = {}
|
||||
for mod_name in ['torch', 'diffusers', 'diffusers.utils']:
|
||||
if mod_name in sys.modules:
|
||||
original_modules[mod_name] = sys.modules[mod_name]
|
||||
|
||||
# Remove cli_demo if cached
|
||||
if 'cli_demo' in sys.modules:
|
||||
del sys.modules['cli_demo']
|
||||
|
||||
# Install mocks
|
||||
sys.modules['torch'] = mock_torch
|
||||
sys.modules['diffusers'] = mock_diffusers
|
||||
sys.modules['diffusers.utils'] = mock_diffusers_utils
|
||||
|
||||
try:
|
||||
# Load cli_demo.py using importlib
|
||||
cli_demo_path = os.path.join(os.path.dirname(__file__), '..', 'inference', 'cli_demo.py')
|
||||
spec = importlib.util.spec_from_file_location("cli_demo", cli_demo_path)
|
||||
cli_demo = importlib.util.module_from_spec(spec)
|
||||
sys.modules['cli_demo'] = cli_demo
|
||||
spec.loader.exec_module(cli_demo)
|
||||
|
||||
return {
|
||||
'cli_demo': cli_demo,
|
||||
'mock_torch': mock_torch,
|
||||
'mock_CogVideoXPipeline': mock_CogVideoXPipeline,
|
||||
'mock_CogVideoXImageToVideoPipeline': mock_CogVideoXImageToVideoPipeline,
|
||||
'mock_CogVideoXVideoToVideoPipeline': mock_CogVideoXVideoToVideoPipeline,
|
||||
'mock_CogVideoXDPMScheduler': mock_CogVideoXDPMScheduler,
|
||||
'mock_t2v_pipe': mock_t2v_pipe,
|
||||
'mock_i2v_pipe': mock_i2v_pipe,
|
||||
'mock_v2v_pipe': mock_v2v_pipe,
|
||||
}
|
||||
finally:
|
||||
# Restore original modules
|
||||
for mod_name, mod in original_modules.items():
|
||||
sys.modules[mod_name] = mod
|
||||
|
||||
|
||||
class TestDeviceMapSupport:
|
||||
"""Test suite for multi-GPU device_map functionality."""
|
||||
|
||||
def _get_common_args(self, generate_type="t2v", device_map=None):
|
||||
"""Get common arguments for generate_video function."""
|
||||
args = {
|
||||
"prompt": "A test video",
|
||||
"model_path": "THUDM/CogVideoX-5b",
|
||||
"generate_type": generate_type,
|
||||
"output_path": "./test_output.mp4",
|
||||
"num_inference_steps": 1,
|
||||
"num_frames": 9,
|
||||
"seed": 42,
|
||||
"device_map": device_map,
|
||||
}
|
||||
if generate_type in ("i2v", "v2v"):
|
||||
args["image_or_video_path"] = "/fake/path/image.png"
|
||||
return args
|
||||
|
||||
# =========================================================================
|
||||
# Tests for device_map=None (default behavior - CPU offload)
|
||||
# =========================================================================
|
||||
|
||||
def test_device_map_none_t2v_enables_cpu_offload(self):
|
||||
"""Test that device_map=None enables CPU offload for t2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=None))
|
||||
|
||||
# Verify from_pretrained was called WITHOUT device_map
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert 'device_map' not in call_kwargs, "device_map should NOT be passed when None"
|
||||
|
||||
# Verify CPU offload was enabled
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
|
||||
def test_device_map_none_i2v_enables_cpu_offload(self):
|
||||
"""Test that device_map=None enables CPU offload for i2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map=None))
|
||||
|
||||
# Verify from_pretrained was called WITHOUT device_map
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert 'device_map' not in call_kwargs, "device_map should NOT be passed when None"
|
||||
|
||||
# Verify CPU offload was enabled
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
|
||||
def test_device_map_none_v2v_enables_cpu_offload(self):
|
||||
"""Test that device_map=None enables CPU offload for v2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map=None))
|
||||
|
||||
# Verify from_pretrained was called WITHOUT device_map
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert 'device_map' not in call_kwargs, "device_map should NOT be passed when None"
|
||||
|
||||
# Verify CPU offload was enabled
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
|
||||
# =========================================================================
|
||||
# Tests for device_map="auto"
|
||||
# =========================================================================
|
||||
|
||||
def test_device_map_auto_t2v_no_cpu_offload(self):
|
||||
"""Test that device_map='auto' passes device_map and skips CPU offload for t2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="auto"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="auto"
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "auto"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_auto_i2v_no_cpu_offload(self):
|
||||
"""Test that device_map='auto' passes device_map and skips CPU offload for i2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map="auto"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="auto"
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "auto"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_auto_v2v_no_cpu_offload(self):
|
||||
"""Test that device_map='auto' passes device_map and skips CPU offload for v2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map="auto"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="auto"
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "auto"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# Tests for device_map="balanced"
|
||||
# =========================================================================
|
||||
|
||||
def test_device_map_balanced_t2v_no_cpu_offload(self):
|
||||
"""Test that device_map='balanced' passes device_map and skips CPU offload for t2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map="balanced"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="balanced"
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "balanced"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_balanced_i2v_no_cpu_offload(self):
|
||||
"""Test that device_map='balanced' passes device_map and skips CPU offload for i2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map="balanced"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="balanced"
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "balanced"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_balanced_v2v_no_cpu_offload(self):
|
||||
"""Test that device_map='balanced' passes device_map and skips CPU offload for v2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map="balanced"))
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="balanced"
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "balanced"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# Tests for device_map="sequential"
|
||||
# =========================================================================
|
||||
|
||||
def test_device_map_sequential_t2v_no_cpu_offload(self):
|
||||
"""Test that device_map='sequential' passes device_map and skips CPU offload for t2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="t2v", device_map="sequential")
|
||||
)
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "sequential"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_sequential_i2v_no_cpu_offload(self):
|
||||
"""Test that device_map='sequential' passes device_map and skips CPU offload for i2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="i2v", device_map="sequential")
|
||||
)
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "sequential"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
def test_device_map_sequential_v2v_no_cpu_offload(self):
|
||||
"""Test that device_map='sequential' passes device_map and skips CPU offload for v2v."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="v2v", device_map="sequential")
|
||||
)
|
||||
|
||||
# Verify from_pretrained was called WITH device_map="sequential"
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
assert call_kwargs['device_map'] == "sequential"
|
||||
|
||||
# Verify CPU offload was NOT enabled
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# VAE optimizations (should always be enabled)
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.parametrize("device_map", [None, "auto", "balanced", "sequential"])
|
||||
def test_vae_optimizations_always_enabled_t2v(self, device_map):
|
||||
"""Test that VAE slicing and tiling are always enabled for t2v regardless of device_map."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=device_map))
|
||||
|
||||
# VAE optimizations should always be called
|
||||
mocks['mock_t2v_pipe'].vae.enable_slicing.assert_called_once()
|
||||
mocks['mock_t2v_pipe'].vae.enable_tiling.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize("device_map", [None, "auto", "balanced", "sequential"])
|
||||
def test_vae_optimizations_always_enabled_i2v(self, device_map):
|
||||
"""Test that VAE slicing and tiling are always enabled for i2v regardless of device_map."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map=device_map))
|
||||
|
||||
# VAE optimizations should always be called
|
||||
mocks['mock_i2v_pipe'].vae.enable_slicing.assert_called_once()
|
||||
mocks['mock_i2v_pipe'].vae.enable_tiling.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize("device_map", [None, "auto", "balanced", "sequential"])
|
||||
def test_vae_optimizations_always_enabled_v2v(self, device_map):
|
||||
"""Test that VAE slicing and tiling are always enabled for v2v regardless of device_map."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map=device_map))
|
||||
|
||||
# VAE optimizations should always be called
|
||||
mocks['mock_v2v_pipe'].vae.enable_slicing.assert_called_once()
|
||||
mocks['mock_v2v_pipe'].vae.enable_tiling.assert_called_once()
|
||||
|
||||
# =========================================================================
|
||||
# Invalid device_map values
|
||||
# =========================================================================
|
||||
|
||||
def test_invalid_device_map_raises_error(self):
|
||||
"""Test that invalid device_map values raise ValueError."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="t2v", device_map="invalid")
|
||||
)
|
||||
|
||||
assert "Invalid device_map" in str(exc_info.value)
|
||||
assert "invalid" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize("invalid_value", ["cuda:0", "cpu", "GPU", "multi", "distributed"])
|
||||
def test_various_invalid_device_map_values(self, invalid_value):
|
||||
"""Test that various invalid device_map values all raise ValueError."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
cli_demo.generate_video(
|
||||
**self._get_common_args(generate_type="t2v", device_map=invalid_value)
|
||||
)
|
||||
|
||||
assert "Invalid device_map" in str(exc_info.value)
|
||||
|
||||
# =========================================================================
|
||||
# Backward compatibility tests
|
||||
# =========================================================================
|
||||
|
||||
def test_default_behavior_unchanged_no_device_map_arg(self):
|
||||
"""Test that default behavior (no device_map argument) remains unchanged."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
# Call without device_map argument at all (relies on default)
|
||||
args = {
|
||||
"prompt": "A test video",
|
||||
"model_path": "THUDM/CogVideoX-5b",
|
||||
"generate_type": "t2v",
|
||||
"output_path": "./test_output.mp4",
|
||||
"num_inference_steps": 1,
|
||||
"num_frames": 9,
|
||||
"seed": 42,
|
||||
# Note: device_map is intentionally NOT included to test default
|
||||
}
|
||||
cli_demo.generate_video(**args)
|
||||
|
||||
# Verify from_pretrained was called WITHOUT device_map
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
assert 'device_map' not in call_kwargs
|
||||
|
||||
# Verify CPU offload WAS enabled (default single-GPU behavior)
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
|
||||
# =========================================================================
|
||||
# Parametrized comprehensive tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device_map,should_have_device_map,should_cpu_offload",
|
||||
[
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
],
|
||||
)
|
||||
def test_device_map_behavior_t2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with t2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="t2v", device_map=device_map))
|
||||
|
||||
# Check from_pretrained call
|
||||
mocks['mock_CogVideoXPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXPipeline'].from_pretrained.call_args[1]
|
||||
|
||||
if should_have_device_map:
|
||||
assert 'device_map' in call_kwargs
|
||||
assert call_kwargs['device_map'] == device_map
|
||||
else:
|
||||
assert 'device_map' not in call_kwargs
|
||||
|
||||
# Check CPU offload
|
||||
if should_cpu_offload:
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
else:
|
||||
mocks['mock_t2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device_map,should_have_device_map,should_cpu_offload",
|
||||
[
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
],
|
||||
)
|
||||
def test_device_map_behavior_i2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with i2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="i2v", device_map=device_map))
|
||||
|
||||
# Check from_pretrained call
|
||||
mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXImageToVideoPipeline'].from_pretrained.call_args[1]
|
||||
|
||||
if should_have_device_map:
|
||||
assert 'device_map' in call_kwargs
|
||||
assert call_kwargs['device_map'] == device_map
|
||||
else:
|
||||
assert 'device_map' not in call_kwargs
|
||||
|
||||
# Check CPU offload
|
||||
if should_cpu_offload:
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
else:
|
||||
mocks['mock_i2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device_map,should_have_device_map,should_cpu_offload",
|
||||
[
|
||||
(None, False, True),
|
||||
("auto", True, False),
|
||||
("balanced", True, False),
|
||||
("sequential", True, False),
|
||||
],
|
||||
)
|
||||
def test_device_map_behavior_v2v(self, device_map, should_have_device_map, should_cpu_offload):
|
||||
"""Comprehensive test for device_map behavior with v2v pipeline."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
cli_demo.generate_video(**self._get_common_args(generate_type="v2v", device_map=device_map))
|
||||
|
||||
# Check from_pretrained call
|
||||
mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.assert_called_once()
|
||||
call_kwargs = mocks['mock_CogVideoXVideoToVideoPipeline'].from_pretrained.call_args[1]
|
||||
|
||||
if should_have_device_map:
|
||||
assert 'device_map' in call_kwargs
|
||||
assert call_kwargs['device_map'] == device_map
|
||||
else:
|
||||
assert 'device_map' not in call_kwargs
|
||||
|
||||
# Check CPU offload
|
||||
if should_cpu_offload:
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_called_once()
|
||||
else:
|
||||
mocks['mock_v2v_pipe'].enable_sequential_cpu_offload.assert_not_called()
|
||||
|
||||
|
||||
class TestValidDeviceMapsConstant:
|
||||
"""Test the VALID_DEVICE_MAPS constant."""
|
||||
|
||||
def test_valid_device_maps_contains_expected_values(self):
|
||||
"""Verify VALID_DEVICE_MAPS contains the expected device map options."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
assert "auto" in cli_demo.VALID_DEVICE_MAPS
|
||||
assert "balanced" in cli_demo.VALID_DEVICE_MAPS
|
||||
assert "sequential" in cli_demo.VALID_DEVICE_MAPS
|
||||
assert len(cli_demo.VALID_DEVICE_MAPS) == 3
|
||||
|
||||
def test_valid_device_maps_is_set(self):
|
||||
"""Verify VALID_DEVICE_MAPS is a set for O(1) lookup."""
|
||||
mocks = create_mocked_cli_demo()
|
||||
cli_demo = mocks['cli_demo']
|
||||
|
||||
assert isinstance(cli_demo.VALID_DEVICE_MAPS, set)
|
||||
|
||||
|
||||
class TestCLIArgumentParsing:
|
||||
"""Test CLI argument parsing for device_map."""
|
||||
|
||||
def test_cli_device_map_choices(self):
|
||||
"""Test that CLI parser accepts valid device_map choices."""
|
||||
import argparse
|
||||
|
||||
# Create a minimal parser matching the CLI
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--device_map",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["auto", "balanced", "sequential"],
|
||||
)
|
||||
|
||||
# Test valid values
|
||||
for valid_value in ["auto", "balanced", "sequential"]:
|
||||
args = parser.parse_args(["--device_map", valid_value])
|
||||
assert args.device_map == valid_value
|
||||
|
||||
# Test default (no argument)
|
||||
args = parser.parse_args([])
|
||||
assert args.device_map is None
|
||||
|
||||
def test_cli_device_map_invalid_choice(self):
|
||||
"""Test that CLI parser rejects invalid device_map choices."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--device_map",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["auto", "balanced", "sequential"],
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
parser.parse_args(["--device_map", "invalid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
x
Reference in New Issue
Block a user