diff --git a/.gitignore b/.gitignore index ad4bbeb..303dbae 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ logs/ .idea output* test* +!tests/ +!tests/** venv **/.swp **/*.log diff --git a/tests/test_batch_inference.py b/tests/test_batch_inference.py index 1225cb5..3cb60fe 100644 --- a/tests/test_batch_inference.py +++ b/tests/test_batch_inference.py @@ -34,28 +34,30 @@ mock_torch.Generator = MagicMock # Create mock diffusers module mock_diffusers = MagicMock() + # Create mock tqdm that works as a context manager class MockTqdm: """Mock tqdm that works as a context manager and returns a dummy progress bar.""" + def __init__(self, iterable=None, total=None, **kwargs): self.iterable = iterable self.n = 0 - + def __iter__(self): if self.iterable is not None: for item in self.iterable: yield item self.n += 1 - + def __enter__(self): return self - + def __exit__(self, *args): pass - + def update(self, n=1): self.n += n - + def set_postfix(self, ordered_dict=None, refresh=True, **kwargs): pass @@ -172,7 +174,7 @@ class TestBatchJob: def test_from_dict_basic(self, sample_job_dict): """Test creating a BatchJob from a dictionary.""" job = BatchJob.from_dict(sample_job_dict, line_number=1) - + assert job.prompt == "A cat playing piano" assert job.output_name == "cat_piano.mp4" assert job.num_frames == 49 @@ -185,7 +187,7 @@ class TestBatchJob: """Test creating a BatchJob with only required fields.""" data = {"prompt": "Test prompt", "output_name": "test.mp4"} job = BatchJob.from_dict(data) - + assert job.prompt == "Test prompt" assert job.output_name == "test.mp4" assert job.num_frames is None @@ -195,14 +197,14 @@ class TestBatchJob: def test_from_dict_with_image_path(self, sample_i2v_job_dict): """Test creating a BatchJob with image_path for i2v.""" job = BatchJob.from_dict(sample_i2v_job_dict) - + assert job.image_path == "/path/to/image.jpg" assert job.video_path is None def test_from_dict_with_video_path(self, sample_v2v_job_dict): """Test creating a BatchJob with video_path for v2v.""" job = BatchJob.from_dict(sample_v2v_job_dict) - + assert job.video_path == "/path/to/video.mp4" assert job.image_path is None @@ -210,28 +212,28 @@ class TestBatchJob: """Test validation of a valid job returns no errors.""" job = BatchJob.from_dict(sample_job_dict) errors = job.validate() - + assert errors == [] def test_validate_missing_prompt(self): """Test validation catches missing prompt.""" job = BatchJob.from_dict({"output_name": "test.mp4"}) errors = job.validate() - + assert "Missing required field: prompt" in errors def test_validate_missing_output_name(self): """Test validation catches missing output_name.""" job = BatchJob.from_dict({"prompt": "Test"}) errors = job.validate() - + assert "Missing required field: output_name" in errors def test_validate_missing_both_required(self): """Test validation catches both missing required fields.""" job = BatchJob.from_dict({}) errors = job.validate() - + assert len(errors) == 2 assert "Missing required field: prompt" in errors assert "Missing required field: output_name" in errors @@ -253,7 +255,7 @@ class TestBatchState: model_path="THUDM/CogVideoX1.5-5B", generate_type="t2v", ) - + assert state.batch_file == "batch.jsonl" assert state.completed == [] assert state.failed == [] @@ -266,10 +268,10 @@ class TestBatchState: model_path="model", generate_type="t2v", ) - + state.mark_completed("video1.mp4") state.mark_completed("video2.mp4") - + assert "video1.mp4" in state.completed assert "video2.mp4" in state.completed @@ -281,10 +283,10 @@ class TestBatchState: model_path="model", generate_type="t2v", ) - + state.mark_completed("video1.mp4") state.mark_completed("video1.mp4") - + assert state.completed.count("video1.mp4") == 1 def test_mark_failed(self): @@ -295,9 +297,9 @@ class TestBatchState: model_path="model", generate_type="t2v", ) - + state.mark_failed("video1.mp4", "CUDA out of memory") - + assert len(state.failed) == 1 assert state.failed[0]["output_name"] == "video1.mp4" assert state.failed[0]["error"] == "CUDA out of memory" @@ -310,16 +312,16 @@ class TestBatchState: model_path="model", generate_type="t2v", ) - + state.mark_completed("video1.mp4") - + assert state.is_completed("video1.mp4") is True assert state.is_completed("video2.mp4") is False def test_save_and_load(self, tmp_path): """Test saving and loading state.""" state_file = tmp_path / ".batch_state.json" - + # Create and save state state = BatchState( batch_file="batch.jsonl", @@ -331,10 +333,10 @@ class TestBatchState: state.mark_completed("video1.mp4") state.mark_failed("video2.mp4", "Error") state.save(state_file) - + # Load state loaded = BatchState.load(state_file) - + assert loaded is not None assert loaded.batch_file == "batch.jsonl" assert "video1.mp4" in loaded.completed @@ -344,27 +346,27 @@ class TestBatchState: def test_load_nonexistent_file(self, tmp_path): """Test loading from a nonexistent file returns None.""" state_file = tmp_path / "nonexistent.json" - + loaded = BatchState.load(state_file) - + assert loaded is None def test_load_corrupted_file(self, tmp_path): """Test loading a corrupted state file returns None.""" state_file = tmp_path / ".batch_state.json" state_file.write_text("{ invalid json }") - + loaded = BatchState.load(state_file) - + assert loaded is None def test_load_invalid_json_structure(self, tmp_path): """Test loading a file with valid JSON but invalid structure.""" state_file = tmp_path / ".batch_state.json" state_file.write_text('{"wrong": "structure"}') - + loaded = BatchState.load(state_file) - + # Should return None due to missing required fields assert loaded is None @@ -380,7 +382,7 @@ class TestLoadBatchFile: def test_load_valid_jsonl(self, batch_file): """Test loading a valid JSONL file.""" jobs = load_batch_file(batch_file) - + assert len(jobs) == 3 assert jobs[0].prompt == "A cat playing piano" assert jobs[1].output_name == "beach.mp4" @@ -394,9 +396,9 @@ class TestLoadBatchFile: {"prompt": "Test2", "output_name": "test2.mp4"}""" batch_file = tmp_path / "batch.jsonl" batch_file.write_text(content) - + jobs = load_batch_file(batch_file) - + assert len(jobs) == 2 def test_load_jsonl_with_empty_lines(self, tmp_path): @@ -408,9 +410,9 @@ class TestLoadBatchFile: """ batch_file = tmp_path / "batch.jsonl" batch_file.write_text(content) - + jobs = load_batch_file(batch_file) - + assert len(jobs) == 2 def test_load_jsonl_invalid_json_line(self, tmp_path): @@ -420,9 +422,9 @@ class TestLoadBatchFile: {"prompt": "Test2", "output_name": "test2.mp4"}""" batch_file = tmp_path / "batch.jsonl" batch_file.write_text(content) - + jobs = load_batch_file(batch_file) - + # Should skip the invalid line assert len(jobs) == 2 @@ -434,9 +436,9 @@ class TestLoadBatchFile: {"prompt": "Test2", "output_name": "test2.mp4"}""" batch_file = tmp_path / "batch.jsonl" batch_file.write_text(content) - + jobs = load_batch_file(batch_file) - + # Should skip the two invalid jobs assert len(jobs) == 2 assert jobs[0].output_name == "test.mp4" @@ -446,9 +448,9 @@ class TestLoadBatchFile: """Test loading an empty batch file.""" batch_file = tmp_path / "empty.jsonl" batch_file.write_text("") - + jobs = load_batch_file(batch_file) - + assert len(jobs) == 0 def test_load_batch_file_all_comments(self, tmp_path): @@ -458,9 +460,9 @@ class TestLoadBatchFile: # Comment 3""" batch_file = tmp_path / "comments.jsonl" batch_file.write_text(content) - + jobs = load_batch_file(batch_file) - + assert len(jobs) == 0 def test_load_batch_file_tracks_line_numbers(self, tmp_path): @@ -471,9 +473,9 @@ class TestLoadBatchFile: {"prompt": "Line 4", "output_name": "line4.mp4"}""" batch_file = tmp_path / "batch.jsonl" batch_file.write_text(content) - + jobs = load_batch_file(batch_file) - + assert jobs[0].line_number == 2 assert jobs[1].line_number == 4 @@ -488,30 +490,24 @@ class TestMultiGPUDistribution: def test_single_gpu_gets_all_jobs(self): """Test that a single GPU processes all jobs.""" - all_jobs = [ - BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") - for i in range(10) - ] - + all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)] + # Simulate distribution for single GPU (gpu_id=0, num_gpus=1) gpu_id = 0 num_gpus = 1 jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] - + assert len(jobs) == 10 def test_two_gpu_distribution(self): """Test job distribution across 2 GPUs.""" - all_jobs = [ - BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") - for i in range(10) - ] - + all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)] + # GPU 0 gets jobs 0, 2, 4, 6, 8 gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0] # GPU 1 gets jobs 1, 3, 5, 7, 9 gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 1] - + assert len(gpu0_jobs) == 5 assert len(gpu1_jobs) == 5 assert gpu0_jobs[0].output_name == "video0.mp4" @@ -519,11 +515,8 @@ class TestMultiGPUDistribution: def test_four_gpu_distribution(self): """Test job distribution across 4 GPUs.""" - all_jobs = [ - BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") - for i in range(12) - ] - + all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(12)] + num_gpus = 4 for gpu_id in range(num_gpus): jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] @@ -531,16 +524,13 @@ class TestMultiGPUDistribution: def test_uneven_distribution(self): """Test distribution when jobs don't divide evenly.""" - all_jobs = [ - BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") - for i in range(10) - ] - + all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)] + num_gpus = 3 gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0] gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 1] gpu2_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 2] - + # 10 jobs / 3 GPUs: GPU0 gets 4, GPU1 gets 3, GPU2 gets 3 assert len(gpu0_jobs) == 4 # indices 0, 3, 6, 9 assert len(gpu1_jobs) == 3 # indices 1, 4, 7 @@ -562,14 +552,14 @@ class TestLoadPipeline: mock_pipe.scheduler = MagicMock() mock_pipe.vae = MagicMock() mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe - + with patch.object(batch_inference, "CogVideoXDPMScheduler"): pipe = load_pipeline( model_path="THUDM/CogVideoX1.5-5B", generate_type="t2v", enable_cpu_offload=False, ) - + mock_cogvideo_pipeline.from_pretrained.assert_called_once() assert pipe is mock_pipe @@ -580,14 +570,14 @@ class TestLoadPipeline: mock_pipe.scheduler = MagicMock() mock_pipe.vae = MagicMock() mock_i2v_pipeline.from_pretrained.return_value = mock_pipe - + with patch.object(batch_inference, "CogVideoXDPMScheduler"): pipe = load_pipeline( model_path="THUDM/CogVideoX1.5-5B-I2V", generate_type="i2v", enable_cpu_offload=False, ) - + mock_i2v_pipeline.from_pretrained.assert_called_once() assert pipe is mock_pipe @@ -598,14 +588,14 @@ class TestLoadPipeline: mock_pipe.scheduler = MagicMock() mock_pipe.vae = MagicMock() mock_v2v_pipeline.from_pretrained.return_value = mock_pipe - + with patch.object(batch_inference, "CogVideoXDPMScheduler"): pipe = load_pipeline( model_path="THUDM/CogVideoX1.5-5B", generate_type="v2v", enable_cpu_offload=False, ) - + mock_v2v_pipeline.from_pretrained.assert_called_once() assert pipe is mock_pipe @@ -616,14 +606,14 @@ class TestLoadPipeline: mock_pipe.scheduler = MagicMock() mock_pipe.vae = MagicMock() mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe - + with patch.object(batch_inference, "CogVideoXDPMScheduler"): load_pipeline( model_path="model", generate_type="t2v", enable_cpu_offload=True, ) - + mock_pipe.enable_sequential_cpu_offload.assert_called_once() mock_pipe.to.assert_not_called() @@ -634,14 +624,14 @@ class TestLoadPipeline: mock_pipe.scheduler = MagicMock() mock_pipe.vae = MagicMock() mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe - + with patch.object(batch_inference, "CogVideoXDPMScheduler"): load_pipeline( model_path="model", generate_type="t2v", enable_cpu_offload=False, ) - + mock_pipe.to.assert_called_once_with("cuda") mock_pipe.enable_sequential_cpu_offload.assert_not_called() @@ -652,7 +642,7 @@ class TestLoadPipeline: mock_pipe.scheduler = MagicMock() mock_pipe.vae = MagicMock() mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe - + with patch.object(batch_inference, "CogVideoXDPMScheduler"): load_pipeline( model_path="model", @@ -660,7 +650,7 @@ class TestLoadPipeline: lora_path="/path/to/lora", enable_cpu_offload=False, ) - + mock_pipe.load_lora_weights.assert_called_once() mock_pipe.fuse_lora.assert_called_once() @@ -682,7 +672,7 @@ class TestGenerateSingleVideo: num_frames=49, ) output_path = tmp_path / "cat.mp4" - + generate_single_video( pipe=mock_pipeline, job=job, @@ -690,7 +680,7 @@ class TestGenerateSingleVideo: model_name="cogvideox-5b", output_path=output_path, ) - + mock_pipeline.assert_called_once() mock_export.assert_called_once() @@ -699,14 +689,14 @@ class TestGenerateSingleVideo: def test_generate_i2v_video(self, mock_load_image, mock_export, mock_pipeline, tmp_path): """Test generating an image-to-video.""" mock_load_image.return_value = MagicMock() - + job = BatchJob( prompt="Transform this", output_name="i2v.mp4", image_path="/path/to/image.jpg", ) output_path = tmp_path / "i2v.mp4" - + generate_single_video( pipe=mock_pipeline, job=job, @@ -714,7 +704,7 @@ class TestGenerateSingleVideo: model_name="cogvideox-5b-i2v", output_path=output_path, ) - + mock_load_image.assert_called_once_with(image="/path/to/image.jpg") mock_pipeline.assert_called_once() @@ -723,14 +713,14 @@ class TestGenerateSingleVideo: def test_generate_v2v_video(self, mock_load_video, mock_export, mock_pipeline, tmp_path): """Test generating a video-to-video.""" mock_load_video.return_value = MagicMock() - + job = BatchJob( prompt="Enhance this", output_name="v2v.mp4", video_path="/path/to/video.mp4", ) output_path = tmp_path / "v2v.mp4" - + generate_single_video( pipe=mock_pipeline, job=job, @@ -738,7 +728,7 @@ class TestGenerateSingleVideo: model_name="cogvideox-5b", output_path=output_path, ) - + mock_load_video.assert_called_once_with("/path/to/video.mp4") mock_pipeline.assert_called_once() @@ -750,7 +740,7 @@ class TestGenerateSingleVideo: # Missing image_path ) output_path = tmp_path / "i2v.mp4" - + with pytest.raises(ValueError, match="image_path is required"): generate_single_video( pipe=mock_pipeline, @@ -768,7 +758,7 @@ class TestGenerateSingleVideo: # Missing video_path ) output_path = tmp_path / "v2v.mp4" - + with pytest.raises(ValueError, match="video_path is required"): generate_single_video( pipe=mock_pipeline, @@ -790,7 +780,7 @@ class TestGenerateSingleVideo: seed=456, ) output_path = tmp_path / "test.mp4" - + generate_single_video( pipe=mock_pipeline, job=job, @@ -802,7 +792,7 @@ class TestGenerateSingleVideo: default_num_inference_steps=50, default_seed=42, ) - + # Check the pipeline was called with job-specific values call_kwargs = mock_pipeline.call_args[1] assert call_kwargs["num_frames"] == 33 @@ -824,9 +814,9 @@ class TestRunBatch: """Test basic batch processing.""" mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe - + output_dir = tmp_path / "output" - + summary = run_batch( batch_file=batch_file, model_path="THUDM/CogVideoX1.5-5B", @@ -834,7 +824,7 @@ class TestRunBatch: generate_type="t2v", resume=False, ) - + assert summary["total"] == 3 assert summary["completed"] == 3 assert summary["failed"] == 0 @@ -842,21 +832,23 @@ class TestRunBatch: @patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "generate_single_video") - def test_run_batch_creates_output_dir(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + def test_run_batch_creates_output_dir( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): """Test that output directory is created if it doesn't exist.""" mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe - + output_dir = tmp_path / "nested" / "output" / "dir" assert not output_dir.exists() - + run_batch( batch_file=batch_file, model_path="model", output_dir=output_dir, resume=False, ) - + assert output_dir.exists() @patch.object(batch_inference, "load_pipeline") @@ -865,64 +857,68 @@ class TestRunBatch: """Test that state is saved after processing.""" mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe - + output_dir = tmp_path / "output" - + run_batch( batch_file=batch_file, model_path="model", output_dir=output_dir, resume=True, ) - + state_file = output_dir / ".batch_state.json" assert state_file.exists() - + with open(state_file) as f: state_data = json.load(f) - + assert len(state_data["completed"]) == 3 @patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "generate_single_video") - def test_run_batch_handles_errors(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + def test_run_batch_handles_errors( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): """Test that batch continues after individual job failures.""" mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe - + # Make the second call raise an exception mock_generate.side_effect = [None, RuntimeError("CUDA OOM"), None] - + output_dir = tmp_path / "output" - + summary = run_batch( batch_file=batch_file, model_path="model", output_dir=output_dir, resume=False, ) - + assert summary["completed"] == 2 assert summary["failed"] == 1 assert mock_generate.call_count == 3 @patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "generate_single_video") - def test_run_batch_logs_errors_to_file(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + def test_run_batch_logs_errors_to_file( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): """Test that errors are logged to errors.log.""" mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe mock_generate.side_effect = RuntimeError("Test error") - + output_dir = tmp_path / "output" - + run_batch( batch_file=batch_file, model_path="model", output_dir=output_dir, resume=False, ) - + error_log = output_dir / "errors.log" assert error_log.exists() content = error_log.read_text() @@ -930,14 +926,16 @@ class TestRunBatch: @patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "generate_single_video") - def test_run_batch_resume_skips_completed(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + def test_run_batch_resume_skips_completed( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): """Test that resume skips already-completed jobs.""" mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe - + output_dir = tmp_path / "output" output_dir.mkdir() - + # Create a state file showing first job completed state = BatchState( batch_file=str(batch_file), @@ -947,14 +945,14 @@ class TestRunBatch: completed=["cat_piano.mp4"], ) state.save(output_dir / ".batch_state.json") - + summary = run_batch( batch_file=batch_file, model_path="model", output_dir=output_dir, resume=True, ) - + # Should only process 2 jobs (skipping the completed one) assert summary["total"] == 2 assert summary["skipped"] == 1 @@ -965,16 +963,16 @@ class TestRunBatch: """Test running batch with empty file.""" batch_file = tmp_path / "empty.jsonl" batch_file.write_text("") - + output_dir = tmp_path / "output" - + summary = run_batch( batch_file=batch_file, model_path="model", output_dir=output_dir, resume=False, ) - + assert summary["total"] == 0 # Pipeline should not be loaded for empty batch mock_load_pipeline.assert_not_called() @@ -984,18 +982,15 @@ class TestRunBatch: def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path): """Test multi-GPU job distribution in run_batch.""" # Create batch file with 6 jobs - jobs = [ - {"prompt": f"Job {i}", "output_name": f"video{i}.mp4"} - for i in range(6) - ] + jobs = [{"prompt": f"Job {i}", "output_name": f"video{i}.mp4"} for i in range(6)] batch_file = tmp_path / "batch.jsonl" batch_file.write_text("\n".join(json.dumps(j) for j in jobs)) - + mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe - + output_dir = tmp_path / "output" - + # GPU 0 of 3 should get jobs 0, 3 (indices 0, 3) summary = run_batch( batch_file=batch_file, @@ -1005,25 +1000,27 @@ class TestRunBatch: num_gpus=3, resume=False, ) - + assert summary["total"] == 2 # GPU 0 gets 2 jobs assert mock_generate.call_count == 2 @patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "generate_single_video") - def test_run_batch_state_persists_after_each_job(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + def test_run_batch_state_persists_after_each_job( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): """Test that state is saved after each job, not just at the end.""" mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe - + # Track how many times state file is written state_writes = [] original_save = BatchState.save - + def tracking_save(self, state_file): original_save(self, state_file) state_writes.append(len(self.completed)) - + with patch.object(BatchState, 'save', tracking_save): output_dir = tmp_path / "output" run_batch( @@ -1032,7 +1029,7 @@ class TestRunBatch: output_dir=output_dir, resume=True, ) - + # State should be saved 3 times (once per job) assert len(state_writes) == 3 assert state_writes == [1, 2, 3] # Progressively more completed jobs @@ -1086,7 +1083,7 @@ class TestEdgeCases: "height": 1080, } job = BatchJob.from_dict(data) - + assert job.width == 1920 assert job.height == 1080 assert job.num_frames == 100 @@ -1100,32 +1097,34 @@ class TestEdgeCases: "another_field": 123, } job = BatchJob.from_dict(data) - + assert job.prompt == "Test" assert job.output_name == "test.mp4" # Extra fields should be ignored without error @patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "generate_single_video") - def test_run_batch_corrupted_state_starts_fresh(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): + def test_run_batch_corrupted_state_starts_fresh( + self, mock_generate, mock_load_pipeline, batch_file, tmp_path + ): """Test that corrupted state file doesn't prevent batch from running.""" mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe - + output_dir = tmp_path / "output" output_dir.mkdir() - + # Create corrupted state file state_file = output_dir / ".batch_state.json" state_file.write_text("{corrupted: json") - + summary = run_batch( batch_file=batch_file, model_path="model", output_dir=output_dir, resume=True, ) - + # Should start fresh and process all jobs assert summary["total"] == 3 assert mock_generate.call_count == 3 @@ -1138,12 +1137,12 @@ class TestEdgeCases: model_path="model", generate_type="t2v", ) - + assert state.updated_at == "" - + state_file = tmp_path / "state.json" state.save(state_file) - + # updated_at should be set after save assert state.updated_at != "" @@ -1154,16 +1153,16 @@ class TestEdgeCases: mock_pipe = MagicMock() mock_load_pipeline.return_value = mock_pipe mock_generate.side_effect = RuntimeError("All fail") - + output_dir = tmp_path / "output" - + summary = run_batch( batch_file=batch_file, model_path="model", output_dir=output_dir, resume=False, ) - + assert summary["completed"] == 0 assert summary["failed"] == 3 @@ -1174,9 +1173,9 @@ class TestEdgeCases: {"prompt": "Café résumé naïve", "output_name": "unicode3.mp4"}""" batch_file = tmp_path / "unicode.jsonl" batch_file.write_text(content, encoding="utf-8") - + jobs = load_batch_file(batch_file) - + assert len(jobs) == 3 assert jobs[0].prompt == "猫がピアノを弾いている" assert "🌊" in jobs[1].prompt @@ -1187,8 +1186,8 @@ class TestEdgeCases: content = json.dumps({"prompt": long_prompt, "output_name": "long.mp4"}) batch_file = tmp_path / "long.jsonl" batch_file.write_text(content) - + jobs = load_batch_file(batch_file) - + assert len(jobs) == 1 assert len(jobs[0].prompt) > 4000 diff --git a/tools/batch_inference.py b/tools/batch_inference.py index e231f78..2771fac 100644 --- a/tools/batch_inference.py +++ b/tools/batch_inference.py @@ -107,6 +107,7 @@ RESOLUTION_MAP = { @dataclass class BatchJob: """Represents a single job in the batch.""" + prompt: str output_name: str image_path: Optional[str] = None @@ -117,12 +118,12 @@ class BatchJob: seed: Optional[int] = None width: Optional[int] = None height: Optional[int] = None - + # Internal fields line_number: int = 0 status: str = "pending" error: Optional[str] = None - + @classmethod def from_dict(cls, data: Dict[str, Any], line_number: int = 0) -> "BatchJob": """Create a BatchJob from a dictionary.""" @@ -139,7 +140,7 @@ class BatchJob: height=data.get("height"), line_number=line_number, ) - + def validate(self) -> List[str]: """Validate the job and return list of errors.""" errors = [] @@ -153,6 +154,7 @@ class BatchJob: @dataclass class BatchState: """Tracks batch progress for resume capability.""" + batch_file: str output_dir: str model_path: str @@ -161,7 +163,7 @@ class BatchState: failed: List[Dict[str, str]] = field(default_factory=list) started_at: str = "" updated_at: str = "" - + @classmethod def load(cls, state_file: Path) -> Optional["BatchState"]: """Load state from file.""" @@ -174,22 +176,22 @@ class BatchState: except Exception as e: logger.warning(f"Failed to load state file: {e}") return None - + def save(self, state_file: Path): """Save state to file.""" self.updated_at = datetime.now().isoformat() with open(state_file, "w") as f: json.dump(self.__dict__, f, indent=2) - + def mark_completed(self, output_name: str): """Mark a job as completed.""" if output_name not in self.completed: self.completed.append(output_name) - + def mark_failed(self, output_name: str, error: str): """Mark a job as failed.""" self.failed.append({"output_name": output_name, "error": error}) - + def is_completed(self, output_name: str) -> bool: """Check if a job was already completed.""" return output_name in self.completed @@ -198,37 +200,37 @@ class BatchState: def load_batch_file(batch_file: Path) -> List[BatchJob]: """ Load and parse a JSONL batch file. - + Args: batch_file: Path to the JSONL file - + Returns: List of BatchJob objects """ jobs = [] - + with open(batch_file, "r") as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line or line.startswith("#"): continue - + try: data = json.loads(line) job = BatchJob.from_dict(data, line_number=line_num) - + # Validate job errors = job.validate() if errors: logger.warning(f"Line {line_num}: Invalid job - {', '.join(errors)}") continue - + jobs.append(job) - + except json.JSONDecodeError as e: logger.warning(f"Line {line_num}: Invalid JSON - {e}") continue - + return jobs @@ -241,26 +243,26 @@ def load_pipeline( ): """ Load the appropriate pipeline for the generation type. - + Args: model_path: Path to the model generate_type: Type of generation (t2v, i2v, v2v) dtype: Data type for computation lora_path: Optional path to LoRA weights enable_cpu_offload: Whether to enable CPU offloading - + Returns: The loaded pipeline """ logger.info(f"Loading pipeline for {generate_type} from {model_path}") - + if generate_type == "i2v": pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) elif generate_type == "t2v": pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype) else: # v2v pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) - + # Load LoRA weights if provided if lora_path: logger.info(f"Loading LoRA weights from {lora_path}") @@ -268,21 +270,21 @@ def load_pipeline( lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="batch_lora" ) pipe.fuse_lora(components=["transformer"], lora_scale=1.0) - + # Set scheduler pipe.scheduler = CogVideoXDPMScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing" ) - + # Enable memory optimizations if enable_cpu_offload: pipe.enable_sequential_cpu_offload() else: pipe.to("cuda") - + pipe.vae.enable_slicing() pipe.vae.enable_tiling() - + logger.info("Pipeline loaded successfully") return pipe @@ -301,7 +303,7 @@ def generate_single_video( ): """ Generate a single video from a job. - + Args: pipe: The loaded pipeline job: The batch job to process @@ -315,17 +317,17 @@ def generate_single_video( desired_resolution = RESOLUTION_MAP.get(model_name.lower(), (480, 720)) height = job.height if job.height else desired_resolution[0] width = job.width if job.width else desired_resolution[1] - + # Use job-specific or default values num_frames = job.num_frames or default_num_frames guidance_scale = job.guidance_scale or default_guidance_scale num_inference_steps = job.num_inference_steps or default_num_inference_steps seed = job.seed or default_seed - + # Load image/video if needed image = None video = None - + if generate_type == "i2v": if not job.image_path: raise ValueError("image_path is required for i2v generation") @@ -334,10 +336,10 @@ def generate_single_video( if not job.video_path: raise ValueError("video_path is required for v2v generation") video = load_video(job.video_path) - + # Generate video generator = torch.Generator().manual_seed(seed) - + if generate_type == "i2v": video_frames = pipe( height=height, @@ -376,7 +378,7 @@ def generate_single_video( guidance_scale=guidance_scale, generator=generator, ).frames[0] - + # Export video export_to_video(video_frames, str(output_path), fps=fps) @@ -400,7 +402,7 @@ def run_batch( ) -> Dict[str, Any]: """ Run batch inference on a JSONL file. - + Args: batch_file: Path to the JSONL batch file model_path: Path to the model @@ -414,7 +416,7 @@ def run_batch( num_gpus: Total number of GPUs for distribution default_*: Default values for optional parameters fps: Frames per second for output videos - + Returns: Summary dictionary with statistics """ @@ -422,26 +424,26 @@ def run_batch( output_dir.mkdir(parents=True, exist_ok=True) state_file = output_dir / ".batch_state.json" error_log = output_dir / "errors.log" - + # Load jobs logger.info(f"Loading batch file: {batch_file}") all_jobs = load_batch_file(batch_file) logger.info(f"Found {len(all_jobs)} valid jobs in batch file") - + # Distribute jobs across GPUs if using multi-GPU if num_gpus > 1: jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] logger.info(f"GPU {gpu_id}/{num_gpus}: Processing {len(jobs)} jobs") else: jobs = all_jobs - + # Load or create state state = None if resume: state = BatchState.load(state_file) if state: logger.info(f"Resuming batch: {len(state.completed)} already completed") - + if state is None: state = BatchState( batch_file=str(batch_file), @@ -450,7 +452,7 @@ def run_batch( generate_type=generate_type, started_at=datetime.now().isoformat(), ) - + # Filter out completed jobs if resume: pending_jobs = [j for j in jobs if not state.is_completed(j.output_name)] @@ -458,11 +460,11 @@ def run_batch( if skipped > 0: logger.info(f"Skipping {skipped} already-completed jobs") jobs = pending_jobs - + if not jobs: logger.info("No jobs to process") return {"total": 0, "completed": 0, "failed": 0, "skipped": len(all_jobs)} - + # Load pipeline model_name = model_path.split("/")[-1] pipe = load_pipeline( @@ -472,19 +474,19 @@ def run_batch( lora_path=lora_path, enable_cpu_offload=enable_cpu_offload, ) - + # Process jobs completed = 0 failed = 0 start_time = time.time() - + with tqdm(total=len(jobs), desc="Generating videos", unit="video") as pbar: for job in jobs: output_path = output_dir / job.output_name - + try: logger.info(f"Processing: {job.output_name} - \"{job.prompt[:50]}...\"") - + generate_single_video( pipe=pipe, job=job, @@ -497,15 +499,15 @@ def run_batch( default_seed=default_seed, fps=fps, ) - + state.mark_completed(job.output_name) completed += 1 logger.info(f"Completed: {job.output_name}") - + except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" logger.error(f"Failed: {job.output_name} - {error_msg}") - + # Log full traceback to error log with open(error_log, "a") as f: f.write(f"\n{'='*60}\n") @@ -514,26 +516,24 @@ def run_batch( f.write(f"Time: {datetime.now().isoformat()}\n") f.write(f"Error: {error_msg}\n") f.write(traceback.format_exc()) - + state.mark_failed(job.output_name, error_msg) failed += 1 - + # Save state after each job (for resume) state.save(state_file) pbar.update(1) - + # Update ETA in progress bar elapsed = time.time() - start_time if completed + failed > 0: avg_time = elapsed / (completed + failed) remaining = len(jobs) - (completed + failed) eta_seconds = avg_time * remaining - pbar.set_postfix({ - "done": completed, - "failed": failed, - "ETA": f"{eta_seconds/60:.1f}m" - }) - + pbar.set_postfix( + {"done": completed, "failed": failed, "ETA": f"{eta_seconds/60:.1f}m"} + ) + # Final summary elapsed_total = time.time() - start_time summary = { @@ -544,7 +544,7 @@ def run_batch( "elapsed_seconds": elapsed_total, "avg_seconds_per_video": elapsed_total / max(completed + failed, 1), } - + logger.info("=" * 60) logger.info("BATCH COMPLETE") logger.info(f" Total jobs: {summary['total']}") @@ -554,7 +554,7 @@ def run_batch( if summary['failed'] > 0: logger.info(f" See errors in: {error_log}") logger.info("=" * 60) - + return summary @@ -566,11 +566,11 @@ def main(): Examples: # Basic text-to-video batch python batch_inference.py --batch_file prompts.jsonl --model_path THUDM/CogVideoX1.5-5B - + # Image-to-video batch with custom output directory python batch_inference.py --batch_file i2v.jsonl --model_path THUDM/CogVideoX1.5-5B-I2V \\ --generate_type i2v --output_dir ./my_videos - + # Multi-GPU: run on 4 GPUs (one process per GPU) for i in {0..3}; do CUDA_VISIBLE_DEVICES=$i python batch_inference.py --batch_file batch.jsonl \\ @@ -579,128 +579,91 @@ Examples: JSONL Format: Each line is a JSON object with: prompt (required), output_name (required), - and optional: image_path, video_path, num_frames, guidance_scale, + and optional: image_path, video_path, num_frames, guidance_scale, num_inference_steps, seed, width, height - """ + """, ) - + # Required arguments - parser.add_argument( - "--batch_file", - type=str, - required=True, - help="Path to JSONL batch file" - ) + parser.add_argument("--batch_file", type=str, required=True, help="Path to JSONL batch file") parser.add_argument( "--model_path", type=str, default="THUDM/CogVideoX1.5-5B", - help="Path to the model (default: THUDM/CogVideoX1.5-5B)" + help="Path to the model (default: THUDM/CogVideoX1.5-5B)", ) - + # Output settings parser.add_argument( "--output_dir", type=str, default="./batch_output", - help="Directory for output videos (default: ./batch_output)" + help="Directory for output videos (default: ./batch_output)", ) parser.add_argument( "--generate_type", type=str, choices=["t2v", "i2v", "v2v"], default="t2v", - help="Generation type (default: t2v)" + help="Generation type (default: t2v)", ) - + # Model settings parser.add_argument( - "--lora_path", - type=str, - default=None, - help="Path to LoRA weights (optional)" + "--lora_path", type=str, default=None, help="Path to LoRA weights (optional)" ) parser.add_argument( "--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16", - help="Data type for computation (default: bfloat16)" + help="Data type for computation (default: bfloat16)", ) parser.add_argument( "--disable_cpu_offload", action="store_true", - help="Disable CPU offloading (uses more VRAM but faster)" + help="Disable CPU offloading (uses more VRAM but faster)", ) - + # Default generation parameters parser.add_argument( - "--num_frames", - type=int, - default=81, - help="Default number of frames (default: 81)" + "--num_frames", type=int, default=81, help="Default number of frames (default: 81)" ) parser.add_argument( - "--guidance_scale", - type=float, - default=6.0, - help="Default guidance scale (default: 6.0)" + "--guidance_scale", type=float, default=6.0, help="Default guidance scale (default: 6.0)" ) parser.add_argument( - "--num_inference_steps", - type=int, - default=50, - help="Default inference steps (default: 50)" + "--num_inference_steps", type=int, default=50, help="Default inference steps (default: 50)" ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="Default random seed (default: 42)" - ) - parser.add_argument( - "--fps", - type=int, - default=16, - help="Output video FPS (default: 16)" - ) - + parser.add_argument("--seed", type=int, default=42, help="Default random seed (default: 42)") + parser.add_argument("--fps", type=int, default=16, help="Output video FPS (default: 16)") + # Resume and multi-GPU parser.add_argument( "--resume", action="store_true", default=True, - help="Resume from previous state (default: True)" + help="Resume from previous state (default: True)", + ) + parser.add_argument("--no_resume", action="store_true", help="Don't resume, start fresh") + parser.add_argument( + "--gpu_id", type=int, default=0, help="GPU ID for multi-GPU distribution (default: 0)" ) parser.add_argument( - "--no_resume", - action="store_true", - help="Don't resume, start fresh" + "--num_gpus", type=int, default=1, help="Total number of GPUs for distribution (default: 1)" ) - parser.add_argument( - "--gpu_id", - type=int, - default=0, - help="GPU ID for multi-GPU distribution (default: 0)" - ) - parser.add_argument( - "--num_gpus", - type=int, - default=1, - help="Total number of GPUs for distribution (default: 1)" - ) - + args = parser.parse_args() - + # Validate batch file exists batch_file = Path(args.batch_file) if not batch_file.exists(): logger.error(f"Batch file not found: {batch_file}") sys.exit(1) - + # Parse dtype dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 - + # Run batch try: summary = run_batch( @@ -720,11 +683,11 @@ JSONL Format: default_seed=args.seed, fps=args.fps, ) - + # Exit with error code if any failures if summary["failed"] > 0: sys.exit(1) - + except KeyboardInterrupt: logger.info("\nBatch interrupted by user. Progress saved for resume.") sys.exit(130)