chore: allow tests/ directory in gitignore

This commit is contained in:
Test User 2026-02-19 03:36:33 +00:00
parent 15a4b403c4
commit 206760830a
3 changed files with 244 additions and 280 deletions

2
.gitignore vendored
View File

@ -8,6 +8,8 @@ logs/
.idea .idea
output* output*
test* test*
!tests/
!tests/**
venv venv
**/.swp **/.swp
**/*.log **/*.log

View File

@ -34,28 +34,30 @@ mock_torch.Generator = MagicMock
# Create mock diffusers module # Create mock diffusers module
mock_diffusers = MagicMock() mock_diffusers = MagicMock()
# Create mock tqdm that works as a context manager # Create mock tqdm that works as a context manager
class MockTqdm: class MockTqdm:
"""Mock tqdm that works as a context manager and returns a dummy progress bar.""" """Mock tqdm that works as a context manager and returns a dummy progress bar."""
def __init__(self, iterable=None, total=None, **kwargs): def __init__(self, iterable=None, total=None, **kwargs):
self.iterable = iterable self.iterable = iterable
self.n = 0 self.n = 0
def __iter__(self): def __iter__(self):
if self.iterable is not None: if self.iterable is not None:
for item in self.iterable: for item in self.iterable:
yield item yield item
self.n += 1 self.n += 1
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, *args): def __exit__(self, *args):
pass pass
def update(self, n=1): def update(self, n=1):
self.n += n self.n += n
def set_postfix(self, ordered_dict=None, refresh=True, **kwargs): def set_postfix(self, ordered_dict=None, refresh=True, **kwargs):
pass pass
@ -172,7 +174,7 @@ class TestBatchJob:
def test_from_dict_basic(self, sample_job_dict): def test_from_dict_basic(self, sample_job_dict):
"""Test creating a BatchJob from a dictionary.""" """Test creating a BatchJob from a dictionary."""
job = BatchJob.from_dict(sample_job_dict, line_number=1) job = BatchJob.from_dict(sample_job_dict, line_number=1)
assert job.prompt == "A cat playing piano" assert job.prompt == "A cat playing piano"
assert job.output_name == "cat_piano.mp4" assert job.output_name == "cat_piano.mp4"
assert job.num_frames == 49 assert job.num_frames == 49
@ -185,7 +187,7 @@ class TestBatchJob:
"""Test creating a BatchJob with only required fields.""" """Test creating a BatchJob with only required fields."""
data = {"prompt": "Test prompt", "output_name": "test.mp4"} data = {"prompt": "Test prompt", "output_name": "test.mp4"}
job = BatchJob.from_dict(data) job = BatchJob.from_dict(data)
assert job.prompt == "Test prompt" assert job.prompt == "Test prompt"
assert job.output_name == "test.mp4" assert job.output_name == "test.mp4"
assert job.num_frames is None assert job.num_frames is None
@ -195,14 +197,14 @@ class TestBatchJob:
def test_from_dict_with_image_path(self, sample_i2v_job_dict): def test_from_dict_with_image_path(self, sample_i2v_job_dict):
"""Test creating a BatchJob with image_path for i2v.""" """Test creating a BatchJob with image_path for i2v."""
job = BatchJob.from_dict(sample_i2v_job_dict) job = BatchJob.from_dict(sample_i2v_job_dict)
assert job.image_path == "/path/to/image.jpg" assert job.image_path == "/path/to/image.jpg"
assert job.video_path is None assert job.video_path is None
def test_from_dict_with_video_path(self, sample_v2v_job_dict): def test_from_dict_with_video_path(self, sample_v2v_job_dict):
"""Test creating a BatchJob with video_path for v2v.""" """Test creating a BatchJob with video_path for v2v."""
job = BatchJob.from_dict(sample_v2v_job_dict) job = BatchJob.from_dict(sample_v2v_job_dict)
assert job.video_path == "/path/to/video.mp4" assert job.video_path == "/path/to/video.mp4"
assert job.image_path is None assert job.image_path is None
@ -210,28 +212,28 @@ class TestBatchJob:
"""Test validation of a valid job returns no errors.""" """Test validation of a valid job returns no errors."""
job = BatchJob.from_dict(sample_job_dict) job = BatchJob.from_dict(sample_job_dict)
errors = job.validate() errors = job.validate()
assert errors == [] assert errors == []
def test_validate_missing_prompt(self): def test_validate_missing_prompt(self):
"""Test validation catches missing prompt.""" """Test validation catches missing prompt."""
job = BatchJob.from_dict({"output_name": "test.mp4"}) job = BatchJob.from_dict({"output_name": "test.mp4"})
errors = job.validate() errors = job.validate()
assert "Missing required field: prompt" in errors assert "Missing required field: prompt" in errors
def test_validate_missing_output_name(self): def test_validate_missing_output_name(self):
"""Test validation catches missing output_name.""" """Test validation catches missing output_name."""
job = BatchJob.from_dict({"prompt": "Test"}) job = BatchJob.from_dict({"prompt": "Test"})
errors = job.validate() errors = job.validate()
assert "Missing required field: output_name" in errors assert "Missing required field: output_name" in errors
def test_validate_missing_both_required(self): def test_validate_missing_both_required(self):
"""Test validation catches both missing required fields.""" """Test validation catches both missing required fields."""
job = BatchJob.from_dict({}) job = BatchJob.from_dict({})
errors = job.validate() errors = job.validate()
assert len(errors) == 2 assert len(errors) == 2
assert "Missing required field: prompt" in errors assert "Missing required field: prompt" in errors
assert "Missing required field: output_name" in errors assert "Missing required field: output_name" in errors
@ -253,7 +255,7 @@ class TestBatchState:
model_path="THUDM/CogVideoX1.5-5B", model_path="THUDM/CogVideoX1.5-5B",
generate_type="t2v", generate_type="t2v",
) )
assert state.batch_file == "batch.jsonl" assert state.batch_file == "batch.jsonl"
assert state.completed == [] assert state.completed == []
assert state.failed == [] assert state.failed == []
@ -266,10 +268,10 @@ class TestBatchState:
model_path="model", model_path="model",
generate_type="t2v", generate_type="t2v",
) )
state.mark_completed("video1.mp4") state.mark_completed("video1.mp4")
state.mark_completed("video2.mp4") state.mark_completed("video2.mp4")
assert "video1.mp4" in state.completed assert "video1.mp4" in state.completed
assert "video2.mp4" in state.completed assert "video2.mp4" in state.completed
@ -281,10 +283,10 @@ class TestBatchState:
model_path="model", model_path="model",
generate_type="t2v", generate_type="t2v",
) )
state.mark_completed("video1.mp4") state.mark_completed("video1.mp4")
state.mark_completed("video1.mp4") state.mark_completed("video1.mp4")
assert state.completed.count("video1.mp4") == 1 assert state.completed.count("video1.mp4") == 1
def test_mark_failed(self): def test_mark_failed(self):
@ -295,9 +297,9 @@ class TestBatchState:
model_path="model", model_path="model",
generate_type="t2v", generate_type="t2v",
) )
state.mark_failed("video1.mp4", "CUDA out of memory") state.mark_failed("video1.mp4", "CUDA out of memory")
assert len(state.failed) == 1 assert len(state.failed) == 1
assert state.failed[0]["output_name"] == "video1.mp4" assert state.failed[0]["output_name"] == "video1.mp4"
assert state.failed[0]["error"] == "CUDA out of memory" assert state.failed[0]["error"] == "CUDA out of memory"
@ -310,16 +312,16 @@ class TestBatchState:
model_path="model", model_path="model",
generate_type="t2v", generate_type="t2v",
) )
state.mark_completed("video1.mp4") state.mark_completed("video1.mp4")
assert state.is_completed("video1.mp4") is True assert state.is_completed("video1.mp4") is True
assert state.is_completed("video2.mp4") is False assert state.is_completed("video2.mp4") is False
def test_save_and_load(self, tmp_path): def test_save_and_load(self, tmp_path):
"""Test saving and loading state.""" """Test saving and loading state."""
state_file = tmp_path / ".batch_state.json" state_file = tmp_path / ".batch_state.json"
# Create and save state # Create and save state
state = BatchState( state = BatchState(
batch_file="batch.jsonl", batch_file="batch.jsonl",
@ -331,10 +333,10 @@ class TestBatchState:
state.mark_completed("video1.mp4") state.mark_completed("video1.mp4")
state.mark_failed("video2.mp4", "Error") state.mark_failed("video2.mp4", "Error")
state.save(state_file) state.save(state_file)
# Load state # Load state
loaded = BatchState.load(state_file) loaded = BatchState.load(state_file)
assert loaded is not None assert loaded is not None
assert loaded.batch_file == "batch.jsonl" assert loaded.batch_file == "batch.jsonl"
assert "video1.mp4" in loaded.completed assert "video1.mp4" in loaded.completed
@ -344,27 +346,27 @@ class TestBatchState:
def test_load_nonexistent_file(self, tmp_path): def test_load_nonexistent_file(self, tmp_path):
"""Test loading from a nonexistent file returns None.""" """Test loading from a nonexistent file returns None."""
state_file = tmp_path / "nonexistent.json" state_file = tmp_path / "nonexistent.json"
loaded = BatchState.load(state_file) loaded = BatchState.load(state_file)
assert loaded is None assert loaded is None
def test_load_corrupted_file(self, tmp_path): def test_load_corrupted_file(self, tmp_path):
"""Test loading a corrupted state file returns None.""" """Test loading a corrupted state file returns None."""
state_file = tmp_path / ".batch_state.json" state_file = tmp_path / ".batch_state.json"
state_file.write_text("{ invalid json }") state_file.write_text("{ invalid json }")
loaded = BatchState.load(state_file) loaded = BatchState.load(state_file)
assert loaded is None assert loaded is None
def test_load_invalid_json_structure(self, tmp_path): def test_load_invalid_json_structure(self, tmp_path):
"""Test loading a file with valid JSON but invalid structure.""" """Test loading a file with valid JSON but invalid structure."""
state_file = tmp_path / ".batch_state.json" state_file = tmp_path / ".batch_state.json"
state_file.write_text('{"wrong": "structure"}') state_file.write_text('{"wrong": "structure"}')
loaded = BatchState.load(state_file) loaded = BatchState.load(state_file)
# Should return None due to missing required fields # Should return None due to missing required fields
assert loaded is None assert loaded is None
@ -380,7 +382,7 @@ class TestLoadBatchFile:
def test_load_valid_jsonl(self, batch_file): def test_load_valid_jsonl(self, batch_file):
"""Test loading a valid JSONL file.""" """Test loading a valid JSONL file."""
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
assert len(jobs) == 3 assert len(jobs) == 3
assert jobs[0].prompt == "A cat playing piano" assert jobs[0].prompt == "A cat playing piano"
assert jobs[1].output_name == "beach.mp4" assert jobs[1].output_name == "beach.mp4"
@ -394,9 +396,9 @@ class TestLoadBatchFile:
{"prompt": "Test2", "output_name": "test2.mp4"}""" {"prompt": "Test2", "output_name": "test2.mp4"}"""
batch_file = tmp_path / "batch.jsonl" batch_file = tmp_path / "batch.jsonl"
batch_file.write_text(content) batch_file.write_text(content)
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
assert len(jobs) == 2 assert len(jobs) == 2
def test_load_jsonl_with_empty_lines(self, tmp_path): def test_load_jsonl_with_empty_lines(self, tmp_path):
@ -408,9 +410,9 @@ class TestLoadBatchFile:
""" """
batch_file = tmp_path / "batch.jsonl" batch_file = tmp_path / "batch.jsonl"
batch_file.write_text(content) batch_file.write_text(content)
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
assert len(jobs) == 2 assert len(jobs) == 2
def test_load_jsonl_invalid_json_line(self, tmp_path): def test_load_jsonl_invalid_json_line(self, tmp_path):
@ -420,9 +422,9 @@ class TestLoadBatchFile:
{"prompt": "Test2", "output_name": "test2.mp4"}""" {"prompt": "Test2", "output_name": "test2.mp4"}"""
batch_file = tmp_path / "batch.jsonl" batch_file = tmp_path / "batch.jsonl"
batch_file.write_text(content) batch_file.write_text(content)
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
# Should skip the invalid line # Should skip the invalid line
assert len(jobs) == 2 assert len(jobs) == 2
@ -434,9 +436,9 @@ class TestLoadBatchFile:
{"prompt": "Test2", "output_name": "test2.mp4"}""" {"prompt": "Test2", "output_name": "test2.mp4"}"""
batch_file = tmp_path / "batch.jsonl" batch_file = tmp_path / "batch.jsonl"
batch_file.write_text(content) batch_file.write_text(content)
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
# Should skip the two invalid jobs # Should skip the two invalid jobs
assert len(jobs) == 2 assert len(jobs) == 2
assert jobs[0].output_name == "test.mp4" assert jobs[0].output_name == "test.mp4"
@ -446,9 +448,9 @@ class TestLoadBatchFile:
"""Test loading an empty batch file.""" """Test loading an empty batch file."""
batch_file = tmp_path / "empty.jsonl" batch_file = tmp_path / "empty.jsonl"
batch_file.write_text("") batch_file.write_text("")
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
assert len(jobs) == 0 assert len(jobs) == 0
def test_load_batch_file_all_comments(self, tmp_path): def test_load_batch_file_all_comments(self, tmp_path):
@ -458,9 +460,9 @@ class TestLoadBatchFile:
# Comment 3""" # Comment 3"""
batch_file = tmp_path / "comments.jsonl" batch_file = tmp_path / "comments.jsonl"
batch_file.write_text(content) batch_file.write_text(content)
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
assert len(jobs) == 0 assert len(jobs) == 0
def test_load_batch_file_tracks_line_numbers(self, tmp_path): def test_load_batch_file_tracks_line_numbers(self, tmp_path):
@ -471,9 +473,9 @@ class TestLoadBatchFile:
{"prompt": "Line 4", "output_name": "line4.mp4"}""" {"prompt": "Line 4", "output_name": "line4.mp4"}"""
batch_file = tmp_path / "batch.jsonl" batch_file = tmp_path / "batch.jsonl"
batch_file.write_text(content) batch_file.write_text(content)
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
assert jobs[0].line_number == 2 assert jobs[0].line_number == 2
assert jobs[1].line_number == 4 assert jobs[1].line_number == 4
@ -488,30 +490,24 @@ class TestMultiGPUDistribution:
def test_single_gpu_gets_all_jobs(self): def test_single_gpu_gets_all_jobs(self):
"""Test that a single GPU processes all jobs.""" """Test that a single GPU processes all jobs."""
all_jobs = [ all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
for i in range(10)
]
# Simulate distribution for single GPU (gpu_id=0, num_gpus=1) # Simulate distribution for single GPU (gpu_id=0, num_gpus=1)
gpu_id = 0 gpu_id = 0
num_gpus = 1 num_gpus = 1
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
assert len(jobs) == 10 assert len(jobs) == 10
def test_two_gpu_distribution(self): def test_two_gpu_distribution(self):
"""Test job distribution across 2 GPUs.""" """Test job distribution across 2 GPUs."""
all_jobs = [ all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
for i in range(10)
]
# GPU 0 gets jobs 0, 2, 4, 6, 8 # GPU 0 gets jobs 0, 2, 4, 6, 8
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0] gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0]
# GPU 1 gets jobs 1, 3, 5, 7, 9 # GPU 1 gets jobs 1, 3, 5, 7, 9
gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 1] gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 1]
assert len(gpu0_jobs) == 5 assert len(gpu0_jobs) == 5
assert len(gpu1_jobs) == 5 assert len(gpu1_jobs) == 5
assert gpu0_jobs[0].output_name == "video0.mp4" assert gpu0_jobs[0].output_name == "video0.mp4"
@ -519,11 +515,8 @@ class TestMultiGPUDistribution:
def test_four_gpu_distribution(self): def test_four_gpu_distribution(self):
"""Test job distribution across 4 GPUs.""" """Test job distribution across 4 GPUs."""
all_jobs = [ all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(12)]
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
for i in range(12)
]
num_gpus = 4 num_gpus = 4
for gpu_id in range(num_gpus): for gpu_id in range(num_gpus):
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
@ -531,16 +524,13 @@ class TestMultiGPUDistribution:
def test_uneven_distribution(self): def test_uneven_distribution(self):
"""Test distribution when jobs don't divide evenly.""" """Test distribution when jobs don't divide evenly."""
all_jobs = [ all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
for i in range(10)
]
num_gpus = 3 num_gpus = 3
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0] gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0]
gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 1] gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 1]
gpu2_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 2] gpu2_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 2]
# 10 jobs / 3 GPUs: GPU0 gets 4, GPU1 gets 3, GPU2 gets 3 # 10 jobs / 3 GPUs: GPU0 gets 4, GPU1 gets 3, GPU2 gets 3
assert len(gpu0_jobs) == 4 # indices 0, 3, 6, 9 assert len(gpu0_jobs) == 4 # indices 0, 3, 6, 9
assert len(gpu1_jobs) == 3 # indices 1, 4, 7 assert len(gpu1_jobs) == 3 # indices 1, 4, 7
@ -562,14 +552,14 @@ class TestLoadPipeline:
mock_pipe.scheduler = MagicMock() mock_pipe.scheduler = MagicMock()
mock_pipe.vae = MagicMock() mock_pipe.vae = MagicMock()
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
with patch.object(batch_inference, "CogVideoXDPMScheduler"): with patch.object(batch_inference, "CogVideoXDPMScheduler"):
pipe = load_pipeline( pipe = load_pipeline(
model_path="THUDM/CogVideoX1.5-5B", model_path="THUDM/CogVideoX1.5-5B",
generate_type="t2v", generate_type="t2v",
enable_cpu_offload=False, enable_cpu_offload=False,
) )
mock_cogvideo_pipeline.from_pretrained.assert_called_once() mock_cogvideo_pipeline.from_pretrained.assert_called_once()
assert pipe is mock_pipe assert pipe is mock_pipe
@ -580,14 +570,14 @@ class TestLoadPipeline:
mock_pipe.scheduler = MagicMock() mock_pipe.scheduler = MagicMock()
mock_pipe.vae = MagicMock() mock_pipe.vae = MagicMock()
mock_i2v_pipeline.from_pretrained.return_value = mock_pipe mock_i2v_pipeline.from_pretrained.return_value = mock_pipe
with patch.object(batch_inference, "CogVideoXDPMScheduler"): with patch.object(batch_inference, "CogVideoXDPMScheduler"):
pipe = load_pipeline( pipe = load_pipeline(
model_path="THUDM/CogVideoX1.5-5B-I2V", model_path="THUDM/CogVideoX1.5-5B-I2V",
generate_type="i2v", generate_type="i2v",
enable_cpu_offload=False, enable_cpu_offload=False,
) )
mock_i2v_pipeline.from_pretrained.assert_called_once() mock_i2v_pipeline.from_pretrained.assert_called_once()
assert pipe is mock_pipe assert pipe is mock_pipe
@ -598,14 +588,14 @@ class TestLoadPipeline:
mock_pipe.scheduler = MagicMock() mock_pipe.scheduler = MagicMock()
mock_pipe.vae = MagicMock() mock_pipe.vae = MagicMock()
mock_v2v_pipeline.from_pretrained.return_value = mock_pipe mock_v2v_pipeline.from_pretrained.return_value = mock_pipe
with patch.object(batch_inference, "CogVideoXDPMScheduler"): with patch.object(batch_inference, "CogVideoXDPMScheduler"):
pipe = load_pipeline( pipe = load_pipeline(
model_path="THUDM/CogVideoX1.5-5B", model_path="THUDM/CogVideoX1.5-5B",
generate_type="v2v", generate_type="v2v",
enable_cpu_offload=False, enable_cpu_offload=False,
) )
mock_v2v_pipeline.from_pretrained.assert_called_once() mock_v2v_pipeline.from_pretrained.assert_called_once()
assert pipe is mock_pipe assert pipe is mock_pipe
@ -616,14 +606,14 @@ class TestLoadPipeline:
mock_pipe.scheduler = MagicMock() mock_pipe.scheduler = MagicMock()
mock_pipe.vae = MagicMock() mock_pipe.vae = MagicMock()
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
with patch.object(batch_inference, "CogVideoXDPMScheduler"): with patch.object(batch_inference, "CogVideoXDPMScheduler"):
load_pipeline( load_pipeline(
model_path="model", model_path="model",
generate_type="t2v", generate_type="t2v",
enable_cpu_offload=True, enable_cpu_offload=True,
) )
mock_pipe.enable_sequential_cpu_offload.assert_called_once() mock_pipe.enable_sequential_cpu_offload.assert_called_once()
mock_pipe.to.assert_not_called() mock_pipe.to.assert_not_called()
@ -634,14 +624,14 @@ class TestLoadPipeline:
mock_pipe.scheduler = MagicMock() mock_pipe.scheduler = MagicMock()
mock_pipe.vae = MagicMock() mock_pipe.vae = MagicMock()
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
with patch.object(batch_inference, "CogVideoXDPMScheduler"): with patch.object(batch_inference, "CogVideoXDPMScheduler"):
load_pipeline( load_pipeline(
model_path="model", model_path="model",
generate_type="t2v", generate_type="t2v",
enable_cpu_offload=False, enable_cpu_offload=False,
) )
mock_pipe.to.assert_called_once_with("cuda") mock_pipe.to.assert_called_once_with("cuda")
mock_pipe.enable_sequential_cpu_offload.assert_not_called() mock_pipe.enable_sequential_cpu_offload.assert_not_called()
@ -652,7 +642,7 @@ class TestLoadPipeline:
mock_pipe.scheduler = MagicMock() mock_pipe.scheduler = MagicMock()
mock_pipe.vae = MagicMock() mock_pipe.vae = MagicMock()
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
with patch.object(batch_inference, "CogVideoXDPMScheduler"): with patch.object(batch_inference, "CogVideoXDPMScheduler"):
load_pipeline( load_pipeline(
model_path="model", model_path="model",
@ -660,7 +650,7 @@ class TestLoadPipeline:
lora_path="/path/to/lora", lora_path="/path/to/lora",
enable_cpu_offload=False, enable_cpu_offload=False,
) )
mock_pipe.load_lora_weights.assert_called_once() mock_pipe.load_lora_weights.assert_called_once()
mock_pipe.fuse_lora.assert_called_once() mock_pipe.fuse_lora.assert_called_once()
@ -682,7 +672,7 @@ class TestGenerateSingleVideo:
num_frames=49, num_frames=49,
) )
output_path = tmp_path / "cat.mp4" output_path = tmp_path / "cat.mp4"
generate_single_video( generate_single_video(
pipe=mock_pipeline, pipe=mock_pipeline,
job=job, job=job,
@ -690,7 +680,7 @@ class TestGenerateSingleVideo:
model_name="cogvideox-5b", model_name="cogvideox-5b",
output_path=output_path, output_path=output_path,
) )
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
mock_export.assert_called_once() mock_export.assert_called_once()
@ -699,14 +689,14 @@ class TestGenerateSingleVideo:
def test_generate_i2v_video(self, mock_load_image, mock_export, mock_pipeline, tmp_path): def test_generate_i2v_video(self, mock_load_image, mock_export, mock_pipeline, tmp_path):
"""Test generating an image-to-video.""" """Test generating an image-to-video."""
mock_load_image.return_value = MagicMock() mock_load_image.return_value = MagicMock()
job = BatchJob( job = BatchJob(
prompt="Transform this", prompt="Transform this",
output_name="i2v.mp4", output_name="i2v.mp4",
image_path="/path/to/image.jpg", image_path="/path/to/image.jpg",
) )
output_path = tmp_path / "i2v.mp4" output_path = tmp_path / "i2v.mp4"
generate_single_video( generate_single_video(
pipe=mock_pipeline, pipe=mock_pipeline,
job=job, job=job,
@ -714,7 +704,7 @@ class TestGenerateSingleVideo:
model_name="cogvideox-5b-i2v", model_name="cogvideox-5b-i2v",
output_path=output_path, output_path=output_path,
) )
mock_load_image.assert_called_once_with(image="/path/to/image.jpg") mock_load_image.assert_called_once_with(image="/path/to/image.jpg")
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
@ -723,14 +713,14 @@ class TestGenerateSingleVideo:
def test_generate_v2v_video(self, mock_load_video, mock_export, mock_pipeline, tmp_path): def test_generate_v2v_video(self, mock_load_video, mock_export, mock_pipeline, tmp_path):
"""Test generating a video-to-video.""" """Test generating a video-to-video."""
mock_load_video.return_value = MagicMock() mock_load_video.return_value = MagicMock()
job = BatchJob( job = BatchJob(
prompt="Enhance this", prompt="Enhance this",
output_name="v2v.mp4", output_name="v2v.mp4",
video_path="/path/to/video.mp4", video_path="/path/to/video.mp4",
) )
output_path = tmp_path / "v2v.mp4" output_path = tmp_path / "v2v.mp4"
generate_single_video( generate_single_video(
pipe=mock_pipeline, pipe=mock_pipeline,
job=job, job=job,
@ -738,7 +728,7 @@ class TestGenerateSingleVideo:
model_name="cogvideox-5b", model_name="cogvideox-5b",
output_path=output_path, output_path=output_path,
) )
mock_load_video.assert_called_once_with("/path/to/video.mp4") mock_load_video.assert_called_once_with("/path/to/video.mp4")
mock_pipeline.assert_called_once() mock_pipeline.assert_called_once()
@ -750,7 +740,7 @@ class TestGenerateSingleVideo:
# Missing image_path # Missing image_path
) )
output_path = tmp_path / "i2v.mp4" output_path = tmp_path / "i2v.mp4"
with pytest.raises(ValueError, match="image_path is required"): with pytest.raises(ValueError, match="image_path is required"):
generate_single_video( generate_single_video(
pipe=mock_pipeline, pipe=mock_pipeline,
@ -768,7 +758,7 @@ class TestGenerateSingleVideo:
# Missing video_path # Missing video_path
) )
output_path = tmp_path / "v2v.mp4" output_path = tmp_path / "v2v.mp4"
with pytest.raises(ValueError, match="video_path is required"): with pytest.raises(ValueError, match="video_path is required"):
generate_single_video( generate_single_video(
pipe=mock_pipeline, pipe=mock_pipeline,
@ -790,7 +780,7 @@ class TestGenerateSingleVideo:
seed=456, seed=456,
) )
output_path = tmp_path / "test.mp4" output_path = tmp_path / "test.mp4"
generate_single_video( generate_single_video(
pipe=mock_pipeline, pipe=mock_pipeline,
job=job, job=job,
@ -802,7 +792,7 @@ class TestGenerateSingleVideo:
default_num_inference_steps=50, default_num_inference_steps=50,
default_seed=42, default_seed=42,
) )
# Check the pipeline was called with job-specific values # Check the pipeline was called with job-specific values
call_kwargs = mock_pipeline.call_args[1] call_kwargs = mock_pipeline.call_args[1]
assert call_kwargs["num_frames"] == 33 assert call_kwargs["num_frames"] == 33
@ -824,9 +814,9 @@ class TestRunBatch:
"""Test basic batch processing.""" """Test basic batch processing."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
summary = run_batch( summary = run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="THUDM/CogVideoX1.5-5B", model_path="THUDM/CogVideoX1.5-5B",
@ -834,7 +824,7 @@ class TestRunBatch:
generate_type="t2v", generate_type="t2v",
resume=False, resume=False,
) )
assert summary["total"] == 3 assert summary["total"] == 3
assert summary["completed"] == 3 assert summary["completed"] == 3
assert summary["failed"] == 0 assert summary["failed"] == 0
@ -842,21 +832,23 @@ class TestRunBatch:
@patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "load_pipeline")
@patch.object(batch_inference, "generate_single_video") @patch.object(batch_inference, "generate_single_video")
def test_run_batch_creates_output_dir(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): def test_run_batch_creates_output_dir(
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
):
"""Test that output directory is created if it doesn't exist.""" """Test that output directory is created if it doesn't exist."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
output_dir = tmp_path / "nested" / "output" / "dir" output_dir = tmp_path / "nested" / "output" / "dir"
assert not output_dir.exists() assert not output_dir.exists()
run_batch( run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="model", model_path="model",
output_dir=output_dir, output_dir=output_dir,
resume=False, resume=False,
) )
assert output_dir.exists() assert output_dir.exists()
@patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "load_pipeline")
@ -865,64 +857,68 @@ class TestRunBatch:
"""Test that state is saved after processing.""" """Test that state is saved after processing."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
run_batch( run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="model", model_path="model",
output_dir=output_dir, output_dir=output_dir,
resume=True, resume=True,
) )
state_file = output_dir / ".batch_state.json" state_file = output_dir / ".batch_state.json"
assert state_file.exists() assert state_file.exists()
with open(state_file) as f: with open(state_file) as f:
state_data = json.load(f) state_data = json.load(f)
assert len(state_data["completed"]) == 3 assert len(state_data["completed"]) == 3
@patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "load_pipeline")
@patch.object(batch_inference, "generate_single_video") @patch.object(batch_inference, "generate_single_video")
def test_run_batch_handles_errors(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): def test_run_batch_handles_errors(
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
):
"""Test that batch continues after individual job failures.""" """Test that batch continues after individual job failures."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
# Make the second call raise an exception # Make the second call raise an exception
mock_generate.side_effect = [None, RuntimeError("CUDA OOM"), None] mock_generate.side_effect = [None, RuntimeError("CUDA OOM"), None]
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
summary = run_batch( summary = run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="model", model_path="model",
output_dir=output_dir, output_dir=output_dir,
resume=False, resume=False,
) )
assert summary["completed"] == 2 assert summary["completed"] == 2
assert summary["failed"] == 1 assert summary["failed"] == 1
assert mock_generate.call_count == 3 assert mock_generate.call_count == 3
@patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "load_pipeline")
@patch.object(batch_inference, "generate_single_video") @patch.object(batch_inference, "generate_single_video")
def test_run_batch_logs_errors_to_file(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): def test_run_batch_logs_errors_to_file(
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
):
"""Test that errors are logged to errors.log.""" """Test that errors are logged to errors.log."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
mock_generate.side_effect = RuntimeError("Test error") mock_generate.side_effect = RuntimeError("Test error")
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
run_batch( run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="model", model_path="model",
output_dir=output_dir, output_dir=output_dir,
resume=False, resume=False,
) )
error_log = output_dir / "errors.log" error_log = output_dir / "errors.log"
assert error_log.exists() assert error_log.exists()
content = error_log.read_text() content = error_log.read_text()
@ -930,14 +926,16 @@ class TestRunBatch:
@patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "load_pipeline")
@patch.object(batch_inference, "generate_single_video") @patch.object(batch_inference, "generate_single_video")
def test_run_batch_resume_skips_completed(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): def test_run_batch_resume_skips_completed(
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
):
"""Test that resume skips already-completed jobs.""" """Test that resume skips already-completed jobs."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
output_dir.mkdir() output_dir.mkdir()
# Create a state file showing first job completed # Create a state file showing first job completed
state = BatchState( state = BatchState(
batch_file=str(batch_file), batch_file=str(batch_file),
@ -947,14 +945,14 @@ class TestRunBatch:
completed=["cat_piano.mp4"], completed=["cat_piano.mp4"],
) )
state.save(output_dir / ".batch_state.json") state.save(output_dir / ".batch_state.json")
summary = run_batch( summary = run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="model", model_path="model",
output_dir=output_dir, output_dir=output_dir,
resume=True, resume=True,
) )
# Should only process 2 jobs (skipping the completed one) # Should only process 2 jobs (skipping the completed one)
assert summary["total"] == 2 assert summary["total"] == 2
assert summary["skipped"] == 1 assert summary["skipped"] == 1
@ -965,16 +963,16 @@ class TestRunBatch:
"""Test running batch with empty file.""" """Test running batch with empty file."""
batch_file = tmp_path / "empty.jsonl" batch_file = tmp_path / "empty.jsonl"
batch_file.write_text("") batch_file.write_text("")
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
summary = run_batch( summary = run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="model", model_path="model",
output_dir=output_dir, output_dir=output_dir,
resume=False, resume=False,
) )
assert summary["total"] == 0 assert summary["total"] == 0
# Pipeline should not be loaded for empty batch # Pipeline should not be loaded for empty batch
mock_load_pipeline.assert_not_called() mock_load_pipeline.assert_not_called()
@ -984,18 +982,15 @@ class TestRunBatch:
def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path): def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path):
"""Test multi-GPU job distribution in run_batch.""" """Test multi-GPU job distribution in run_batch."""
# Create batch file with 6 jobs # Create batch file with 6 jobs
jobs = [ jobs = [{"prompt": f"Job {i}", "output_name": f"video{i}.mp4"} for i in range(6)]
{"prompt": f"Job {i}", "output_name": f"video{i}.mp4"}
for i in range(6)
]
batch_file = tmp_path / "batch.jsonl" batch_file = tmp_path / "batch.jsonl"
batch_file.write_text("\n".join(json.dumps(j) for j in jobs)) batch_file.write_text("\n".join(json.dumps(j) for j in jobs))
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
# GPU 0 of 3 should get jobs 0, 3 (indices 0, 3) # GPU 0 of 3 should get jobs 0, 3 (indices 0, 3)
summary = run_batch( summary = run_batch(
batch_file=batch_file, batch_file=batch_file,
@ -1005,25 +1000,27 @@ class TestRunBatch:
num_gpus=3, num_gpus=3,
resume=False, resume=False,
) )
assert summary["total"] == 2 # GPU 0 gets 2 jobs assert summary["total"] == 2 # GPU 0 gets 2 jobs
assert mock_generate.call_count == 2 assert mock_generate.call_count == 2
@patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "load_pipeline")
@patch.object(batch_inference, "generate_single_video") @patch.object(batch_inference, "generate_single_video")
def test_run_batch_state_persists_after_each_job(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): def test_run_batch_state_persists_after_each_job(
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
):
"""Test that state is saved after each job, not just at the end.""" """Test that state is saved after each job, not just at the end."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
# Track how many times state file is written # Track how many times state file is written
state_writes = [] state_writes = []
original_save = BatchState.save original_save = BatchState.save
def tracking_save(self, state_file): def tracking_save(self, state_file):
original_save(self, state_file) original_save(self, state_file)
state_writes.append(len(self.completed)) state_writes.append(len(self.completed))
with patch.object(BatchState, 'save', tracking_save): with patch.object(BatchState, 'save', tracking_save):
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
run_batch( run_batch(
@ -1032,7 +1029,7 @@ class TestRunBatch:
output_dir=output_dir, output_dir=output_dir,
resume=True, resume=True,
) )
# State should be saved 3 times (once per job) # State should be saved 3 times (once per job)
assert len(state_writes) == 3 assert len(state_writes) == 3
assert state_writes == [1, 2, 3] # Progressively more completed jobs assert state_writes == [1, 2, 3] # Progressively more completed jobs
@ -1086,7 +1083,7 @@ class TestEdgeCases:
"height": 1080, "height": 1080,
} }
job = BatchJob.from_dict(data) job = BatchJob.from_dict(data)
assert job.width == 1920 assert job.width == 1920
assert job.height == 1080 assert job.height == 1080
assert job.num_frames == 100 assert job.num_frames == 100
@ -1100,32 +1097,34 @@ class TestEdgeCases:
"another_field": 123, "another_field": 123,
} }
job = BatchJob.from_dict(data) job = BatchJob.from_dict(data)
assert job.prompt == "Test" assert job.prompt == "Test"
assert job.output_name == "test.mp4" assert job.output_name == "test.mp4"
# Extra fields should be ignored without error # Extra fields should be ignored without error
@patch.object(batch_inference, "load_pipeline") @patch.object(batch_inference, "load_pipeline")
@patch.object(batch_inference, "generate_single_video") @patch.object(batch_inference, "generate_single_video")
def test_run_batch_corrupted_state_starts_fresh(self, mock_generate, mock_load_pipeline, batch_file, tmp_path): def test_run_batch_corrupted_state_starts_fresh(
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
):
"""Test that corrupted state file doesn't prevent batch from running.""" """Test that corrupted state file doesn't prevent batch from running."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
output_dir.mkdir() output_dir.mkdir()
# Create corrupted state file # Create corrupted state file
state_file = output_dir / ".batch_state.json" state_file = output_dir / ".batch_state.json"
state_file.write_text("{corrupted: json") state_file.write_text("{corrupted: json")
summary = run_batch( summary = run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="model", model_path="model",
output_dir=output_dir, output_dir=output_dir,
resume=True, resume=True,
) )
# Should start fresh and process all jobs # Should start fresh and process all jobs
assert summary["total"] == 3 assert summary["total"] == 3
assert mock_generate.call_count == 3 assert mock_generate.call_count == 3
@ -1138,12 +1137,12 @@ class TestEdgeCases:
model_path="model", model_path="model",
generate_type="t2v", generate_type="t2v",
) )
assert state.updated_at == "" assert state.updated_at == ""
state_file = tmp_path / "state.json" state_file = tmp_path / "state.json"
state.save(state_file) state.save(state_file)
# updated_at should be set after save # updated_at should be set after save
assert state.updated_at != "" assert state.updated_at != ""
@ -1154,16 +1153,16 @@ class TestEdgeCases:
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_load_pipeline.return_value = mock_pipe mock_load_pipeline.return_value = mock_pipe
mock_generate.side_effect = RuntimeError("All fail") mock_generate.side_effect = RuntimeError("All fail")
output_dir = tmp_path / "output" output_dir = tmp_path / "output"
summary = run_batch( summary = run_batch(
batch_file=batch_file, batch_file=batch_file,
model_path="model", model_path="model",
output_dir=output_dir, output_dir=output_dir,
resume=False, resume=False,
) )
assert summary["completed"] == 0 assert summary["completed"] == 0
assert summary["failed"] == 3 assert summary["failed"] == 3
@ -1174,9 +1173,9 @@ class TestEdgeCases:
{"prompt": "Café résumé naïve", "output_name": "unicode3.mp4"}""" {"prompt": "Café résumé naïve", "output_name": "unicode3.mp4"}"""
batch_file = tmp_path / "unicode.jsonl" batch_file = tmp_path / "unicode.jsonl"
batch_file.write_text(content, encoding="utf-8") batch_file.write_text(content, encoding="utf-8")
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
assert len(jobs) == 3 assert len(jobs) == 3
assert jobs[0].prompt == "猫がピアノを弾いている" assert jobs[0].prompt == "猫がピアノを弾いている"
assert "🌊" in jobs[1].prompt assert "🌊" in jobs[1].prompt
@ -1187,8 +1186,8 @@ class TestEdgeCases:
content = json.dumps({"prompt": long_prompt, "output_name": "long.mp4"}) content = json.dumps({"prompt": long_prompt, "output_name": "long.mp4"})
batch_file = tmp_path / "long.jsonl" batch_file = tmp_path / "long.jsonl"
batch_file.write_text(content) batch_file.write_text(content)
jobs = load_batch_file(batch_file) jobs = load_batch_file(batch_file)
assert len(jobs) == 1 assert len(jobs) == 1
assert len(jobs[0].prompt) > 4000 assert len(jobs[0].prompt) > 4000

