diff --git a/.gitignore b/.gitignore index ad4bbeb..303dbae 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ logs/ .idea output* test* +!tests/ +!tests/** venv **/.swp **/*.log diff --git a/resources/example_batch_i2v.jsonl b/resources/example_batch_i2v.jsonl new file mode 100644 index 0000000..aaad4cf --- /dev/null +++ b/resources/example_batch_i2v.jsonl @@ -0,0 +1,8 @@ +# Example image-to-video batch file +# For i2v generation, include image_path pointing to your input images +# Run with: python tools/batch_inference.py --batch_file resources/example_batch_i2v.jsonl --model_path THUDM/CogVideoX1.5-5B-I2V --generate_type i2v + +{"prompt": "The person in the image starts walking forward with a confident stride", "output_name": "person_walking.mp4", "image_path": "./input_images/person.jpg"} +{"prompt": "The landscape transforms with clouds moving across the sky and grass swaying", "output_name": "landscape_motion.mp4", "image_path": "./input_images/landscape.jpg", "num_frames": 81} +{"prompt": "Zoom in slowly on the subject while adding subtle motion blur", "output_name": "zoom_effect.mp4", "image_path": "./input_images/portrait.jpg", "guidance_scale": 5.0} +{"prompt": "The still life scene comes alive with gentle movement and shifting shadows", "output_name": "still_life_animated.mp4", "image_path": "./input_images/still_life.jpg"} diff --git a/resources/example_batch_t2v.jsonl b/resources/example_batch_t2v.jsonl new file mode 100644 index 0000000..fa18ea8 --- /dev/null +++ b/resources/example_batch_t2v.jsonl @@ -0,0 +1,12 @@ +# Example text-to-video batch file +# Each line is a JSON object with prompt and output_name (required) +# Optional: num_frames, guidance_scale, num_inference_steps, seed, width, height + +{"prompt": "A majestic eagle soaring through golden sunset clouds, cinematic lighting, 4K quality", "output_name": "eagle_sunset.mp4"} +{"prompt": "A cozy coffee shop on a rainy day, steam rising from cups, warm ambient lighting", "output_name": "coffee_shop_rain.mp4", "num_frames": 49} +{"prompt": "An astronaut floating in space with Earth visible in the background, peaceful and serene", "output_name": "astronaut_space.mp4", "seed": 123} +{"prompt": "A field of sunflowers swaying gently in the summer breeze, bright and cheerful", "output_name": "sunflowers.mp4", "guidance_scale": 7.0} +{"prompt": "A futuristic city at night with neon lights and flying cars, cyberpunk aesthetic", "output_name": "cyberpunk_city.mp4", "num_inference_steps": 60} +{"prompt": "A serene Japanese garden with cherry blossoms falling, koi pond, peaceful atmosphere", "output_name": "japanese_garden.mp4"} +{"prompt": "Waves crashing on a tropical beach at sunset, palm trees silhouetted against orange sky", "output_name": "tropical_sunset.mp4"} +{"prompt": "A mystical forest with glowing mushrooms and fireflies, fantasy atmosphere", "output_name": "mystical_forest.mp4", "seed": 456} diff --git a/resources/example_batch_v2v.jsonl b/resources/example_batch_v2v.jsonl new file mode 100644 index 0000000..c4bc9b6 --- /dev/null +++ b/resources/example_batch_v2v.jsonl @@ -0,0 +1,8 @@ +# Example video-to-video batch file +# For v2v generation, include video_path pointing to your input videos +# Run with: python tools/batch_inference.py --batch_file resources/example_batch_v2v.jsonl --model_path THUDM/CogVideoX1.5-5B --generate_type v2v + +{"prompt": "Transform this video into a watercolor painting style with soft brushstrokes", "output_name": "watercolor_style.mp4", "video_path": "./input_videos/original1.mp4"} +{"prompt": "Convert to anime style with vibrant colors and dramatic lighting", "output_name": "anime_style.mp4", "video_path": "./input_videos/original2.mp4", "guidance_scale": 7.5} +{"prompt": "Add cinematic color grading with film grain and dramatic contrast", "output_name": "cinematic_grade.mp4", "video_path": "./input_videos/original3.mp4", "num_inference_steps": 40} +{"prompt": "Transform into a vintage black and white film with classic aesthetics", "output_name": "vintage_bw.mp4", "video_path": "./input_videos/original4.mp4"} diff --git a/tests/test_batch_inference.py b/tests/test_batch_inference.py new file mode 100644 index 0000000..3cb60fe --- /dev/null +++ b/tests/test_batch_inference.py @@ -0,0 +1,1193 @@ +""" +Comprehensive tests for the batch inference pipeline. + +Tests cover: +- JSONL parsing (valid and invalid input) +- Batch processing logic (iteration, state tracking) +- Resume capability (read/write .batch_state.json) +- Error handling (skip failed items, continue processing) +- Multi-GPU job distribution (--gpu_id, --num_gpus) +- Output path creation +- Progress tracking +- All generation types: t2v, i2v, v2v +- Edge cases: empty batch, corrupted state, missing files, invalid JSONL +""" + +import importlib.util +import json +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch +import pytest + + +# ============================================================================= +# Mock heavy dependencies before importing batch_inference +# ============================================================================= + +# Create mock torch module +mock_torch = MagicMock() +mock_torch.bfloat16 = "bfloat16" +mock_torch.float16 = "float16" +mock_torch.Generator = MagicMock + +# Create mock diffusers module +mock_diffusers = MagicMock() + + +# Create mock tqdm that works as a context manager +class MockTqdm: + """Mock tqdm that works as a context manager and returns a dummy progress bar.""" + + def __init__(self, iterable=None, total=None, **kwargs): + self.iterable = iterable + self.n = 0 + + def __iter__(self): + if self.iterable is not None: + for item in self.iterable: + yield item + self.n += 1 + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def update(self, n=1): + self.n += n + + def set_postfix(self, ordered_dict=None, refresh=True, **kwargs): + pass + + +mock_tqdm = MagicMock() +mock_tqdm.tqdm = MockTqdm + +# Install mocks before importing +sys.modules["torch"] = mock_torch +sys.modules["diffusers"] = mock_diffusers +sys.modules["diffusers.utils"] = mock_diffusers.utils +sys.modules["tqdm"] = mock_tqdm + + +def _load_batch_inference_module(): + """Load the batch_inference module from tools directory.""" + module_path = Path(__file__).parent.parent / "tools" / "batch_inference.py" + spec = importlib.util.spec_from_file_location("batch_inference", module_path) + module = importlib.util.module_from_spec(spec) + sys.modules["batch_inference"] = module + spec.loader.exec_module(module) + return module + + +batch_inference = _load_batch_inference_module() + +# Import the symbols we need +BatchJob = batch_inference.BatchJob +BatchState = batch_inference.BatchState +load_batch_file = batch_inference.load_batch_file +load_pipeline = batch_inference.load_pipeline +generate_single_video = batch_inference.generate_single_video +run_batch = batch_inference.run_batch +RESOLUTION_MAP = batch_inference.RESOLUTION_MAP + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def sample_job_dict(): + """A valid job dictionary.""" + return { + "prompt": "A cat playing piano", + "output_name": "cat_piano.mp4", + "num_frames": 49, + "guidance_scale": 7.0, + "seed": 123, + } + + +@pytest.fixture +def sample_i2v_job_dict(): + """A valid i2v job dictionary.""" + return { + "prompt": "Transform this image", + "output_name": "i2v_output.mp4", + "image_path": "/path/to/image.jpg", + } + + +@pytest.fixture +def sample_v2v_job_dict(): + """A valid v2v job dictionary.""" + return { + "prompt": "Enhance this video", + "output_name": "v2v_output.mp4", + "video_path": "/path/to/video.mp4", + } + + +@pytest.fixture +def valid_jsonl_content(): + """Valid JSONL content with multiple jobs.""" + jobs = [ + {"prompt": "A cat playing piano", "output_name": "cat_piano.mp4"}, + {"prompt": "Waves crashing on beach", "output_name": "beach.mp4", "num_frames": 49}, + {"prompt": "A dog running", "output_name": "dog.mp4", "seed": 999}, + ] + return "\n".join(json.dumps(j) for j in jobs) + + +@pytest.fixture +def batch_file(tmp_path, valid_jsonl_content): + """Create a temporary batch file with valid JSONL content.""" + batch_path = tmp_path / "batch.jsonl" + batch_path.write_text(valid_jsonl_content) + return batch_path + + +@pytest.fixture +def mock_pipeline(): + """Create a mock pipeline that returns fake video frames.""" + mock = MagicMock() + # Mock the pipeline call to return a fake frames result + mock_result = MagicMock() + mock_result.frames = [[MagicMock()]] # Nested list like real output + mock.return_value = mock_result + mock.scheduler = MagicMock() + mock.vae = MagicMock() + return mock + + +# ============================================================================= +# BatchJob Tests +# ============================================================================= + + +class TestBatchJob: + """Tests for the BatchJob dataclass.""" + + def test_from_dict_basic(self, sample_job_dict): + """Test creating a BatchJob from a dictionary.""" + job = BatchJob.from_dict(sample_job_dict, line_number=1) + + assert job.prompt == "A cat playing piano" + assert job.output_name == "cat_piano.mp4" + assert job.num_frames == 49 + assert job.guidance_scale == 7.0 + assert job.seed == 123 + assert job.line_number == 1 + assert job.status == "pending" + + def test_from_dict_minimal(self): + """Test creating a BatchJob with only required fields.""" + data = {"prompt": "Test prompt", "output_name": "test.mp4"} + job = BatchJob.from_dict(data) + + assert job.prompt == "Test prompt" + assert job.output_name == "test.mp4" + assert job.num_frames is None + assert job.image_path is None + assert job.video_path is None + + def test_from_dict_with_image_path(self, sample_i2v_job_dict): + """Test creating a BatchJob with image_path for i2v.""" + job = BatchJob.from_dict(sample_i2v_job_dict) + + assert job.image_path == "/path/to/image.jpg" + assert job.video_path is None + + def test_from_dict_with_video_path(self, sample_v2v_job_dict): + """Test creating a BatchJob with video_path for v2v.""" + job = BatchJob.from_dict(sample_v2v_job_dict) + + assert job.video_path == "/path/to/video.mp4" + assert job.image_path is None + + def test_validate_valid_job(self, sample_job_dict): + """Test validation of a valid job returns no errors.""" + job = BatchJob.from_dict(sample_job_dict) + errors = job.validate() + + assert errors == [] + + def test_validate_missing_prompt(self): + """Test validation catches missing prompt.""" + job = BatchJob.from_dict({"output_name": "test.mp4"}) + errors = job.validate() + + assert "Missing required field: prompt" in errors + + def test_validate_missing_output_name(self): + """Test validation catches missing output_name.""" + job = BatchJob.from_dict({"prompt": "Test"}) + errors = job.validate() + + assert "Missing required field: output_name" in errors + + def test_validate_missing_both_required(self): + """Test validation catches both missing required fields.""" + job = BatchJob.from_dict({}) + errors = job.validate() + + assert len(errors) == 2 + assert "Missing required field: prompt" in errors + assert "Missing required field: output_name" in errors + + +# ============================================================================= +# BatchState Tests +# ============================================================================= + + +class TestBatchState: + """Tests for the BatchState dataclass.""" + + def test_create_new_state(self): + """Test creating a new BatchState.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="THUDM/CogVideoX1.5-5B", + generate_type="t2v", + ) + + assert state.batch_file == "batch.jsonl" + assert state.completed == [] + assert state.failed == [] + + def test_mark_completed(self): + """Test marking a job as completed.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + state.mark_completed("video1.mp4") + state.mark_completed("video2.mp4") + + assert "video1.mp4" in state.completed + assert "video2.mp4" in state.completed + + def test_mark_completed_no_duplicates(self): + """Test that marking the same job twice doesn't duplicate.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + state.mark_completed("video1.mp4") + state.mark_completed("video1.mp4") + + assert state.completed.count("video1.mp4") == 1 + + def test_mark_failed(self): + """Test marking a job as failed.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + state.mark_failed("video1.mp4", "CUDA out of memory") + + assert len(state.failed) == 1 + assert state.failed[0]["output_name"] == "video1.mp4" + assert state.failed[0]["error"] == "CUDA out of memory" + + def test_is_completed(self): + """Test checking if a job is completed.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + state.mark_completed("video1.mp4") + + assert state.is_completed("video1.mp4") is True + assert state.is_completed("video2.mp4") is False + + def test_save_and_load(self, tmp_path): + """Test saving and loading state.""" + state_file = tmp_path / ".batch_state.json" + + # Create and save state + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + started_at="2024-01-01T00:00:00", + ) + state.mark_completed("video1.mp4") + state.mark_failed("video2.mp4", "Error") + state.save(state_file) + + # Load state + loaded = BatchState.load(state_file) + + assert loaded is not None + assert loaded.batch_file == "batch.jsonl" + assert "video1.mp4" in loaded.completed + assert len(loaded.failed) == 1 + assert loaded.failed[0]["output_name"] == "video2.mp4" + + def test_load_nonexistent_file(self, tmp_path): + """Test loading from a nonexistent file returns None.""" + state_file = tmp_path / "nonexistent.json" + + loaded = BatchState.load(state_file) + + assert loaded is None + + def test_load_corrupted_file(self, tmp_path): + """Test loading a corrupted state file returns None.""" + state_file = tmp_path / ".batch_state.json" + state_file.write_text("{ invalid json }") + + loaded = BatchState.load(state_file) + + assert loaded is None + + def test_load_invalid_json_structure(self, tmp_path): + """Test loading a file with valid JSON but invalid structure.""" + state_file = tmp_path / ".batch_state.json" + state_file.write_text('{"wrong": "structure"}') + + loaded = BatchState.load(state_file) + + # Should return None due to missing required fields + assert loaded is None + + +# ============================================================================= +# JSONL Parsing Tests +# ============================================================================= + + +class TestLoadBatchFile: + """Tests for JSONL parsing.""" + + def test_load_valid_jsonl(self, batch_file): + """Test loading a valid JSONL file.""" + jobs = load_batch_file(batch_file) + + assert len(jobs) == 3 + assert jobs[0].prompt == "A cat playing piano" + assert jobs[1].output_name == "beach.mp4" + assert jobs[2].seed == 999 + + def test_load_jsonl_with_comments(self, tmp_path): + """Test that comment lines are skipped.""" + content = """# This is a comment +{"prompt": "Test", "output_name": "test.mp4"} +# Another comment +{"prompt": "Test2", "output_name": "test2.mp4"}""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 2 + + def test_load_jsonl_with_empty_lines(self, tmp_path): + """Test that empty lines are skipped.""" + content = """{"prompt": "Test", "output_name": "test.mp4"} + +{"prompt": "Test2", "output_name": "test2.mp4"} + +""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 2 + + def test_load_jsonl_invalid_json_line(self, tmp_path): + """Test that invalid JSON lines are skipped with warning.""" + content = """{"prompt": "Test", "output_name": "test.mp4"} +{invalid json here} +{"prompt": "Test2", "output_name": "test2.mp4"}""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + # Should skip the invalid line + assert len(jobs) == 2 + + def test_load_jsonl_missing_required_fields(self, tmp_path): + """Test that jobs with missing required fields are skipped.""" + content = """{"prompt": "Test", "output_name": "test.mp4"} +{"prompt": "No output name"} +{"output_name": "no_prompt.mp4"} +{"prompt": "Test2", "output_name": "test2.mp4"}""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + # Should skip the two invalid jobs + assert len(jobs) == 2 + assert jobs[0].output_name == "test.mp4" + assert jobs[1].output_name == "test2.mp4" + + def test_load_empty_batch_file(self, tmp_path): + """Test loading an empty batch file.""" + batch_file = tmp_path / "empty.jsonl" + batch_file.write_text("") + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 0 + + def test_load_batch_file_all_comments(self, tmp_path): + """Test loading a file with only comments.""" + content = """# Comment 1 +# Comment 2 +# Comment 3""" + batch_file = tmp_path / "comments.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 0 + + def test_load_batch_file_tracks_line_numbers(self, tmp_path): + """Test that line numbers are tracked correctly.""" + content = """# Comment on line 1 +{"prompt": "Line 2", "output_name": "line2.mp4"} + +{"prompt": "Line 4", "output_name": "line4.mp4"}""" + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert jobs[0].line_number == 2 + assert jobs[1].line_number == 4 + + +# ============================================================================= +# Multi-GPU Distribution Tests +# ============================================================================= + + +class TestMultiGPUDistribution: + """Tests for multi-GPU job distribution logic.""" + + def test_single_gpu_gets_all_jobs(self): + """Test that a single GPU processes all jobs.""" + all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)] + + # Simulate distribution for single GPU (gpu_id=0, num_gpus=1) + gpu_id = 0 + num_gpus = 1 + jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] + + assert len(jobs) == 10 + + def test_two_gpu_distribution(self): + """Test job distribution across 2 GPUs.""" + all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)] + + # GPU 0 gets jobs 0, 2, 4, 6, 8 + gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0] + # GPU 1 gets jobs 1, 3, 5, 7, 9 + gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 1] + + assert len(gpu0_jobs) == 5 + assert len(gpu1_jobs) == 5 + assert gpu0_jobs[0].output_name == "video0.mp4" + assert gpu1_jobs[0].output_name == "video1.mp4" + + def test_four_gpu_distribution(self): + """Test job distribution across 4 GPUs.""" + all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(12)] + + num_gpus = 4 + for gpu_id in range(num_gpus): + jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] + assert len(jobs) == 3 # 12 jobs / 4 GPUs = 3 each + + def test_uneven_distribution(self): + """Test distribution when jobs don't divide evenly.""" + all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)] + + num_gpus = 3 + gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0] + gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 1] + gpu2_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 2] + + # 10 jobs / 3 GPUs: GPU0 gets 4, GPU1 gets 3, GPU2 gets 3 + assert len(gpu0_jobs) == 4 # indices 0, 3, 6, 9 + assert len(gpu1_jobs) == 3 # indices 1, 4, 7 + assert len(gpu2_jobs) == 3 # indices 2, 5, 8 + + +# ============================================================================= +# Pipeline Loading Tests +# ============================================================================= + + +class TestLoadPipeline: + """Tests for pipeline loading.""" + + @patch.object(batch_inference, "CogVideoXPipeline") + def test_load_t2v_pipeline(self, mock_cogvideo_pipeline): + """Test loading a text-to-video pipeline.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + pipe = load_pipeline( + model_path="THUDM/CogVideoX1.5-5B", + generate_type="t2v", + enable_cpu_offload=False, + ) + + mock_cogvideo_pipeline.from_pretrained.assert_called_once() + assert pipe is mock_pipe + + @patch.object(batch_inference, "CogVideoXImageToVideoPipeline") + def test_load_i2v_pipeline(self, mock_i2v_pipeline): + """Test loading an image-to-video pipeline.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_i2v_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + pipe = load_pipeline( + model_path="THUDM/CogVideoX1.5-5B-I2V", + generate_type="i2v", + enable_cpu_offload=False, + ) + + mock_i2v_pipeline.from_pretrained.assert_called_once() + assert pipe is mock_pipe + + @patch.object(batch_inference, "CogVideoXVideoToVideoPipeline") + def test_load_v2v_pipeline(self, mock_v2v_pipeline): + """Test loading a video-to-video pipeline.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_v2v_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + pipe = load_pipeline( + model_path="THUDM/CogVideoX1.5-5B", + generate_type="v2v", + enable_cpu_offload=False, + ) + + mock_v2v_pipeline.from_pretrained.assert_called_once() + assert pipe is mock_pipe + + @patch.object(batch_inference, "CogVideoXPipeline") + def test_load_pipeline_with_cpu_offload(self, mock_cogvideo_pipeline): + """Test that CPU offload is enabled when requested.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + load_pipeline( + model_path="model", + generate_type="t2v", + enable_cpu_offload=True, + ) + + mock_pipe.enable_sequential_cpu_offload.assert_called_once() + mock_pipe.to.assert_not_called() + + @patch.object(batch_inference, "CogVideoXPipeline") + def test_load_pipeline_without_cpu_offload(self, mock_cogvideo_pipeline): + """Test that pipeline moves to CUDA when CPU offload is disabled.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + load_pipeline( + model_path="model", + generate_type="t2v", + enable_cpu_offload=False, + ) + + mock_pipe.to.assert_called_once_with("cuda") + mock_pipe.enable_sequential_cpu_offload.assert_not_called() + + @patch.object(batch_inference, "CogVideoXPipeline") + def test_load_pipeline_with_lora(self, mock_cogvideo_pipeline): + """Test loading pipeline with LoRA weights.""" + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.vae = MagicMock() + mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe + + with patch.object(batch_inference, "CogVideoXDPMScheduler"): + load_pipeline( + model_path="model", + generate_type="t2v", + lora_path="/path/to/lora", + enable_cpu_offload=False, + ) + + mock_pipe.load_lora_weights.assert_called_once() + mock_pipe.fuse_lora.assert_called_once() + + +# ============================================================================= +# Video Generation Tests +# ============================================================================= + + +class TestGenerateSingleVideo: + """Tests for single video generation.""" + + @patch.object(batch_inference, "export_to_video") + def test_generate_t2v_video(self, mock_export, mock_pipeline, tmp_path): + """Test generating a text-to-video.""" + job = BatchJob( + prompt="A cat playing piano", + output_name="cat.mp4", + num_frames=49, + ) + output_path = tmp_path / "cat.mp4" + + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="t2v", + model_name="cogvideox-5b", + output_path=output_path, + ) + + mock_pipeline.assert_called_once() + mock_export.assert_called_once() + + @patch.object(batch_inference, "export_to_video") + @patch.object(batch_inference, "load_image") + def test_generate_i2v_video(self, mock_load_image, mock_export, mock_pipeline, tmp_path): + """Test generating an image-to-video.""" + mock_load_image.return_value = MagicMock() + + job = BatchJob( + prompt="Transform this", + output_name="i2v.mp4", + image_path="/path/to/image.jpg", + ) + output_path = tmp_path / "i2v.mp4" + + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="i2v", + model_name="cogvideox-5b-i2v", + output_path=output_path, + ) + + mock_load_image.assert_called_once_with(image="/path/to/image.jpg") + mock_pipeline.assert_called_once() + + @patch.object(batch_inference, "export_to_video") + @patch.object(batch_inference, "load_video") + def test_generate_v2v_video(self, mock_load_video, mock_export, mock_pipeline, tmp_path): + """Test generating a video-to-video.""" + mock_load_video.return_value = MagicMock() + + job = BatchJob( + prompt="Enhance this", + output_name="v2v.mp4", + video_path="/path/to/video.mp4", + ) + output_path = tmp_path / "v2v.mp4" + + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="v2v", + model_name="cogvideox-5b", + output_path=output_path, + ) + + mock_load_video.assert_called_once_with("/path/to/video.mp4") + mock_pipeline.assert_called_once() + + def test_generate_i2v_missing_image_path(self, mock_pipeline, tmp_path): + """Test that i2v generation fails without image_path.""" + job = BatchJob( + prompt="Transform this", + output_name="i2v.mp4", + # Missing image_path + ) + output_path = tmp_path / "i2v.mp4" + + with pytest.raises(ValueError, match="image_path is required"): + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="i2v", + model_name="cogvideox-5b-i2v", + output_path=output_path, + ) + + def test_generate_v2v_missing_video_path(self, mock_pipeline, tmp_path): + """Test that v2v generation fails without video_path.""" + job = BatchJob( + prompt="Enhance this", + output_name="v2v.mp4", + # Missing video_path + ) + output_path = tmp_path / "v2v.mp4" + + with pytest.raises(ValueError, match="video_path is required"): + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="v2v", + model_name="cogvideox-5b", + output_path=output_path, + ) + + @patch.object(batch_inference, "export_to_video") + def test_generate_uses_job_specific_params(self, mock_export, mock_pipeline, tmp_path): + """Test that job-specific parameters override defaults.""" + job = BatchJob( + prompt="Test", + output_name="test.mp4", + num_frames=33, + guidance_scale=9.0, + num_inference_steps=25, + seed=456, + ) + output_path = tmp_path / "test.mp4" + + generate_single_video( + pipe=mock_pipeline, + job=job, + generate_type="t2v", + model_name="cogvideox-5b", + output_path=output_path, + default_num_frames=81, + default_guidance_scale=6.0, + default_num_inference_steps=50, + default_seed=42, + ) + + # Check the pipeline was called with job-specific values + call_kwargs = mock_pipeline.call_args[1] + assert call_kwargs["num_frames"] == 33 + assert call_kwargs["guidance_scale"] == 9.0 + assert call_kwargs["num_inference_steps"] == 25 + + +# ============================================================================= +# Batch Processing Tests +# ============================================================================= + + +class TestRunBatch: + """Tests for the main batch processing function.""" + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_basic(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test basic batch processing.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + + summary = run_batch( + batch_file=batch_file, + model_path="THUDM/CogVideoX1.5-5B", + output_dir=output_dir, + generate_type="t2v", + resume=False, + ) + + assert summary["total"] == 3 + assert summary["completed"] == 3 + assert summary["failed"] == 0 + assert mock_generate.call_count == 3 + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_creates_output_dir( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): + """Test that output directory is created if it doesn't exist.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "nested" / "output" / "dir" + assert not output_dir.exists() + + run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + assert output_dir.exists() + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_saves_state(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test that state is saved after processing.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + + run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=True, + ) + + state_file = output_dir / ".batch_state.json" + assert state_file.exists() + + with open(state_file) as f: + state_data = json.load(f) + + assert len(state_data["completed"]) == 3 + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_handles_errors( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): + """Test that batch continues after individual job failures.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + # Make the second call raise an exception + mock_generate.side_effect = [None, RuntimeError("CUDA OOM"), None] + + output_dir = tmp_path / "output" + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + assert summary["completed"] == 2 + assert summary["failed"] == 1 + assert mock_generate.call_count == 3 + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_logs_errors_to_file( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): + """Test that errors are logged to errors.log.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + mock_generate.side_effect = RuntimeError("Test error") + + output_dir = tmp_path / "output" + + run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + error_log = output_dir / "errors.log" + assert error_log.exists() + content = error_log.read_text() + assert "Test error" in content + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_resume_skips_completed( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): + """Test that resume skips already-completed jobs.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + output_dir.mkdir() + + # Create a state file showing first job completed + state = BatchState( + batch_file=str(batch_file), + output_dir=str(output_dir), + model_path="model", + generate_type="t2v", + completed=["cat_piano.mp4"], + ) + state.save(output_dir / ".batch_state.json") + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=True, + ) + + # Should only process 2 jobs (skipping the completed one) + assert summary["total"] == 2 + assert summary["skipped"] == 1 + assert mock_generate.call_count == 2 + + @patch.object(batch_inference, "load_pipeline") + def test_run_batch_empty_file(self, mock_load_pipeline, tmp_path): + """Test running batch with empty file.""" + batch_file = tmp_path / "empty.jsonl" + batch_file.write_text("") + + output_dir = tmp_path / "output" + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + assert summary["total"] == 0 + # Pipeline should not be loaded for empty batch + mock_load_pipeline.assert_not_called() + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path): + """Test multi-GPU job distribution in run_batch.""" + # Create batch file with 6 jobs + jobs = [{"prompt": f"Job {i}", "output_name": f"video{i}.mp4"} for i in range(6)] + batch_file = tmp_path / "batch.jsonl" + batch_file.write_text("\n".join(json.dumps(j) for j in jobs)) + + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + + # GPU 0 of 3 should get jobs 0, 3 (indices 0, 3) + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + gpu_id=0, + num_gpus=3, + resume=False, + ) + + assert summary["total"] == 2 # GPU 0 gets 2 jobs + assert mock_generate.call_count == 2 + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_state_persists_after_each_job( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): + """Test that state is saved after each job, not just at the end.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + # Track how many times state file is written + state_writes = [] + original_save = BatchState.save + + def tracking_save(self, state_file): + original_save(self, state_file) + state_writes.append(len(self.completed)) + + with patch.object(BatchState, 'save', tracking_save): + output_dir = tmp_path / "output" + run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=True, + ) + + # State should be saved 3 times (once per job) + assert len(state_writes) == 3 + assert state_writes == [1, 2, 3] # Progressively more completed jobs + + +# ============================================================================= +# Resolution Map Tests +# ============================================================================= + + +class TestResolutionMap: + """Tests for resolution mapping.""" + + def test_resolution_map_contains_expected_models(self): + """Test that RESOLUTION_MAP contains expected model keys.""" + assert "cogvideox1.5-5b-i2v" in RESOLUTION_MAP + assert "cogvideox1.5-5b" in RESOLUTION_MAP + assert "cogvideox-5b-i2v" in RESOLUTION_MAP + assert "cogvideox-5b" in RESOLUTION_MAP + assert "cogvideox-2b" in RESOLUTION_MAP + + def test_resolution_values_are_valid(self): + """Test that all resolution values are valid (height, width) tuples.""" + for model, resolution in RESOLUTION_MAP.items(): + assert isinstance(resolution, tuple) + assert len(resolution) == 2 + assert resolution[0] > 0 # height + assert resolution[1] > 0 # width + + +# ============================================================================= +# Edge Case Tests +# ============================================================================= + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_batch_job_with_all_optional_fields(self): + """Test BatchJob with all optional fields set.""" + data = { + "prompt": "Test", + "output_name": "test.mp4", + "image_path": "/img.jpg", + "video_path": "/vid.mp4", + "num_frames": 100, + "guidance_scale": 10.0, + "num_inference_steps": 100, + "seed": 9999, + "width": 1920, + "height": 1080, + } + job = BatchJob.from_dict(data) + + assert job.width == 1920 + assert job.height == 1080 + assert job.num_frames == 100 + + def test_batch_job_preserves_extra_fields(self): + """Test that extra fields in JSON don't cause errors.""" + data = { + "prompt": "Test", + "output_name": "test.mp4", + "extra_field": "ignored", + "another_field": 123, + } + job = BatchJob.from_dict(data) + + assert job.prompt == "Test" + assert job.output_name == "test.mp4" + # Extra fields should be ignored without error + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_corrupted_state_starts_fresh( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): + """Test that corrupted state file doesn't prevent batch from running.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + + output_dir = tmp_path / "output" + output_dir.mkdir() + + # Create corrupted state file + state_file = output_dir / ".batch_state.json" + state_file.write_text("{corrupted: json") + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=True, + ) + + # Should start fresh and process all jobs + assert summary["total"] == 3 + assert mock_generate.call_count == 3 + + def test_batch_state_updated_at_is_set_on_save(self, tmp_path): + """Test that updated_at is set when saving state.""" + state = BatchState( + batch_file="batch.jsonl", + output_dir="./output", + model_path="model", + generate_type="t2v", + ) + + assert state.updated_at == "" + + state_file = tmp_path / "state.json" + state.save(state_file) + + # updated_at should be set after save + assert state.updated_at != "" + + @patch.object(batch_inference, "load_pipeline") + @patch.object(batch_inference, "generate_single_video") + def test_run_batch_all_jobs_fail(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + """Test batch where all jobs fail.""" + mock_pipe = MagicMock() + mock_load_pipeline.return_value = mock_pipe + mock_generate.side_effect = RuntimeError("All fail") + + output_dir = tmp_path / "output" + + summary = run_batch( + batch_file=batch_file, + model_path="model", + output_dir=output_dir, + resume=False, + ) + + assert summary["completed"] == 0 + assert summary["failed"] == 3 + + def test_jsonl_with_unicode(self, tmp_path): + """Test JSONL parsing with unicode characters.""" + content = """{"prompt": "猫がピアノを弾いている", "output_name": "unicode1.mp4"} +{"prompt": "海辺の波 🌊", "output_name": "unicode2.mp4"} +{"prompt": "Café résumé naïve", "output_name": "unicode3.mp4"}""" + batch_file = tmp_path / "unicode.jsonl" + batch_file.write_text(content, encoding="utf-8") + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 3 + assert jobs[0].prompt == "猫がピアノを弾いている" + assert "🌊" in jobs[1].prompt + + def test_jsonl_with_very_long_prompt(self, tmp_path): + """Test JSONL with very long prompts.""" + long_prompt = "A " + "very " * 1000 + "long prompt" + content = json.dumps({"prompt": long_prompt, "output_name": "long.mp4"}) + batch_file = tmp_path / "long.jsonl" + batch_file.write_text(content) + + jobs = load_batch_file(batch_file) + + assert len(jobs) == 1 + assert len(jobs[0].prompt) > 4000 diff --git a/tools/batch_inference.py b/tools/batch_inference.py new file mode 100644 index 0000000..2771fac --- /dev/null +++ b/tools/batch_inference.py @@ -0,0 +1,701 @@ +#!/usr/bin/env python3 +""" +Batch Inference Pipeline for CogVideo + +A production-grade tool for generating multiple videos from a batch input file. +Supports text-to-video (t2v), image-to-video (i2v), and video-to-video (v2v) generation. + +Features: +- JSONL batch input format (one job per line) +- Resume capability (skips already-completed jobs) +- Progress tracking with ETA +- Robust error handling (logs errors, continues batch) +- Memory-efficient (loads model once, generates sequentially) +- Multi-GPU support via job-level parallelism + +Input Format (JSONL): +Each line is a JSON object with the following fields: + - prompt (str, required): Text description for video generation + - output_name (str, required): Output filename (without path, e.g., "my_video.mp4") + - image_path (str, optional): Path to input image for i2v generation + - video_path (str, optional): Path to input video for v2v generation + - num_frames (int, optional): Number of frames (default: 81) + - guidance_scale (float, optional): CFG scale (default: 6.0) + - num_inference_steps (int, optional): Inference steps (default: 50) + - seed (int, optional): Random seed (default: 42) + - width (int, optional): Video width + - height (int, optional): Video height + +Example JSONL (resources/example_batch.jsonl): + {"prompt": "A cat playing piano", "output_name": "cat_piano.mp4"} + {"prompt": "Waves crashing on beach", "output_name": "beach.mp4", "num_frames": 49} + {"prompt": "Transform this image", "output_name": "i2v_output.mp4", "image_path": "./input.jpg"} + +Usage: + # Basic usage (text-to-video) + python tools/batch_inference.py \\ + --batch_file resources/example_batch.jsonl \\ + --model_path THUDM/CogVideoX1.5-5B \\ + --output_dir ./batch_output + + # Image-to-video batch + python tools/batch_inference.py \\ + --batch_file resources/i2v_batch.jsonl \\ + --model_path THUDM/CogVideoX1.5-5B-I2V \\ + --generate_type i2v \\ + --output_dir ./batch_output + + # Resume interrupted batch + python tools/batch_inference.py \\ + --batch_file resources/example_batch.jsonl \\ + --model_path THUDM/CogVideoX1.5-5B \\ + --output_dir ./batch_output \\ + --resume + + # Multi-GPU: distribute jobs across GPUs + CUDA_VISIBLE_DEVICES=0 python tools/batch_inference.py --batch_file batch.jsonl --gpu_id 0 --num_gpus 4 & + CUDA_VISIBLE_DEVICES=1 python tools/batch_inference.py --batch_file batch.jsonl --gpu_id 1 --num_gpus 4 & + CUDA_VISIBLE_DEVICES=2 python tools/batch_inference.py --batch_file batch.jsonl --gpu_id 2 --num_gpus 4 & + CUDA_VISIBLE_DEVICES=3 python tools/batch_inference.py --batch_file batch.jsonl --gpu_id 3 --num_gpus 4 & + +Author: CogVideo Contributors +License: Apache 2.0 +""" + +import argparse +import json +import logging +import os +import sys +import time +import traceback +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional + +import torch +from tqdm import tqdm + +from diffusers import ( + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXVideoToVideoPipeline, +) +from diffusers.utils import export_to_video, load_image, load_video + + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + +# Recommended resolution for each model (height, width) +RESOLUTION_MAP = { + "cogvideox1.5-5b-i2v": (768, 1360), + "cogvideox1.5-5b": (768, 1360), + "cogvideox-5b-i2v": (480, 720), + "cogvideox-5b": (480, 720), + "cogvideox-2b": (480, 720), +} + + +@dataclass +class BatchJob: + """Represents a single job in the batch.""" + + prompt: str + output_name: str + image_path: Optional[str] = None + video_path: Optional[str] = None + num_frames: Optional[int] = None + guidance_scale: Optional[float] = None + num_inference_steps: Optional[int] = None + seed: Optional[int] = None + width: Optional[int] = None + height: Optional[int] = None + + # Internal fields + line_number: int = 0 + status: str = "pending" + error: Optional[str] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any], line_number: int = 0) -> "BatchJob": + """Create a BatchJob from a dictionary.""" + return cls( + prompt=data.get("prompt", ""), + output_name=data.get("output_name", ""), + image_path=data.get("image_path"), + video_path=data.get("video_path"), + num_frames=data.get("num_frames"), + guidance_scale=data.get("guidance_scale"), + num_inference_steps=data.get("num_inference_steps"), + seed=data.get("seed"), + width=data.get("width"), + height=data.get("height"), + line_number=line_number, + ) + + def validate(self) -> List[str]: + """Validate the job and return list of errors.""" + errors = [] + if not self.prompt: + errors.append("Missing required field: prompt") + if not self.output_name: + errors.append("Missing required field: output_name") + return errors + + +@dataclass +class BatchState: + """Tracks batch progress for resume capability.""" + + batch_file: str + output_dir: str + model_path: str + generate_type: str + completed: List[str] = field(default_factory=list) + failed: List[Dict[str, str]] = field(default_factory=list) + started_at: str = "" + updated_at: str = "" + + @classmethod + def load(cls, state_file: Path) -> Optional["BatchState"]: + """Load state from file.""" + if not state_file.exists(): + return None + try: + with open(state_file, "r") as f: + data = json.load(f) + return cls(**data) + except Exception as e: + logger.warning(f"Failed to load state file: {e}") + return None + + def save(self, state_file: Path): + """Save state to file.""" + self.updated_at = datetime.now().isoformat() + with open(state_file, "w") as f: + json.dump(self.__dict__, f, indent=2) + + def mark_completed(self, output_name: str): + """Mark a job as completed.""" + if output_name not in self.completed: + self.completed.append(output_name) + + def mark_failed(self, output_name: str, error: str): + """Mark a job as failed.""" + self.failed.append({"output_name": output_name, "error": error}) + + def is_completed(self, output_name: str) -> bool: + """Check if a job was already completed.""" + return output_name in self.completed + + +def load_batch_file(batch_file: Path) -> List[BatchJob]: + """ + Load and parse a JSONL batch file. + + Args: + batch_file: Path to the JSONL file + + Returns: + List of BatchJob objects + """ + jobs = [] + + with open(batch_file, "r") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line or line.startswith("#"): + continue + + try: + data = json.loads(line) + job = BatchJob.from_dict(data, line_number=line_num) + + # Validate job + errors = job.validate() + if errors: + logger.warning(f"Line {line_num}: Invalid job - {', '.join(errors)}") + continue + + jobs.append(job) + + except json.JSONDecodeError as e: + logger.warning(f"Line {line_num}: Invalid JSON - {e}") + continue + + return jobs + + +def load_pipeline( + model_path: str, + generate_type: str, + dtype: torch.dtype = torch.bfloat16, + lora_path: Optional[str] = None, + enable_cpu_offload: bool = True, +): + """ + Load the appropriate pipeline for the generation type. + + Args: + model_path: Path to the model + generate_type: Type of generation (t2v, i2v, v2v) + dtype: Data type for computation + lora_path: Optional path to LoRA weights + enable_cpu_offload: Whether to enable CPU offloading + + Returns: + The loaded pipeline + """ + logger.info(f"Loading pipeline for {generate_type} from {model_path}") + + if generate_type == "i2v": + pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) + elif generate_type == "t2v": + pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype) + else: # v2v + pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) + + # Load LoRA weights if provided + if lora_path: + logger.info(f"Loading LoRA weights from {lora_path}") + pipe.load_lora_weights( + lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="batch_lora" + ) + pipe.fuse_lora(components=["transformer"], lora_scale=1.0) + + # Set scheduler + pipe.scheduler = CogVideoXDPMScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing" + ) + + # Enable memory optimizations + if enable_cpu_offload: + pipe.enable_sequential_cpu_offload() + else: + pipe.to("cuda") + + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + + logger.info("Pipeline loaded successfully") + return pipe + + +def generate_single_video( + pipe, + job: BatchJob, + generate_type: str, + model_name: str, + output_path: Path, + default_num_frames: int = 81, + default_guidance_scale: float = 6.0, + default_num_inference_steps: int = 50, + default_seed: int = 42, + fps: int = 16, +): + """ + Generate a single video from a job. + + Args: + pipe: The loaded pipeline + job: The batch job to process + generate_type: Type of generation + model_name: Name of the model (for resolution lookup) + output_path: Full path for output video + default_*: Default values for optional parameters + fps: Frames per second for output video + """ + # Determine resolution + desired_resolution = RESOLUTION_MAP.get(model_name.lower(), (480, 720)) + height = job.height if job.height else desired_resolution[0] + width = job.width if job.width else desired_resolution[1] + + # Use job-specific or default values + num_frames = job.num_frames or default_num_frames + guidance_scale = job.guidance_scale or default_guidance_scale + num_inference_steps = job.num_inference_steps or default_num_inference_steps + seed = job.seed or default_seed + + # Load image/video if needed + image = None + video = None + + if generate_type == "i2v": + if not job.image_path: + raise ValueError("image_path is required for i2v generation") + image = load_image(image=job.image_path) + elif generate_type == "v2v": + if not job.video_path: + raise ValueError("video_path is required for v2v generation") + video = load_video(job.video_path) + + # Generate video + generator = torch.Generator().manual_seed(seed) + + if generate_type == "i2v": + video_frames = pipe( + height=height, + width=width, + prompt=job.prompt, + image=image, + num_videos_per_prompt=1, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=generator, + ).frames[0] + elif generate_type == "t2v": + video_frames = pipe( + height=height, + width=width, + prompt=job.prompt, + num_videos_per_prompt=1, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=generator, + ).frames[0] + else: # v2v + video_frames = pipe( + height=height, + width=width, + prompt=job.prompt, + video=video, + num_videos_per_prompt=1, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=generator, + ).frames[0] + + # Export video + export_to_video(video_frames, str(output_path), fps=fps) + + +def run_batch( + batch_file: Path, + model_path: str, + output_dir: Path, + generate_type: str = "t2v", + dtype: torch.dtype = torch.bfloat16, + lora_path: Optional[str] = None, + enable_cpu_offload: bool = True, + resume: bool = True, + gpu_id: int = 0, + num_gpus: int = 1, + default_num_frames: int = 81, + default_guidance_scale: float = 6.0, + default_num_inference_steps: int = 50, + default_seed: int = 42, + fps: int = 16, +) -> Dict[str, Any]: + """ + Run batch inference on a JSONL file. + + Args: + batch_file: Path to the JSONL batch file + model_path: Path to the model + output_dir: Directory for output videos + generate_type: Type of generation (t2v, i2v, v2v) + dtype: Data type for computation + lora_path: Optional path to LoRA weights + enable_cpu_offload: Whether to enable CPU offloading + resume: Whether to resume from previous state + gpu_id: GPU ID for multi-GPU distribution + num_gpus: Total number of GPUs for distribution + default_*: Default values for optional parameters + fps: Frames per second for output videos + + Returns: + Summary dictionary with statistics + """ + # Setup paths + output_dir.mkdir(parents=True, exist_ok=True) + state_file = output_dir / ".batch_state.json" + error_log = output_dir / "errors.log" + + # Load jobs + logger.info(f"Loading batch file: {batch_file}") + all_jobs = load_batch_file(batch_file) + logger.info(f"Found {len(all_jobs)} valid jobs in batch file") + + # Distribute jobs across GPUs if using multi-GPU + if num_gpus > 1: + jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] + logger.info(f"GPU {gpu_id}/{num_gpus}: Processing {len(jobs)} jobs") + else: + jobs = all_jobs + + # Load or create state + state = None + if resume: + state = BatchState.load(state_file) + if state: + logger.info(f"Resuming batch: {len(state.completed)} already completed") + + if state is None: + state = BatchState( + batch_file=str(batch_file), + output_dir=str(output_dir), + model_path=model_path, + generate_type=generate_type, + started_at=datetime.now().isoformat(), + ) + + # Filter out completed jobs + if resume: + pending_jobs = [j for j in jobs if not state.is_completed(j.output_name)] + skipped = len(jobs) - len(pending_jobs) + if skipped > 0: + logger.info(f"Skipping {skipped} already-completed jobs") + jobs = pending_jobs + + if not jobs: + logger.info("No jobs to process") + return {"total": 0, "completed": 0, "failed": 0, "skipped": len(all_jobs)} + + # Load pipeline + model_name = model_path.split("/")[-1] + pipe = load_pipeline( + model_path=model_path, + generate_type=generate_type, + dtype=dtype, + lora_path=lora_path, + enable_cpu_offload=enable_cpu_offload, + ) + + # Process jobs + completed = 0 + failed = 0 + start_time = time.time() + + with tqdm(total=len(jobs), desc="Generating videos", unit="video") as pbar: + for job in jobs: + output_path = output_dir / job.output_name + + try: + logger.info(f"Processing: {job.output_name} - \"{job.prompt[:50]}...\"") + + generate_single_video( + pipe=pipe, + job=job, + generate_type=generate_type, + model_name=model_name, + output_path=output_path, + default_num_frames=default_num_frames, + default_guidance_scale=default_guidance_scale, + default_num_inference_steps=default_num_inference_steps, + default_seed=default_seed, + fps=fps, + ) + + state.mark_completed(job.output_name) + completed += 1 + logger.info(f"Completed: {job.output_name}") + + except Exception as e: + error_msg = f"{type(e).__name__}: {str(e)}" + logger.error(f"Failed: {job.output_name} - {error_msg}") + + # Log full traceback to error log + with open(error_log, "a") as f: + f.write(f"\n{'='*60}\n") + f.write(f"Job: {job.output_name}\n") + f.write(f"Prompt: {job.prompt}\n") + f.write(f"Time: {datetime.now().isoformat()}\n") + f.write(f"Error: {error_msg}\n") + f.write(traceback.format_exc()) + + state.mark_failed(job.output_name, error_msg) + failed += 1 + + # Save state after each job (for resume) + state.save(state_file) + pbar.update(1) + + # Update ETA in progress bar + elapsed = time.time() - start_time + if completed + failed > 0: + avg_time = elapsed / (completed + failed) + remaining = len(jobs) - (completed + failed) + eta_seconds = avg_time * remaining + pbar.set_postfix( + {"done": completed, "failed": failed, "ETA": f"{eta_seconds/60:.1f}m"} + ) + + # Final summary + elapsed_total = time.time() - start_time + summary = { + "total": len(jobs), + "completed": completed, + "failed": failed, + "skipped": len(all_jobs) - len(jobs), + "elapsed_seconds": elapsed_total, + "avg_seconds_per_video": elapsed_total / max(completed + failed, 1), + } + + logger.info("=" * 60) + logger.info("BATCH COMPLETE") + logger.info(f" Total jobs: {summary['total']}") + logger.info(f" Completed: {summary['completed']}") + logger.info(f" Failed: {summary['failed']}") + logger.info(f" Elapsed: {elapsed_total/60:.1f} minutes") + if summary['failed'] > 0: + logger.info(f" See errors in: {error_log}") + logger.info("=" * 60) + + return summary + + +def main(): + parser = argparse.ArgumentParser( + description="Batch inference for CogVideo - generate multiple videos from a JSONL file", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic text-to-video batch + python batch_inference.py --batch_file prompts.jsonl --model_path THUDM/CogVideoX1.5-5B + + # Image-to-video batch with custom output directory + python batch_inference.py --batch_file i2v.jsonl --model_path THUDM/CogVideoX1.5-5B-I2V \\ + --generate_type i2v --output_dir ./my_videos + + # Multi-GPU: run on 4 GPUs (one process per GPU) + for i in {0..3}; do + CUDA_VISIBLE_DEVICES=$i python batch_inference.py --batch_file batch.jsonl \\ + --gpu_id $i --num_gpus 4 & + done + +JSONL Format: + Each line is a JSON object with: prompt (required), output_name (required), + and optional: image_path, video_path, num_frames, guidance_scale, + num_inference_steps, seed, width, height + """, + ) + + # Required arguments + parser.add_argument("--batch_file", type=str, required=True, help="Path to JSONL batch file") + parser.add_argument( + "--model_path", + type=str, + default="THUDM/CogVideoX1.5-5B", + help="Path to the model (default: THUDM/CogVideoX1.5-5B)", + ) + + # Output settings + parser.add_argument( + "--output_dir", + type=str, + default="./batch_output", + help="Directory for output videos (default: ./batch_output)", + ) + parser.add_argument( + "--generate_type", + type=str, + choices=["t2v", "i2v", "v2v"], + default="t2v", + help="Generation type (default: t2v)", + ) + + # Model settings + parser.add_argument( + "--lora_path", type=str, default=None, help="Path to LoRA weights (optional)" + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float16", "bfloat16"], + default="bfloat16", + help="Data type for computation (default: bfloat16)", + ) + parser.add_argument( + "--disable_cpu_offload", + action="store_true", + help="Disable CPU offloading (uses more VRAM but faster)", + ) + + # Default generation parameters + parser.add_argument( + "--num_frames", type=int, default=81, help="Default number of frames (default: 81)" + ) + parser.add_argument( + "--guidance_scale", type=float, default=6.0, help="Default guidance scale (default: 6.0)" + ) + parser.add_argument( + "--num_inference_steps", type=int, default=50, help="Default inference steps (default: 50)" + ) + parser.add_argument("--seed", type=int, default=42, help="Default random seed (default: 42)") + parser.add_argument("--fps", type=int, default=16, help="Output video FPS (default: 16)") + + # Resume and multi-GPU + parser.add_argument( + "--resume", + action="store_true", + default=True, + help="Resume from previous state (default: True)", + ) + parser.add_argument("--no_resume", action="store_true", help="Don't resume, start fresh") + parser.add_argument( + "--gpu_id", type=int, default=0, help="GPU ID for multi-GPU distribution (default: 0)" + ) + parser.add_argument( + "--num_gpus", type=int, default=1, help="Total number of GPUs for distribution (default: 1)" + ) + + args = parser.parse_args() + + # Validate batch file exists + batch_file = Path(args.batch_file) + if not batch_file.exists(): + logger.error(f"Batch file not found: {batch_file}") + sys.exit(1) + + # Parse dtype + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + + # Run batch + try: + summary = run_batch( + batch_file=batch_file, + model_path=args.model_path, + output_dir=Path(args.output_dir), + generate_type=args.generate_type, + dtype=dtype, + lora_path=args.lora_path, + enable_cpu_offload=not args.disable_cpu_offload, + resume=not args.no_resume, + gpu_id=args.gpu_id, + num_gpus=args.num_gpus, + default_num_frames=args.num_frames, + default_guidance_scale=args.guidance_scale, + default_num_inference_steps=args.num_inference_steps, + default_seed=args.seed, + fps=args.fps, + ) + + # Exit with error code if any failures + if summary["failed"] > 0: + sys.exit(1) + + except KeyboardInterrupt: + logger.info("\nBatch interrupted by user. Progress saved for resume.") + sys.exit(130) + except Exception as e: + logger.error(f"Batch failed: {e}") + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main()