From 15a4b403c4dae9c6f7cbc7d6411a31f5b2ac1163 Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 19 Feb 2026 03:24:16 +0000 Subject: [PATCH] test: add comprehensive tests for batch inference pipeline --- tests/test_batch_inference.py | 1194 +++++++++++++++++++++++++++++++++ 1 file changed, 1194 insertions(+) create mode 100644 tests/test_batch_inference.py diff --git a/tests/test_batch_inference.py b/tests/test_batch_inference.py new file mode 100644 index 0000000..1225cb5 --- /dev/null +++ b/tests/test_batch_inference.py @@ -0,0 +1,1194 @@ +""" +Comprehensive tests for the batch inference pipeline. + +Tests cover: +- JSONL parsing (valid and invalid input) +- Batch processing logic (iteration, state tracking) +- Resume capability (read/write .batch_state.json) +- Error handling (skip failed items, continue processing) +- Multi-GPU job distribution (--gpu_id, --num_gpus) +- Output path creation +- Progress tracking +- All generation types: t2v, i2v, v2v +- Edge cases: empty batch, corrupted state, missing files, invalid JSONL +""" + +import importlib.util +import json +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch +import pytest + + +# ============================================================================= +# Mock heavy dependencies before importing batch_inference +# ============================================================================= + +# Create mock torch module +mock_torch = MagicMock() +mock_torch.bfloat16 = "bfloat16" +mock_torch.float16 = "float16" +mock_torch.Generator = MagicMock + +# Create mock diffusers module +mock_diffusers = MagicMock() + +# Create mock tqdm that works as a context manager +class MockTqdm: + """Mock tqdm that works as a context manager and returns a dummy progress bar.""" + def __init__(self, iterable=None, total=None, **kwargs): + self.iterable = iterable + self.n = 0 + + def __iter__(self): + if self.iterable is not None: + for item in self.iterable: + yield item + self.n += 1 + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def update(self, n=1): + self.n += n + + def set_postfix(self, ordered_dict=None, refresh=True, **kwargs): + pass + + +mock_tqdm = MagicMock() +mock_tqdm.tqdm = MockTqdm + +# Install mocks before importing +sys.modules["torch"] = mock_torch +sys.modules["diffusers"] = mock_diffusers +sys.modules["diffusers.utils"] = mock_diffusers.utils +sys.modules["tqdm"] = mock_tqdm + + +def _load_batch_inference_module(): + """Load the batch_inference module from tools directory.""" + module_path = Path(__file__).parent.parent / "tools" / "batch_inference.py" + spec = importlib.util.spec_from_file_location("batch_inference", module_path) + module = importlib.util.module_from_spec(spec) + sys.modules["batch_inference"] = module + spec.loader.exec_module(module) + return module + + +batch_inference = _load_batch_inference_module() + +# Import the symbols we need +BatchJob = batch_inference.BatchJob +BatchState = batch_inference.BatchState +load_batch_file = batch_inference.load_batch_file +load_pipeline = batch_inference.load_pipeline +generate_single_video = batch_inference.generate_single_video +run_batch = batch_inference.run_batch +RESOLUTION_MAP = batch_inference.RESOLUTION_MAP + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def sample_job_dict(): + """A valid job dictionary.""" + return { + "prompt": "A cat playing piano", + "output_name": "cat_piano.mp4", + "num_frames": 49, + "guidance_scale": 7.0, + "seed": 123, + } + + +@pytest.fixture +def sample_i2v_job_dict(): + """A valid i2v job dictionary.""" + return { + "prompt": "Transform this image", + "output_name": "i2v_output.mp4", + "image_path": "/path/to/image.jpg", + } + + +@pytest.fixture +def sample_v2v_job_dict(): + """A valid v2v job dictionary.""" + return { + "prompt": "Enhance this video", + "output_name": "v2v_output.mp4", + "video_path": "/path/to/video.mp4", + } + + +@pytest.fixture +def valid_jsonl_content(): + """Valid JSONL content with multiple jobs.""" + jobs = [ + {"prompt": "A cat playing piano", "output_name": "cat_piano.mp4"}, + {"prompt": "Waves crashing on beach", "output_name": "beach.mp4", "num_frames": 49}, + {"prompt": "A dog running", "output_name": "dog.mp4", "seed": 999}, + ] + return "\n".join(json.dumps(j) for j in jobs) + + +@pytest.fixture +def batch_file(tmp_path, valid_jsonl_content): + """Create a temporary batch file with valid JSONL content.""" + batch_path = tmp_path / "batch.jsonl" + batch_path.write_text(valid_jsonl_content) + return batch_path + + +@pytest.fixture +def mock_pipeline(): + """Create a mock pipeline that returns fake video frames.""" + mock = MagicMock() + # Mock the pipeline call to return a fake frames result + mock_result = MagicMock() + mock_result.frames = [[MagicMock()]] # Nested list like real output + mock.return_value = mock_result + mock.scheduler = MagicMock() + mock.vae = MagicMock() + return mock + + +# ============================================================================= +# BatchJob Tests +# ============================================================================= + + +class TestBatchJob: + """Tests for the BatchJob dataclass.""" + + def test_from_dict_basic(self, sample_job_dict): + """Test creating a BatchJob from a dictionary.""" + job = BatchJob.from_dict(sample_job_dict, line_number=1) + + assert job.prompt == "A cat playing piano" + assert job.output_name == "cat_piano.mp4" + assert job.num_frames == 49 + assert job.guidance_scale == 7.0 + assert job.seed == 123 + assert job.line_number == 1 + assert job.status == "pending" + + def test_from_dict_minimal(self): + """Test creating a BatchJob with only required fields.""" + data = {"prompt": "Test prompt", "output_name": "test.mp4"} + job = BatchJob.from_dict(data) + + assert job.prompt == "Test prompt" + assert job.output_name == "test.mp4" + assert job.num_frames is None + assert job.image_path is None + assert job.video_path is None + + def test_from_dict_with_image_path(self, sample_i2v_job_dict): + """Test creating a BatchJob with image_path for i2v.""" + job = BatchJob.from_dict(sample_i2v_job_dict) + + assert job.image_path == "/path/to/image.jpg" + assert job.video_path is None + + def test_from_dict_with_video_path(self, sample_v2v_job_dict): + """Test creating a BatchJob with video_path for v2v.""" + job = BatchJob.from_dict(sample_v2v_job_dict) + + assert job.video_path == "/path/to/video.mp4" + assert job.image_path is None + + def test_validate_valid_job(self, sample_job_dict): + """Test validation of a valid job returns no errors.""" + job = BatchJob.from_dict(sample_job_dict) + errors = job.validate() + + assert errors == [] + + def test_validate_missing_prompt(self): + """Test validation catches missing prompt.""" + job = BatchJob.from_dict({"output_name": "test.mp4"}) + errors = job.validate() + + assert "Missing required field: prompt" in errors + + def test_validate_missing_output_name(self): + """Test validation catches missing output_name.""" + job = BatchJob.from_dict({"prompt": "Test"}) + errors = job.validate() + + assert "Missing required field: output_name" in errors + + def test_validate_missing_both_required(self): + """Test validation catches both missing required fields.""" + job = BatchJob.from_dict({}) + errors = job.validate() + + assert len(errors) == 2 + assert "Missing required field: prompt" in errors + assert "Missing required field: output_name" in errors + + +# ============================================================================= +# BatchState Tests +# ============================================================================= + + +class TestBatchState: + """Tests for the BatchState dataclass.""" + + def test_create_new_state(self): + """Test creating a new BatchState.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="THUDM/CogVideoX1.5-5B", + generate_type="t2v", + ) + + assert state.batch_file == "batch.jsonl" + assert state.completed == [] + assert state.failed == [] + + def test_mark_completed(self): + """Test marking a job as completed.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + state.mark_completed("video1.mp4") + state.mark_completed("video2.mp4") + + assert "video1.mp4" in state.completed + assert "video2.mp4" in state.completed + + def test_mark_completed_no_duplicates(self): + """Test that marking the same job twice doesn't duplicate.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + state.mark_completed("video1.mp4") + state.mark_completed("video1.mp4") + + assert state.completed.count("video1.mp4") == 1 + + def test_mark_failed(self): + """Test marking a job as failed.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + state.mark_failed("video1.mp4", "CUDA out of memory") + + assert len(state.failed) == 1 + assert state.failed[0]["output_name"] == "video1.mp4" + assert state.failed[0]["error"] == "CUDA out of memory" + + def test_is_completed(self): + """Test checking if a job is completed.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + state.mark_completed("video1.mp4") + + assert state.is_completed("video1.mp4") is True + assert state.is_completed("video2.mp4") is False + + def test_save_and_load(self, tmp_path): + """Test saving and loading state.""" + state_file = tmp_path / ".batch_state.json" + + # Create and save state + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + started_at="2024-01-01T00:00:00", + ) + state.mark_completed("video1.mp4") + state.mark_failed("video2.mp4", "Error") + state.save(state_file) + + # Load state + loaded = BatchState.load(state_file) + + assert loaded is not None + assert loaded.batch_file == "batch.jsonl" + assert "video1.mp4" in loaded.completed + assert len(loaded.failed) == 1 + assert loaded.failed[0]["output_name"] == "video2.mp4" + + def test_load_nonexistent_file(self, tmp_path): + """Test loading from a nonexistent file returns None.""" + state_file = tmp_path / "nonexistent.json" + + loaded = BatchState.load(state_file) + + assert loaded is None + + def test_load_corrupted_file(self, tmp_path): + """Test loading a corrupted state file returns None.""" + state_file = tmp_path / ".batch_state.json" + state_file.write_text("{ invalid json }") + + loaded = BatchState.load(state_file) + + assert loaded is None + + def test_load_invalid_json_structure(self, tmp_path): + """Test loading a file with valid JSON but invalid structure.""" + state_file = tmp_path / ".batch_state.json" + state_file.write_text('{"wrong": "structure"}') + + loaded = BatchState.load(state_file) + + # Should return None due to missing required fields + assert loaded is None + + +# ============================================================================= +# JSONL Parsing Tests +# ============================================================================= + + +class TestLoadBatchFile: + """Tests for JSONL parsing.""" + + def test_load_valid_jsonl(self, batch_file): + """Test loading a valid JSONL file.""" + jobs = load_batch_file(batch_file) + + assert len(jobs) == 3 + assert jobs[0].prompt == "A cat playing piano" + assert jobs[1].output_name == "beach.mp4" + assert jobs[2].seed == 999 + + def test_load_jsonl_with_comments(self, tmp_path): + """Test that comment lines are skipped.""" + content = """# This is a comment +{"prompt": "Test", "output_name": "test.mp4"} +# Another comment +{"prompt": "Test2", "output_name": "test2.mp4"}""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 2 + + def test_load_jsonl_with_empty_lines(self, tmp_path): + """Test that empty lines are skipped.""" + content = """{"prompt": "Test", "output_name": "test.mp4"} + +{"prompt": "Test2", "output_name": "test2.mp4"} + +""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 2 + + def test_load_jsonl_invalid_json_line(self, tmp_path): + """Test that invalid JSON lines are skipped with warning.""" + content = """{"prompt": "Test", "output_name": "test.mp4"} +{invalid json here} +{"prompt": "Test2", "output_name": "test2.mp4"}""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + # Should skip the invalid line + assert len(jobs) == 2 + + def test_load_jsonl_missing_required_fields(self, tmp_path): + """Test that jobs with missing required fields are skipped.""" + content = """{"prompt": "Test", "output_name": "test.mp4"} +{"prompt": "No output name"} +{"output_name": "no_prompt.mp4"} +{"prompt": "Test2", "output_name": "test2.mp4"}""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + # Should skip the two invalid jobs + assert len(jobs) == 2 + assert jobs[0].output_name == "test.mp4" + assert jobs[1].output_name == "test2.mp4" + + def test_load_empty_batch_file(self, tmp_path): + """Test loading an empty batch file.""" + batch_file = tmp_path / "empty.jsonl" + batch_file.write_text("") + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 0 + + def test_load_batch_file_all_comments(self, tmp_path): + """Test loading a file with only comments.""" + content = """# Comment 1 +# Comment 2 +# Comment 3""" + batch_file = tmp_path / "comments.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 0 + + def test_load_batch_file_tracks_line_numbers(self, tmp_path): + """Test that line numbers are tracked correctly.""" + content = """# Comment on line 1 +{"prompt": "Line 2", "output_name": "line2.mp4"} + +{"prompt": "Line 4", "output_name": "line4.mp4"}""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert jobs[0].line_number == 2 + assert jobs[1].line_number == 4 + + +# ============================================================================= +# Multi-GPU Distribution Tests +# ============================================================================= + + +class TestMultiGPUDistribution: + """Tests for multi-GPU job distribution logic.""" + + def test_single_gpu_gets_all_jobs(self): + """Test that a single GPU processes all jobs.""" + all_jobs = [ + BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") + for i in range(10) + ] + + # Simulate distribution for single GPU (gpu_id=0, num_gpus=1) + gpu_id = 0 + num_gpus = 1 + jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] + + assert len(jobs) == 10 + + def test_two_gpu_distribution(self): + """Test job distribution across 2 GPUs.""" + all_jobs = [ + BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") + for i in range(10) + ] + + # GPU 0 gets jobs 0, 2, 4, 6, 8 + gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0] + # GPU 1 gets jobs 1, 3, 5, 7, 9 + gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 1] + + assert len(gpu0_jobs) == 5 + assert len(gpu1_jobs) == 5 + assert gpu0_jobs[0].output_name == "video0.mp4" + assert gpu1_jobs[0].output_name == "video1.mp4" + + def test_four_gpu_distribution(self): + """Test job distribution across 4 GPUs.""" + all_jobs = [ + BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") + for i in range(12) + ] + + num_gpus = 4 + for gpu_id in range(num_gpus): + jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] + assert len(jobs) == 3 # 12 jobs / 4 GPUs = 3 each + + def test_uneven_distribution(self): + """Test distribution when jobs don't divide evenly.""" + all_jobs = [ + BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") + for i in range(10) + ] + + num_gpus = 3 + gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0] + gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 1] + gpu2_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 2] + + # 10 jobs / 3 GPUs: GPU0 gets 4, GPU1 gets 3, GPU2 gets 3 + assert len(gpu0_jobs) == 4 # indices 0, 3, 6, 9 + assert len(gpu1_jobs) == 3 # indices 1, 4, 7 + assert len(gpu2_jobs) == 3 # indices 2, 5, 8 + + +# ============================================================================= +# Pipeline Loading Tests +# ============================================================================= + + +class TestLoadPipeline: + """Tests for pipeline loading.""" + + @patch.object(batch_inference, "CogVideoXPipeline") + def test_load_t2v_pipeline(self, mock_cogvideo_pipeline): + """Test loading a text-to-video pipeline.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + pipe = load_pipeline( + model_path="THUDM/CogVideoX1.5-5B", + generate_type="t2v", + enable_cpu_offload=False, + ) + + mock_cogvideo_pipeline.from_pretrained.assert_called_once() + assert pipe is mock_pipe + + @patch.object(batch_inference, "CogVideoXImageToVideoPipeline") + def test_load_i2v_pipeline(self, mock_i2v_pipeline): + """Test loading an image-to-video pipeline.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_i2v_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + pipe = load_pipeline( + model_path="THUDM/CogVideoX1.5-5B-I2V", + generate_type="i2v", + enable_cpu_offload=False, + ) + + mock_i2v_pipeline.from_pretrained.assert_called_once() + assert pipe is mock_pipe + + @patch.object(batch_inference, "CogVideoXVideoToVideoPipeline") + def test_load_v2v_pipeline(self, mock_v2v_pipeline): + """Test loading a video-to-video pipeline.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_v2v_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + pipe = load_pipeline( + model_path="THUDM/CogVideoX1.5-5B", + generate_type="v2v", + enable_cpu_offload=False, + ) + + mock_v2v_pipeline.from_pretrained.assert_called_once() + assert pipe is mock_pipe + + @patch.object(batch_inference, "CogVideoXPipeline") + def test_load_pipeline_with_cpu_offload(self, mock_cogvideo_pipeline): + """Test that CPU offload is enabled when requested.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + load_pipeline( + model_path="model", + generate_type="t2v", + enable_cpu_offload=True, + ) + + mock_pipe.enable_sequential_cpu_offload.assert_called_once() + mock_pipe.to.assert_not_called() + + @patch.object(batch_inference, "CogVideoXPipeline") + def test_load_pipeline_without_cpu_offload(self, mock_cogvideo_pipeline): + """Test that pipeline moves to CUDA when CPU offload is disabled.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + load_pipeline( + model_path="model", + generate_type="t2v", + enable_cpu_offload=False, + ) + + mock_pipe.to.assert_called_once_with("cuda") + mock_pipe.enable_sequential_cpu_offload.assert_not_called() + + @patch.object(batch_inference, "CogVideoXPipeline") + def test_load_pipeline_with_lora(self, mock_cogvideo_pipeline): + """Test loading pipeline with LoRA weights.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + load_pipeline( + model_path="model", + generate_type="t2v", + lora_path="/path/to/lora", + enable_cpu_offload=False, + ) + + mock_pipe.load_lora_weights.assert_called_once() + mock_pipe.fuse_lora.assert_called_once() + + +# ============================================================================= +# Video Generation Tests +# ============================================================================= + + +class TestGenerateSingleVideo: + """Tests for single video generation.""" + + @patch.object(batch_inference, "export_to_video") + def test_generate_t2v_video(self, mock_export, mock_pipeline, tmp_path): + """Test generating a text-to-video.""" + job = BatchJob( + prompt="A cat playing piano", + output_name="cat.mp4", + num_frames=49, + ) + output_path = tmp_path / "cat.mp4" + + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="t2v", + model_name="cogvideox-5b", + output_path=output_path, + ) + + mock_pipeline.assert_called_once() + mock_export.assert_called_once() + + @patch.object(batch_inference, "export_to_video") + @patch.object(batch_inference, "load_image") + def test_generate_i2v_video(self, mock_load_image, mock_export, mock_pipeline, tmp_path): + """Test generating an image-to-video.""" + mock_load_image.return_value = MagicMock() + + job = BatchJob( + prompt="Transform this", + output_name="i2v.mp4", + image_path="/path/to/image.jpg", + ) + output_path = tmp_path / "i2v.mp4" + + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="i2v", + model_name="cogvideox-5b-i2v", + output_path=output_path, + ) + + mock_load_image.assert_called_once_with(image="/path/to/image.jpg") + mock_pipeline.assert_called_once() + + @patch.object(batch_inference, "export_to_video") + @patch.object(batch_inference, "load_video") + def test_generate_v2v_video(self, mock_load_video, mock_export, mock_pipeline, tmp_path): + """Test generating a video-to-video.""" + mock_load_video.return_value = MagicMock() + + job = BatchJob( + prompt="Enhance this", + output_name="v2v.mp4", + video_path="/path/to/video.mp4", + ) + output_path = tmp_path / "v2v.mp4" + + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="v2v", + model_name="cogvideox-5b", + output_path=output_path, + ) + + mock_load_video.assert_called_once_with("/path/to/video.mp4") + mock_pipeline.assert_called_once() + + def test_generate_i2v_missing_image_path(self, mock_pipeline, tmp_path): + """Test that i2v generation fails without image_path.""" + job = BatchJob( + prompt="Transform this", + output_name="i2v.mp4", + # Missing image_path + ) + output_path = tmp_path / "i2v.mp4" + + with pytest.raises(ValueError, match="image_path is required"): + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="i2v", + model_name="cogvideox-5b-i2v", + output_path=output_path, + ) + + def test_generate_v2v_missing_video_path(self, mock_pipeline, tmp_path): + """Test that v2v generation fails without video_path.""" + job = BatchJob( + prompt="Enhance this", + output_name="v2v.mp4", + # Missing video_path + ) + output_path = tmp_path / "v2v.mp4" + + with pytest.raises(ValueError, match="video_path is required"): + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="v2v", + model_name="cogvideox-5b", + output_path=output_path, + ) + + @patch.object(batch_inference, "export_to_video") + def test_generate_uses_job_specific_params(self, mock_export, mock_pipeline, tmp_path): + """Test that job-specific parameters override defaults.""" + job = BatchJob( + prompt="Test", + output_name="test.mp4", + num_frames=33, + guidance_scale=9.0, + num_inference_steps=25, + seed=456, + ) + output_path = tmp_path / "test.mp4" + + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="t2v", + model_name="cogvideox-5b", + output_path=output_path, + default_num_frames=81, + default_guidance_scale=6.0, + default_num_inference_steps=50, + default_seed=42, + ) + + # Check the pipeline was called with job-specific values + call_kwargs = mock_pipeline.call_args[1] + assert call_kwargs["num_frames"] == 33 + assert call_kwargs["guidance_scale"] == 9.0 + assert call_kwargs["num_inference_steps"] == 25 + + +# ============================================================================= +# Batch Processing Tests +# ============================================================================= + + +class TestRunBatch: + """Tests for the main batch processing function.""" + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_basic(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test basic batch processing.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + + summary = run_batch( + batch_file=batch_file, + model_path="THUDM/CogVideoX1.5-5B", + output_dir=output_dir, + generate_type="t2v", + resume=False, + ) + + assert summary["total"] == 3 + assert summary["completed"] == 3 + assert summary["failed"] == 0 + assert mock_generate.call_count == 3 + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_creates_output_dir(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test that output directory is created if it doesn't exist.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "nested" / "output" / "dir" + assert not output_dir.exists() + + run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + assert output_dir.exists() + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_saves_state(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test that state is saved after processing.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + + run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=True, + ) + + state_file = output_dir / ".batch_state.json" + assert state_file.exists() + + with open(state_file) as f: + state_data = json.load(f) + + assert len(state_data["completed"]) == 3 + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_handles_errors(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test that batch continues after individual job failures.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + # Make the second call raise an exception + mock_generate.side_effect = [None, RuntimeError("CUDA OOM"), None] + + output_dir = tmp_path / "output" + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + assert summary["completed"] == 2 + assert summary["failed"] == 1 + assert mock_generate.call_count == 3 + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_logs_errors_to_file(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test that errors are logged to errors.log.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + mock_generate.side_effect = RuntimeError("Test error") + + output_dir = tmp_path / "output" + + run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + error_log = output_dir / "errors.log" + assert error_log.exists() + content = error_log.read_text() + assert "Test error" in content + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_resume_skips_completed(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test that resume skips already-completed jobs.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + output_dir.mkdir() + + # Create a state file showing first job completed + state = BatchState( + batch_file=str(batch_file), + output_dir=str(output_dir), + model_path="model", + generate_type="t2v", + completed=["cat_piano.mp4"], + ) + state.save(output_dir / ".batch_state.json") + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=True, + ) + + # Should only process 2 jobs (skipping the completed one) + assert summary["total"] == 2 + assert summary["skipped"] == 1 + assert mock_generate.call_count == 2 + + @patch.object(batch_inference, "load_pipeline") + def test_run_batch_empty_file(self, mock_load_pipeline, tmp_path): + """Test running batch with empty file.""" + batch_file = tmp_path / "empty.jsonl" + batch_file.write_text("") + + output_dir = tmp_path / "output" + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + assert summary["total"] == 0 + # Pipeline should not be loaded for empty batch + mock_load_pipeline.assert_not_called() + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path): + """Test multi-GPU job distribution in run_batch.""" + # Create batch file with 6 jobs + jobs = [ + {"prompt": f"Job {i}", "output_name": f"video{i}.mp4"} + for i in range(6) + ] + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text("\n".join(json.dumps(j) for j in jobs)) + + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + + # GPU 0 of 3 should get jobs 0, 3 (indices 0, 3) + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + gpu_id=0, + num_gpus=3, + resume=False, + ) + + assert summary["total"] == 2 # GPU 0 gets 2 jobs + assert mock_generate.call_count == 2 + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_state_persists_after_each_job(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test that state is saved after each job, not just at the end.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + # Track how many times state file is written + state_writes = [] + original_save = BatchState.save + + def tracking_save(self, state_file): + original_save(self, state_file) + state_writes.append(len(self.completed)) + + with patch.object(BatchState, 'save', tracking_save): + output_dir = tmp_path / "output" + run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=True, + ) + + # State should be saved 3 times (once per job) + assert len(state_writes) == 3 + assert state_writes == [1, 2, 3] # Progressively more completed jobs + + +# ============================================================================= +# Resolution Map Tests +# ============================================================================= + + +class TestResolutionMap: + """Tests for resolution mapping.""" + + def test_resolution_map_contains_expected_models(self): + """Test that RESOLUTION_MAP contains expected model keys.""" + assert "cogvideox1.5-5b-i2v" in RESOLUTION_MAP + assert "cogvideox1.5-5b" in RESOLUTION_MAP + assert "cogvideox-5b-i2v" in RESOLUTION_MAP + assert "cogvideox-5b" in RESOLUTION_MAP + assert "cogvideox-2b" in RESOLUTION_MAP + + def test_resolution_values_are_valid(self): + """Test that all resolution values are valid (height, width) tuples.""" + for model, resolution in RESOLUTION_MAP.items(): + assert isinstance(resolution, tuple) + assert len(resolution) == 2 + assert resolution[0] > 0 # height + assert resolution[1] > 0 # width + + +# ============================================================================= +# Edge Case Tests +# ============================================================================= + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_batch_job_with_all_optional_fields(self): + """Test BatchJob with all optional fields set.""" + data = { + "prompt": "Test", + "output_name": "test.mp4", + "image_path": "/img.jpg", + "video_path": "/vid.mp4", + "num_frames": 100, + "guidance_scale": 10.0, + "num_inference_steps": 100, + "seed": 9999, + "width": 1920, + "height": 1080, + } + job = BatchJob.from_dict(data) + + assert job.width == 1920 + assert job.height == 1080 + assert job.num_frames == 100 + + def test_batch_job_preserves_extra_fields(self): + """Test that extra fields in JSON don't cause errors.""" + data = { + "prompt": "Test", + "output_name": "test.mp4", + "extra_field": "ignored", + "another_field": 123, + } + job = BatchJob.from_dict(data) + + assert job.prompt == "Test" + assert job.output_name == "test.mp4" + # Extra fields should be ignored without error + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_corrupted_state_starts_fresh(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test that corrupted state file doesn't prevent batch from running.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + output_dir.mkdir() + + # Create corrupted state file + state_file = output_dir / ".batch_state.json" + state_file.write_text("{corrupted: json") + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=True, + ) + + # Should start fresh and process all jobs + assert summary["total"] == 3 + assert mock_generate.call_count == 3 + + def test_batch_state_updated_at_is_set_on_save(self, tmp_path): + """Test that updated_at is set when saving state.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + assert state.updated_at == "" + + state_file = tmp_path / "state.json" + state.save(state_file) + + # updated_at should be set after save + assert state.updated_at != "" + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_all_jobs_fail(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test batch where all jobs fail.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + mock_generate.side_effect = RuntimeError("All fail") + + output_dir = tmp_path / "output" + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + assert summary["completed"] == 0 + assert summary["failed"] == 3 + + def test_jsonl_with_unicode(self, tmp_path): + """Test JSONL parsing with unicode characters.""" + content = """{"prompt": "猫がピアノを弾いている", "output_name": "unicode1.mp4"} +{"prompt": "海辺の波 🌊", "output_name": "unicode2.mp4"} +{"prompt": "Café résumé naïve", "output_name": "unicode3.mp4"}""" + batch_file = tmp_path / "unicode.jsonl" + batch_file.write_text(content, encoding="utf-8") + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 3 + assert jobs[0].prompt == "猫がピアノを弾いている" + assert "🌊" in jobs[1].prompt + + def test_jsonl_with_very_long_prompt(self, tmp_path): + """Test JSONL with very long prompts.""" + long_prompt = "A " + "very " * 1000 + "long prompt" + content = json.dumps({"prompt": long_prompt, "output_name": "long.mp4"}) + batch_file = tmp_path / "long.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 1 + assert len(jobs[0].prompt) > 4000