View File

@ -107,6 +107,7 @@ RESOLUTION_MAP = {
@dataclass @dataclass
class BatchJob: class BatchJob:
"""Represents a single job in the batch.""" """Represents a single job in the batch."""
prompt: str prompt: str
output_name: str output_name: str
image_path: Optional[str] = None image_path: Optional[str] = None
@ -117,12 +118,12 @@ class BatchJob:
seed: Optional[int] = None seed: Optional[int] = None
width: Optional[int] = None width: Optional[int] = None
height: Optional[int] = None height: Optional[int] = None
# Internal fields # Internal fields
line_number: int = 0 line_number: int = 0
status: str = "pending" status: str = "pending"
error: Optional[str] = None error: Optional[str] = None
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any], line_number: int = 0) -> "BatchJob": def from_dict(cls, data: Dict[str, Any], line_number: int = 0) -> "BatchJob":
"""Create a BatchJob from a dictionary.""" """Create a BatchJob from a dictionary."""
@ -139,7 +140,7 @@ class BatchJob:
height=data.get("height"), height=data.get("height"),
line_number=line_number, line_number=line_number,
) )
def validate(self) -> List[str]: def validate(self) -> List[str]:
"""Validate the job and return list of errors.""" """Validate the job and return list of errors."""
errors = [] errors = []
@ -153,6 +154,7 @@ class BatchJob:
@dataclass @dataclass
class BatchState: class BatchState:
"""Tracks batch progress for resume capability.""" """Tracks batch progress for resume capability."""
batch_file: str batch_file: str
output_dir: str output_dir: str
model_path: str model_path: str
@ -161,7 +163,7 @@ class BatchState:
failed: List[Dict[str, str]] = field(default_factory=list) failed: List[Dict[str, str]] = field(default_factory=list)
started_at: str = "" started_at: str = ""
updated_at: str = "" updated_at: str = ""
@classmethod @classmethod
def load(cls, state_file: Path) -> Optional["BatchState"]: def load(cls, state_file: Path) -> Optional["BatchState"]:
"""Load state from file.""" """Load state from file."""
@ -174,22 +176,22 @@ class BatchState:
except Exception as e: except Exception as e:
logger.warning(f"Failed to load state file: {e}") logger.warning(f"Failed to load state file: {e}")
return None return None
def save(self, state_file: Path): def save(self, state_file: Path):
"""Save state to file.""" """Save state to file."""
self.updated_at = datetime.now().isoformat() self.updated_at = datetime.now().isoformat()
with open(state_file, "w") as f: with open(state_file, "w") as f:
json.dump(self.__dict__, f, indent=2) json.dump(self.__dict__, f, indent=2)
def mark_completed(self, output_name: str): def mark_completed(self, output_name: str):
"""Mark a job as completed.""" """Mark a job as completed."""
if output_name not in self.completed: if output_name not in self.completed:
self.completed.append(output_name) self.completed.append(output_name)
def mark_failed(self, output_name: str, error: str): def mark_failed(self, output_name: str, error: str):
"""Mark a job as failed.""" """Mark a job as failed."""
self.failed.append({"output_name": output_name, "error": error}) self.failed.append({"output_name": output_name, "error": error})
def is_completed(self, output_name: str) -> bool: def is_completed(self, output_name: str) -> bool:
"""Check if a job was already completed.""" """Check if a job was already completed."""
return output_name in self.completed return output_name in self.completed
@ -198,37 +200,37 @@ class BatchState:
def load_batch_file(batch_file: Path) -> List[BatchJob]: def load_batch_file(batch_file: Path) -> List[BatchJob]:
""" """
Load and parse a JSONL batch file. Load and parse a JSONL batch file.
Args: Args:
batch_file: Path to the JSONL file batch_file: Path to the JSONL file
Returns: Returns:
List of BatchJob objects List of BatchJob objects
""" """
jobs = [] jobs = []
with open(batch_file, "r") as f: with open(batch_file, "r") as f:
for line_num, line in enumerate(f, 1): for line_num, line in enumerate(f, 1):
line = line.strip() line = line.strip()
if not line or line.startswith("#"): if not line or line.startswith("#"):
continue continue
try: try:
data = json.loads(line) data = json.loads(line)
job = BatchJob.from_dict(data, line_number=line_num) job = BatchJob.from_dict(data, line_number=line_num)
# Validate job # Validate job
errors = job.validate() errors = job.validate()
if errors: if errors:
logger.warning(f"Line {line_num}: Invalid job - {', '.join(errors)}") logger.warning(f"Line {line_num}: Invalid job - {', '.join(errors)}")
continue continue
jobs.append(job) jobs.append(job)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning(f"Line {line_num}: Invalid JSON - {e}") logger.warning(f"Line {line_num}: Invalid JSON - {e}")
continue continue
return jobs return jobs
@ -241,26 +243,26 @@ def load_pipeline(
): ):
""" """
Load the appropriate pipeline for the generation type. Load the appropriate pipeline for the generation type.
Args: Args:
model_path: Path to the model model_path: Path to the model
generate_type: Type of generation (t2v, i2v, v2v) generate_type: Type of generation (t2v, i2v, v2v)
dtype: Data type for computation dtype: Data type for computation
lora_path: Optional path to LoRA weights lora_path: Optional path to LoRA weights
enable_cpu_offload: Whether to enable CPU offloading enable_cpu_offload: Whether to enable CPU offloading
Returns: Returns:
The loaded pipeline The loaded pipeline
""" """
logger.info(f"Loading pipeline for {generate_type} from {model_path}") logger.info(f"Loading pipeline for {generate_type} from {model_path}")
if generate_type == "i2v": if generate_type == "i2v":
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
elif generate_type == "t2v": elif generate_type == "t2v":
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype) pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
else: # v2v else: # v2v
pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
# Load LoRA weights if provided # Load LoRA weights if provided
if lora_path: if lora_path:
logger.info(f"Loading LoRA weights from {lora_path}") logger.info(f"Loading LoRA weights from {lora_path}")
@ -268,21 +270,21 @@ def load_pipeline(
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="batch_lora" lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="batch_lora"
) )
pipe.fuse_lora(components=["transformer"], lora_scale=1.0) pipe.fuse_lora(components=["transformer"], lora_scale=1.0)
# Set scheduler # Set scheduler
pipe.scheduler = CogVideoXDPMScheduler.from_config( pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing" pipe.scheduler.config, timestep_spacing="trailing"
) )
# Enable memory optimizations # Enable memory optimizations
if enable_cpu_offload: if enable_cpu_offload:
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
else: else:
pipe.to("cuda") pipe.to("cuda")
pipe.vae.enable_slicing() pipe.vae.enable_slicing()
pipe.vae.enable_tiling() pipe.vae.enable_tiling()
logger.info("Pipeline loaded successfully") logger.info("Pipeline loaded successfully")
return pipe return pipe
@ -301,7 +303,7 @@ def generate_single_video(
): ):
""" """
Generate a single video from a job. Generate a single video from a job.
Args: Args:
pipe: The loaded pipeline pipe: The loaded pipeline
job: The batch job to process job: The batch job to process
@ -315,17 +317,17 @@ def generate_single_video(
desired_resolution = RESOLUTION_MAP.get(model_name.lower(), (480, 720)) desired_resolution = RESOLUTION_MAP.get(model_name.lower(), (480, 720))
height = job.height if job.height else desired_resolution[0] height = job.height if job.height else desired_resolution[0]
width = job.width if job.width else desired_resolution[1] width = job.width if job.width else desired_resolution[1]
# Use job-specific or default values # Use job-specific or default values
num_frames = job.num_frames or default_num_frames num_frames = job.num_frames or default_num_frames
guidance_scale = job.guidance_scale or default_guidance_scale guidance_scale = job.guidance_scale or default_guidance_scale
num_inference_steps = job.num_inference_steps or default_num_inference_steps num_inference_steps = job.num_inference_steps or default_num_inference_steps
seed = job.seed or default_seed seed = job.seed or default_seed
# Load image/video if needed # Load image/video if needed
image = None image = None
video = None video = None
if generate_type == "i2v": if generate_type == "i2v":
if not job.image_path: if not job.image_path:
raise ValueError("image_path is required for i2v generation") raise ValueError("image_path is required for i2v generation")
@ -334,10 +336,10 @@ def generate_single_video(
if not job.video_path: if not job.video_path:
raise ValueError("video_path is required for v2v generation") raise ValueError("video_path is required for v2v generation")
video = load_video(job.video_path) video = load_video(job.video_path)
# Generate video # Generate video
generator = torch.Generator().manual_seed(seed) generator = torch.Generator().manual_seed(seed)
if generate_type == "i2v": if generate_type == "i2v":
video_frames = pipe( video_frames = pipe(
height=height, height=height,
@ -376,7 +378,7 @@ def generate_single_video(
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
generator=generator, generator=generator,
).frames[0] ).frames[0]
# Export video # Export video
export_to_video(video_frames, str(output_path), fps=fps) export_to_video(video_frames, str(output_path), fps=fps)
@ -400,7 +402,7 @@ def run_batch(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run batch inference on a JSONL file. Run batch inference on a JSONL file.
Args: Args:
batch_file: Path to the JSONL batch file batch_file: Path to the JSONL batch file
model_path: Path to the model model_path: Path to the model
@ -414,7 +416,7 @@ def run_batch(
num_gpus: Total number of GPUs for distribution num_gpus: Total number of GPUs for distribution
default_*: Default values for optional parameters default_*: Default values for optional parameters
fps: Frames per second for output videos fps: Frames per second for output videos
Returns: Returns:
Summary dictionary with statistics Summary dictionary with statistics
""" """
@ -422,26 +424,26 @@ def run_batch(
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
state_file = output_dir / ".batch_state.json" state_file = output_dir / ".batch_state.json"
error_log = output_dir / "errors.log" error_log = output_dir / "errors.log"
# Load jobs # Load jobs
logger.info(f"Loading batch file: {batch_file}") logger.info(f"Loading batch file: {batch_file}")
all_jobs = load_batch_file(batch_file) all_jobs = load_batch_file(batch_file)
logger.info(f"Found {len(all_jobs)} valid jobs in batch file") logger.info(f"Found {len(all_jobs)} valid jobs in batch file")
# Distribute jobs across GPUs if using multi-GPU # Distribute jobs across GPUs if using multi-GPU
if num_gpus > 1: if num_gpus > 1:
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id] jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
logger.info(f"GPU {gpu_id}/{num_gpus}: Processing {len(jobs)} jobs") logger.info(f"GPU {gpu_id}/{num_gpus}: Processing {len(jobs)} jobs")
else: else:
jobs = all_jobs jobs = all_jobs
# Load or create state # Load or create state
state = None state = None
if resume: if resume:
state = BatchState.load(state_file) state = BatchState.load(state_file)
if state: if state:
logger.info(f"Resuming batch: {len(state.completed)} already completed") logger.info(f"Resuming batch: {len(state.completed)} already completed")
if state is None: if state is None:
state = BatchState( state = BatchState(
batch_file=str(batch_file), batch_file=str(batch_file),
@ -450,7 +452,7 @@ def run_batch(
generate_type=generate_type, generate_type=generate_type,
started_at=datetime.now().isoformat(), started_at=datetime.now().isoformat(),
) )
# Filter out completed jobs # Filter out completed jobs
if resume: if resume:
pending_jobs = [j for j in jobs if not state.is_completed(j.output_name)] pending_jobs = [j for j in jobs if not state.is_completed(j.output_name)]
@ -458,11 +460,11 @@ def run_batch(
if skipped > 0: if skipped > 0:
logger.info(f"Skipping {skipped} already-completed jobs") logger.info(f"Skipping {skipped} already-completed jobs")
jobs = pending_jobs jobs = pending_jobs
if not jobs: if not jobs:
logger.info("No jobs to process") logger.info("No jobs to process")
return {"total": 0, "completed": 0, "failed": 0, "skipped": len(all_jobs)} return {"total": 0, "completed": 0, "failed": 0, "skipped": len(all_jobs)}
# Load pipeline # Load pipeline
model_name = model_path.split("/")[-1] model_name = model_path.split("/")[-1]
pipe = load_pipeline( pipe = load_pipeline(
@ -472,19 +474,19 @@ def run_batch(
lora_path=lora_path, lora_path=lora_path,
enable_cpu_offload=enable_cpu_offload, enable_cpu_offload=enable_cpu_offload,
) )
# Process jobs # Process jobs
completed = 0 completed = 0
failed = 0 failed = 0
start_time = time.time() start_time = time.time()
with tqdm(total=len(jobs), desc="Generating videos", unit="video") as pbar: with tqdm(total=len(jobs), desc="Generating videos", unit="video") as pbar:
for job in jobs: for job in jobs:
output_path = output_dir / job.output_name output_path = output_dir / job.output_name
try: try:
logger.info(f"Processing: {job.output_name} - \"{job.prompt[:50]}...\"") logger.info(f"Processing: {job.output_name} - \"{job.prompt[:50]}...\"")
generate_single_video( generate_single_video(
pipe=pipe, pipe=pipe,
job=job, job=job,
@ -497,15 +499,15 @@ def run_batch(
default_seed=default_seed, default_seed=default_seed,
fps=fps, fps=fps,
) )
state.mark_completed(job.output_name) state.mark_completed(job.output_name)
completed += 1 completed += 1
logger.info(f"Completed: {job.output_name}") logger.info(f"Completed: {job.output_name}")
except Exception as e: except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}" error_msg = f"{type(e).__name__}: {str(e)}"
logger.error(f"Failed: {job.output_name} - {error_msg}") logger.error(f"Failed: {job.output_name} - {error_msg}")
# Log full traceback to error log # Log full traceback to error log
with open(error_log, "a") as f: with open(error_log, "a") as f:
f.write(f"\n{'='*60}\n") f.write(f"\n{'='*60}\n")
@ -514,26 +516,24 @@ def run_batch(
f.write(f"Time: {datetime.now().isoformat()}\n") f.write(f"Time: {datetime.now().isoformat()}\n")
f.write(f"Error: {error_msg}\n") f.write(f"Error: {error_msg}\n")
f.write(traceback.format_exc()) f.write(traceback.format_exc())
state.mark_failed(job.output_name, error_msg) state.mark_failed(job.output_name, error_msg)
failed += 1 failed += 1
# Save state after each job (for resume) # Save state after each job (for resume)
state.save(state_file) state.save(state_file)
pbar.update(1) pbar.update(1)
# Update ETA in progress bar # Update ETA in progress bar
elapsed = time.time() - start_time elapsed = time.time() - start_time
if completed + failed > 0: if completed + failed > 0:
avg_time = elapsed / (completed + failed) avg_time = elapsed / (completed + failed)
remaining = len(jobs) - (completed + failed) remaining = len(jobs) - (completed + failed)
eta_seconds = avg_time * remaining eta_seconds = avg_time * remaining
pbar.set_postfix({ pbar.set_postfix(
"done": completed, {"done": completed, "failed": failed, "ETA": f"{eta_seconds/60:.1f}m"}
"failed": failed, )
"ETA": f"{eta_seconds/60:.1f}m"
})
# Final summary # Final summary
elapsed_total = time.time() - start_time elapsed_total = time.time() - start_time
summary = { summary = {
@ -544,7 +544,7 @@ def run_batch(
"elapsed_seconds": elapsed_total, "elapsed_seconds": elapsed_total,
"avg_seconds_per_video": elapsed_total / max(completed + failed, 1), "avg_seconds_per_video": elapsed_total / max(completed + failed, 1),
} }
logger.info("=" * 60) logger.info("=" * 60)
logger.info("BATCH COMPLETE") logger.info("BATCH COMPLETE")
logger.info(f" Total jobs: {summary['total']}") logger.info(f" Total jobs: {summary['total']}")
@ -554,7 +554,7 @@ def run_batch(
if summary['failed'] > 0: if summary['failed'] > 0:
logger.info(f" See errors in: {error_log}") logger.info(f" See errors in: {error_log}")
logger.info("=" * 60) logger.info("=" * 60)
return summary return summary
@ -566,11 +566,11 @@ def main():
Examples: Examples:
# Basic text-to-video batch # Basic text-to-video batch
python batch_inference.py --batch_file prompts.jsonl --model_path THUDM/CogVideoX1.5-5B python batch_inference.py --batch_file prompts.jsonl --model_path THUDM/CogVideoX1.5-5B
# Image-to-video batch with custom output directory # Image-to-video batch with custom output directory
python batch_inference.py --batch_file i2v.jsonl --model_path THUDM/CogVideoX1.5-5B-I2V \\ python batch_inference.py --batch_file i2v.jsonl --model_path THUDM/CogVideoX1.5-5B-I2V \\
--generate_type i2v --output_dir ./my_videos --generate_type i2v --output_dir ./my_videos
# Multi-GPU: run on 4 GPUs (one process per GPU) # Multi-GPU: run on 4 GPUs (one process per GPU)
for i in {0..3}; do for i in {0..3}; do
CUDA_VISIBLE_DEVICES=$i python batch_inference.py --batch_file batch.jsonl \\ CUDA_VISIBLE_DEVICES=$i python batch_inference.py --batch_file batch.jsonl \\
@ -579,128 +579,91 @@ Examples:
JSONL Format: JSONL Format:
Each line is a JSON object with: prompt (required), output_name (required), Each line is a JSON object with: prompt (required), output_name (required),
and optional: image_path, video_path, num_frames, guidance_scale, and optional: image_path, video_path, num_frames, guidance_scale,
num_inference_steps, seed, width, height num_inference_steps, seed, width, height
""" """,
) )
# Required arguments # Required arguments
parser.add_argument( parser.add_argument("--batch_file", type=str, required=True, help="Path to JSONL batch file")
"--batch_file",
type=str,
required=True,
help="Path to JSONL batch file"
)
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
type=str, type=str,
default="THUDM/CogVideoX1.5-5B", default="THUDM/CogVideoX1.5-5B",
help="Path to the model (default: THUDM/CogVideoX1.5-5B)" help="Path to the model (default: THUDM/CogVideoX1.5-5B)",
) )
# Output settings # Output settings
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
type=str, type=str,
default="./batch_output", default="./batch_output",
help="Directory for output videos (default: ./batch_output)" help="Directory for output videos (default: ./batch_output)",
) )
parser.add_argument( parser.add_argument(
"--generate_type", "--generate_type",
type=str, type=str,
choices=["t2v", "i2v", "v2v"], choices=["t2v", "i2v", "v2v"],
default="t2v", default="t2v",
help="Generation type (default: t2v)" help="Generation type (default: t2v)",
) )
# Model settings # Model settings
parser.add_argument( parser.add_argument(
"--lora_path", "--lora_path", type=str, default=None, help="Path to LoRA weights (optional)"
type=str,
default=None,
help="Path to LoRA weights (optional)"
) )
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
type=str, type=str,
choices=["float16", "bfloat16"], choices=["float16", "bfloat16"],
default="bfloat16", default="bfloat16",
help="Data type for computation (default: bfloat16)" help="Data type for computation (default: bfloat16)",
) )
parser.add_argument( parser.add_argument(
"--disable_cpu_offload", "--disable_cpu_offload",
action="store_true", action="store_true",
help="Disable CPU offloading (uses more VRAM but faster)" help="Disable CPU offloading (uses more VRAM but faster)",
) )
# Default generation parameters # Default generation parameters
parser.add_argument( parser.add_argument(
"--num_frames", "--num_frames", type=int, default=81, help="Default number of frames (default: 81)"
type=int,
default=81,
help="Default number of frames (default: 81)"
) )
parser.add_argument( parser.add_argument(
"--guidance_scale", "--guidance_scale", type=float, default=6.0, help="Default guidance scale (default: 6.0)"
type=float,
default=6.0,
help="Default guidance scale (default: 6.0)"
) )
parser.add_argument( parser.add_argument(
"--num_inference_steps", "--num_inference_steps", type=int, default=50, help="Default inference steps (default: 50)"
type=int,
default=50,
help="Default inference steps (default: 50)"
) )
parser.add_argument( parser.add_argument("--seed", type=int, default=42, help="Default random seed (default: 42)")
"--seed", parser.add_argument("--fps", type=int, default=16, help="Output video FPS (default: 16)")
type=int,
default=42,
help="Default random seed (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=16,
help="Output video FPS (default: 16)"
)
# Resume and multi-GPU # Resume and multi-GPU
parser.add_argument( parser.add_argument(
"--resume", "--resume",
action="store_true", action="store_true",
default=True, default=True,
help="Resume from previous state (default: True)" help="Resume from previous state (default: True)",
)
parser.add_argument("--no_resume", action="store_true", help="Don't resume, start fresh")
parser.add_argument(
"--gpu_id", type=int, default=0, help="GPU ID for multi-GPU distribution (default: 0)"
) )
parser.add_argument( parser.add_argument(
"--no_resume", "--num_gpus", type=int, default=1, help="Total number of GPUs for distribution (default: 1)"
action="store_true",
help="Don't resume, start fresh"
) )
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="GPU ID for multi-GPU distribution (default: 0)"
)
parser.add_argument(
"--num_gpus",
type=int,
default=1,
help="Total number of GPUs for distribution (default: 1)"
)
args = parser.parse_args() args = parser.parse_args()
# Validate batch file exists # Validate batch file exists
batch_file = Path(args.batch_file) batch_file = Path(args.batch_file)
if not batch_file.exists(): if not batch_file.exists():
logger.error(f"Batch file not found: {batch_file}") logger.error(f"Batch file not found: {batch_file}")
sys.exit(1) sys.exit(1)
# Parse dtype # Parse dtype
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
# Run batch # Run batch
try: try:
summary = run_batch( summary = run_batch(
@ -720,11 +683,11 @@ JSONL Format:
default_seed=args.seed, default_seed=args.seed,
fps=args.fps, fps=args.fps,
) )
# Exit with error code if any failures # Exit with error code if any failures
if summary["failed"] > 0: if summary["failed"] > 0:
sys.exit(1) sys.exit(1)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("\nBatch interrupted by user. Progress saved for resume.") logger.info("\nBatch interrupted by user. Progress saved for resume.")
sys.exit(130) sys.exit(130)