mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-28 22:38:18 +08:00
chore: allow tests/ directory in gitignore
This commit is contained in:
parent
15a4b403c4
commit
206760830a
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,6 +8,8 @@ logs/
|
|||||||
.idea
|
.idea
|
||||||
output*
|
output*
|
||||||
test*
|
test*
|
||||||
|
!tests/
|
||||||
|
!tests/**
|
||||||
venv
|
venv
|
||||||
**/.swp
|
**/.swp
|
||||||
**/*.log
|
**/*.log
|
||||||
|
|||||||
@ -34,28 +34,30 @@ mock_torch.Generator = MagicMock
|
|||||||
# Create mock diffusers module
|
# Create mock diffusers module
|
||||||
mock_diffusers = MagicMock()
|
mock_diffusers = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
# Create mock tqdm that works as a context manager
|
# Create mock tqdm that works as a context manager
|
||||||
class MockTqdm:
|
class MockTqdm:
|
||||||
"""Mock tqdm that works as a context manager and returns a dummy progress bar."""
|
"""Mock tqdm that works as a context manager and returns a dummy progress bar."""
|
||||||
|
|
||||||
def __init__(self, iterable=None, total=None, **kwargs):
|
def __init__(self, iterable=None, total=None, **kwargs):
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
self.n = 0
|
self.n = 0
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self.iterable is not None:
|
if self.iterable is not None:
|
||||||
for item in self.iterable:
|
for item in self.iterable:
|
||||||
yield item
|
yield item
|
||||||
self.n += 1
|
self.n += 1
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def update(self, n=1):
|
def update(self, n=1):
|
||||||
self.n += n
|
self.n += n
|
||||||
|
|
||||||
def set_postfix(self, ordered_dict=None, refresh=True, **kwargs):
|
def set_postfix(self, ordered_dict=None, refresh=True, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -172,7 +174,7 @@ class TestBatchJob:
|
|||||||
def test_from_dict_basic(self, sample_job_dict):
|
def test_from_dict_basic(self, sample_job_dict):
|
||||||
"""Test creating a BatchJob from a dictionary."""
|
"""Test creating a BatchJob from a dictionary."""
|
||||||
job = BatchJob.from_dict(sample_job_dict, line_number=1)
|
job = BatchJob.from_dict(sample_job_dict, line_number=1)
|
||||||
|
|
||||||
assert job.prompt == "A cat playing piano"
|
assert job.prompt == "A cat playing piano"
|
||||||
assert job.output_name == "cat_piano.mp4"
|
assert job.output_name == "cat_piano.mp4"
|
||||||
assert job.num_frames == 49
|
assert job.num_frames == 49
|
||||||
@ -185,7 +187,7 @@ class TestBatchJob:
|
|||||||
"""Test creating a BatchJob with only required fields."""
|
"""Test creating a BatchJob with only required fields."""
|
||||||
data = {"prompt": "Test prompt", "output_name": "test.mp4"}
|
data = {"prompt": "Test prompt", "output_name": "test.mp4"}
|
||||||
job = BatchJob.from_dict(data)
|
job = BatchJob.from_dict(data)
|
||||||
|
|
||||||
assert job.prompt == "Test prompt"
|
assert job.prompt == "Test prompt"
|
||||||
assert job.output_name == "test.mp4"
|
assert job.output_name == "test.mp4"
|
||||||
assert job.num_frames is None
|
assert job.num_frames is None
|
||||||
@ -195,14 +197,14 @@ class TestBatchJob:
|
|||||||
def test_from_dict_with_image_path(self, sample_i2v_job_dict):
|
def test_from_dict_with_image_path(self, sample_i2v_job_dict):
|
||||||
"""Test creating a BatchJob with image_path for i2v."""
|
"""Test creating a BatchJob with image_path for i2v."""
|
||||||
job = BatchJob.from_dict(sample_i2v_job_dict)
|
job = BatchJob.from_dict(sample_i2v_job_dict)
|
||||||
|
|
||||||
assert job.image_path == "/path/to/image.jpg"
|
assert job.image_path == "/path/to/image.jpg"
|
||||||
assert job.video_path is None
|
assert job.video_path is None
|
||||||
|
|
||||||
def test_from_dict_with_video_path(self, sample_v2v_job_dict):
|
def test_from_dict_with_video_path(self, sample_v2v_job_dict):
|
||||||
"""Test creating a BatchJob with video_path for v2v."""
|
"""Test creating a BatchJob with video_path for v2v."""
|
||||||
job = BatchJob.from_dict(sample_v2v_job_dict)
|
job = BatchJob.from_dict(sample_v2v_job_dict)
|
||||||
|
|
||||||
assert job.video_path == "/path/to/video.mp4"
|
assert job.video_path == "/path/to/video.mp4"
|
||||||
assert job.image_path is None
|
assert job.image_path is None
|
||||||
|
|
||||||
@ -210,28 +212,28 @@ class TestBatchJob:
|
|||||||
"""Test validation of a valid job returns no errors."""
|
"""Test validation of a valid job returns no errors."""
|
||||||
job = BatchJob.from_dict(sample_job_dict)
|
job = BatchJob.from_dict(sample_job_dict)
|
||||||
errors = job.validate()
|
errors = job.validate()
|
||||||
|
|
||||||
assert errors == []
|
assert errors == []
|
||||||
|
|
||||||
def test_validate_missing_prompt(self):
|
def test_validate_missing_prompt(self):
|
||||||
"""Test validation catches missing prompt."""
|
"""Test validation catches missing prompt."""
|
||||||
job = BatchJob.from_dict({"output_name": "test.mp4"})
|
job = BatchJob.from_dict({"output_name": "test.mp4"})
|
||||||
errors = job.validate()
|
errors = job.validate()
|
||||||
|
|
||||||
assert "Missing required field: prompt" in errors
|
assert "Missing required field: prompt" in errors
|
||||||
|
|
||||||
def test_validate_missing_output_name(self):
|
def test_validate_missing_output_name(self):
|
||||||
"""Test validation catches missing output_name."""
|
"""Test validation catches missing output_name."""
|
||||||
job = BatchJob.from_dict({"prompt": "Test"})
|
job = BatchJob.from_dict({"prompt": "Test"})
|
||||||
errors = job.validate()
|
errors = job.validate()
|
||||||
|
|
||||||
assert "Missing required field: output_name" in errors
|
assert "Missing required field: output_name" in errors
|
||||||
|
|
||||||
def test_validate_missing_both_required(self):
|
def test_validate_missing_both_required(self):
|
||||||
"""Test validation catches both missing required fields."""
|
"""Test validation catches both missing required fields."""
|
||||||
job = BatchJob.from_dict({})
|
job = BatchJob.from_dict({})
|
||||||
errors = job.validate()
|
errors = job.validate()
|
||||||
|
|
||||||
assert len(errors) == 2
|
assert len(errors) == 2
|
||||||
assert "Missing required field: prompt" in errors
|
assert "Missing required field: prompt" in errors
|
||||||
assert "Missing required field: output_name" in errors
|
assert "Missing required field: output_name" in errors
|
||||||
@ -253,7 +255,7 @@ class TestBatchState:
|
|||||||
model_path="THUDM/CogVideoX1.5-5B",
|
model_path="THUDM/CogVideoX1.5-5B",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert state.batch_file == "batch.jsonl"
|
assert state.batch_file == "batch.jsonl"
|
||||||
assert state.completed == []
|
assert state.completed == []
|
||||||
assert state.failed == []
|
assert state.failed == []
|
||||||
@ -266,10 +268,10 @@ class TestBatchState:
|
|||||||
model_path="model",
|
model_path="model",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
)
|
)
|
||||||
|
|
||||||
state.mark_completed("video1.mp4")
|
state.mark_completed("video1.mp4")
|
||||||
state.mark_completed("video2.mp4")
|
state.mark_completed("video2.mp4")
|
||||||
|
|
||||||
assert "video1.mp4" in state.completed
|
assert "video1.mp4" in state.completed
|
||||||
assert "video2.mp4" in state.completed
|
assert "video2.mp4" in state.completed
|
||||||
|
|
||||||
@ -281,10 +283,10 @@ class TestBatchState:
|
|||||||
model_path="model",
|
model_path="model",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
)
|
)
|
||||||
|
|
||||||
state.mark_completed("video1.mp4")
|
state.mark_completed("video1.mp4")
|
||||||
state.mark_completed("video1.mp4")
|
state.mark_completed("video1.mp4")
|
||||||
|
|
||||||
assert state.completed.count("video1.mp4") == 1
|
assert state.completed.count("video1.mp4") == 1
|
||||||
|
|
||||||
def test_mark_failed(self):
|
def test_mark_failed(self):
|
||||||
@ -295,9 +297,9 @@ class TestBatchState:
|
|||||||
model_path="model",
|
model_path="model",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
)
|
)
|
||||||
|
|
||||||
state.mark_failed("video1.mp4", "CUDA out of memory")
|
state.mark_failed("video1.mp4", "CUDA out of memory")
|
||||||
|
|
||||||
assert len(state.failed) == 1
|
assert len(state.failed) == 1
|
||||||
assert state.failed[0]["output_name"] == "video1.mp4"
|
assert state.failed[0]["output_name"] == "video1.mp4"
|
||||||
assert state.failed[0]["error"] == "CUDA out of memory"
|
assert state.failed[0]["error"] == "CUDA out of memory"
|
||||||
@ -310,16 +312,16 @@ class TestBatchState:
|
|||||||
model_path="model",
|
model_path="model",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
)
|
)
|
||||||
|
|
||||||
state.mark_completed("video1.mp4")
|
state.mark_completed("video1.mp4")
|
||||||
|
|
||||||
assert state.is_completed("video1.mp4") is True
|
assert state.is_completed("video1.mp4") is True
|
||||||
assert state.is_completed("video2.mp4") is False
|
assert state.is_completed("video2.mp4") is False
|
||||||
|
|
||||||
def test_save_and_load(self, tmp_path):
|
def test_save_and_load(self, tmp_path):
|
||||||
"""Test saving and loading state."""
|
"""Test saving and loading state."""
|
||||||
state_file = tmp_path / ".batch_state.json"
|
state_file = tmp_path / ".batch_state.json"
|
||||||
|
|
||||||
# Create and save state
|
# Create and save state
|
||||||
state = BatchState(
|
state = BatchState(
|
||||||
batch_file="batch.jsonl",
|
batch_file="batch.jsonl",
|
||||||
@ -331,10 +333,10 @@ class TestBatchState:
|
|||||||
state.mark_completed("video1.mp4")
|
state.mark_completed("video1.mp4")
|
||||||
state.mark_failed("video2.mp4", "Error")
|
state.mark_failed("video2.mp4", "Error")
|
||||||
state.save(state_file)
|
state.save(state_file)
|
||||||
|
|
||||||
# Load state
|
# Load state
|
||||||
loaded = BatchState.load(state_file)
|
loaded = BatchState.load(state_file)
|
||||||
|
|
||||||
assert loaded is not None
|
assert loaded is not None
|
||||||
assert loaded.batch_file == "batch.jsonl"
|
assert loaded.batch_file == "batch.jsonl"
|
||||||
assert "video1.mp4" in loaded.completed
|
assert "video1.mp4" in loaded.completed
|
||||||
@ -344,27 +346,27 @@ class TestBatchState:
|
|||||||
def test_load_nonexistent_file(self, tmp_path):
|
def test_load_nonexistent_file(self, tmp_path):
|
||||||
"""Test loading from a nonexistent file returns None."""
|
"""Test loading from a nonexistent file returns None."""
|
||||||
state_file = tmp_path / "nonexistent.json"
|
state_file = tmp_path / "nonexistent.json"
|
||||||
|
|
||||||
loaded = BatchState.load(state_file)
|
loaded = BatchState.load(state_file)
|
||||||
|
|
||||||
assert loaded is None
|
assert loaded is None
|
||||||
|
|
||||||
def test_load_corrupted_file(self, tmp_path):
|
def test_load_corrupted_file(self, tmp_path):
|
||||||
"""Test loading a corrupted state file returns None."""
|
"""Test loading a corrupted state file returns None."""
|
||||||
state_file = tmp_path / ".batch_state.json"
|
state_file = tmp_path / ".batch_state.json"
|
||||||
state_file.write_text("{ invalid json }")
|
state_file.write_text("{ invalid json }")
|
||||||
|
|
||||||
loaded = BatchState.load(state_file)
|
loaded = BatchState.load(state_file)
|
||||||
|
|
||||||
assert loaded is None
|
assert loaded is None
|
||||||
|
|
||||||
def test_load_invalid_json_structure(self, tmp_path):
|
def test_load_invalid_json_structure(self, tmp_path):
|
||||||
"""Test loading a file with valid JSON but invalid structure."""
|
"""Test loading a file with valid JSON but invalid structure."""
|
||||||
state_file = tmp_path / ".batch_state.json"
|
state_file = tmp_path / ".batch_state.json"
|
||||||
state_file.write_text('{"wrong": "structure"}')
|
state_file.write_text('{"wrong": "structure"}')
|
||||||
|
|
||||||
loaded = BatchState.load(state_file)
|
loaded = BatchState.load(state_file)
|
||||||
|
|
||||||
# Should return None due to missing required fields
|
# Should return None due to missing required fields
|
||||||
assert loaded is None
|
assert loaded is None
|
||||||
|
|
||||||
@ -380,7 +382,7 @@ class TestLoadBatchFile:
|
|||||||
def test_load_valid_jsonl(self, batch_file):
|
def test_load_valid_jsonl(self, batch_file):
|
||||||
"""Test loading a valid JSONL file."""
|
"""Test loading a valid JSONL file."""
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
assert len(jobs) == 3
|
assert len(jobs) == 3
|
||||||
assert jobs[0].prompt == "A cat playing piano"
|
assert jobs[0].prompt == "A cat playing piano"
|
||||||
assert jobs[1].output_name == "beach.mp4"
|
assert jobs[1].output_name == "beach.mp4"
|
||||||
@ -394,9 +396,9 @@ class TestLoadBatchFile:
|
|||||||
{"prompt": "Test2", "output_name": "test2.mp4"}"""
|
{"prompt": "Test2", "output_name": "test2.mp4"}"""
|
||||||
batch_file = tmp_path / "batch.jsonl"
|
batch_file = tmp_path / "batch.jsonl"
|
||||||
batch_file.write_text(content)
|
batch_file.write_text(content)
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
assert len(jobs) == 2
|
assert len(jobs) == 2
|
||||||
|
|
||||||
def test_load_jsonl_with_empty_lines(self, tmp_path):
|
def test_load_jsonl_with_empty_lines(self, tmp_path):
|
||||||
@ -408,9 +410,9 @@ class TestLoadBatchFile:
|
|||||||
"""
|
"""
|
||||||
batch_file = tmp_path / "batch.jsonl"
|
batch_file = tmp_path / "batch.jsonl"
|
||||||
batch_file.write_text(content)
|
batch_file.write_text(content)
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
assert len(jobs) == 2
|
assert len(jobs) == 2
|
||||||
|
|
||||||
def test_load_jsonl_invalid_json_line(self, tmp_path):
|
def test_load_jsonl_invalid_json_line(self, tmp_path):
|
||||||
@ -420,9 +422,9 @@ class TestLoadBatchFile:
|
|||||||
{"prompt": "Test2", "output_name": "test2.mp4"}"""
|
{"prompt": "Test2", "output_name": "test2.mp4"}"""
|
||||||
batch_file = tmp_path / "batch.jsonl"
|
batch_file = tmp_path / "batch.jsonl"
|
||||||
batch_file.write_text(content)
|
batch_file.write_text(content)
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
# Should skip the invalid line
|
# Should skip the invalid line
|
||||||
assert len(jobs) == 2
|
assert len(jobs) == 2
|
||||||
|
|
||||||
@ -434,9 +436,9 @@ class TestLoadBatchFile:
|
|||||||
{"prompt": "Test2", "output_name": "test2.mp4"}"""
|
{"prompt": "Test2", "output_name": "test2.mp4"}"""
|
||||||
batch_file = tmp_path / "batch.jsonl"
|
batch_file = tmp_path / "batch.jsonl"
|
||||||
batch_file.write_text(content)
|
batch_file.write_text(content)
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
# Should skip the two invalid jobs
|
# Should skip the two invalid jobs
|
||||||
assert len(jobs) == 2
|
assert len(jobs) == 2
|
||||||
assert jobs[0].output_name == "test.mp4"
|
assert jobs[0].output_name == "test.mp4"
|
||||||
@ -446,9 +448,9 @@ class TestLoadBatchFile:
|
|||||||
"""Test loading an empty batch file."""
|
"""Test loading an empty batch file."""
|
||||||
batch_file = tmp_path / "empty.jsonl"
|
batch_file = tmp_path / "empty.jsonl"
|
||||||
batch_file.write_text("")
|
batch_file.write_text("")
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
assert len(jobs) == 0
|
assert len(jobs) == 0
|
||||||
|
|
||||||
def test_load_batch_file_all_comments(self, tmp_path):
|
def test_load_batch_file_all_comments(self, tmp_path):
|
||||||
@ -458,9 +460,9 @@ class TestLoadBatchFile:
|
|||||||
# Comment 3"""
|
# Comment 3"""
|
||||||
batch_file = tmp_path / "comments.jsonl"
|
batch_file = tmp_path / "comments.jsonl"
|
||||||
batch_file.write_text(content)
|
batch_file.write_text(content)
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
assert len(jobs) == 0
|
assert len(jobs) == 0
|
||||||
|
|
||||||
def test_load_batch_file_tracks_line_numbers(self, tmp_path):
|
def test_load_batch_file_tracks_line_numbers(self, tmp_path):
|
||||||
@ -471,9 +473,9 @@ class TestLoadBatchFile:
|
|||||||
{"prompt": "Line 4", "output_name": "line4.mp4"}"""
|
{"prompt": "Line 4", "output_name": "line4.mp4"}"""
|
||||||
batch_file = tmp_path / "batch.jsonl"
|
batch_file = tmp_path / "batch.jsonl"
|
||||||
batch_file.write_text(content)
|
batch_file.write_text(content)
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
assert jobs[0].line_number == 2
|
assert jobs[0].line_number == 2
|
||||||
assert jobs[1].line_number == 4
|
assert jobs[1].line_number == 4
|
||||||
|
|
||||||
@ -488,30 +490,24 @@ class TestMultiGPUDistribution:
|
|||||||
|
|
||||||
def test_single_gpu_gets_all_jobs(self):
|
def test_single_gpu_gets_all_jobs(self):
|
||||||
"""Test that a single GPU processes all jobs."""
|
"""Test that a single GPU processes all jobs."""
|
||||||
all_jobs = [
|
all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
|
||||||
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Simulate distribution for single GPU (gpu_id=0, num_gpus=1)
|
# Simulate distribution for single GPU (gpu_id=0, num_gpus=1)
|
||||||
gpu_id = 0
|
gpu_id = 0
|
||||||
num_gpus = 1
|
num_gpus = 1
|
||||||
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
|
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
|
||||||
|
|
||||||
assert len(jobs) == 10
|
assert len(jobs) == 10
|
||||||
|
|
||||||
def test_two_gpu_distribution(self):
|
def test_two_gpu_distribution(self):
|
||||||
"""Test job distribution across 2 GPUs."""
|
"""Test job distribution across 2 GPUs."""
|
||||||
all_jobs = [
|
all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
|
||||||
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
|
|
||||||
# GPU 0 gets jobs 0, 2, 4, 6, 8
|
# GPU 0 gets jobs 0, 2, 4, 6, 8
|
||||||
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0]
|
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0]
|
||||||
# GPU 1 gets jobs 1, 3, 5, 7, 9
|
# GPU 1 gets jobs 1, 3, 5, 7, 9
|
||||||
gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 1]
|
gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 1]
|
||||||
|
|
||||||
assert len(gpu0_jobs) == 5
|
assert len(gpu0_jobs) == 5
|
||||||
assert len(gpu1_jobs) == 5
|
assert len(gpu1_jobs) == 5
|
||||||
assert gpu0_jobs[0].output_name == "video0.mp4"
|
assert gpu0_jobs[0].output_name == "video0.mp4"
|
||||||
@ -519,11 +515,8 @@ class TestMultiGPUDistribution:
|
|||||||
|
|
||||||
def test_four_gpu_distribution(self):
|
def test_four_gpu_distribution(self):
|
||||||
"""Test job distribution across 4 GPUs."""
|
"""Test job distribution across 4 GPUs."""
|
||||||
all_jobs = [
|
all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(12)]
|
||||||
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
|
|
||||||
for i in range(12)
|
|
||||||
]
|
|
||||||
|
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
for gpu_id in range(num_gpus):
|
for gpu_id in range(num_gpus):
|
||||||
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
|
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
|
||||||
@ -531,16 +524,13 @@ class TestMultiGPUDistribution:
|
|||||||
|
|
||||||
def test_uneven_distribution(self):
|
def test_uneven_distribution(self):
|
||||||
"""Test distribution when jobs don't divide evenly."""
|
"""Test distribution when jobs don't divide evenly."""
|
||||||
all_jobs = [
|
all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
|
||||||
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
|
|
||||||
num_gpus = 3
|
num_gpus = 3
|
||||||
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0]
|
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0]
|
||||||
gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 1]
|
gpu1_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 1]
|
||||||
gpu2_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 2]
|
gpu2_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 2]
|
||||||
|
|
||||||
# 10 jobs / 3 GPUs: GPU0 gets 4, GPU1 gets 3, GPU2 gets 3
|
# 10 jobs / 3 GPUs: GPU0 gets 4, GPU1 gets 3, GPU2 gets 3
|
||||||
assert len(gpu0_jobs) == 4 # indices 0, 3, 6, 9
|
assert len(gpu0_jobs) == 4 # indices 0, 3, 6, 9
|
||||||
assert len(gpu1_jobs) == 3 # indices 1, 4, 7
|
assert len(gpu1_jobs) == 3 # indices 1, 4, 7
|
||||||
@ -562,14 +552,14 @@ class TestLoadPipeline:
|
|||||||
mock_pipe.scheduler = MagicMock()
|
mock_pipe.scheduler = MagicMock()
|
||||||
mock_pipe.vae = MagicMock()
|
mock_pipe.vae = MagicMock()
|
||||||
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
|
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
|
||||||
|
|
||||||
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
model_path="THUDM/CogVideoX1.5-5B",
|
model_path="THUDM/CogVideoX1.5-5B",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
enable_cpu_offload=False,
|
enable_cpu_offload=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_cogvideo_pipeline.from_pretrained.assert_called_once()
|
mock_cogvideo_pipeline.from_pretrained.assert_called_once()
|
||||||
assert pipe is mock_pipe
|
assert pipe is mock_pipe
|
||||||
|
|
||||||
@ -580,14 +570,14 @@ class TestLoadPipeline:
|
|||||||
mock_pipe.scheduler = MagicMock()
|
mock_pipe.scheduler = MagicMock()
|
||||||
mock_pipe.vae = MagicMock()
|
mock_pipe.vae = MagicMock()
|
||||||
mock_i2v_pipeline.from_pretrained.return_value = mock_pipe
|
mock_i2v_pipeline.from_pretrained.return_value = mock_pipe
|
||||||
|
|
||||||
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
model_path="THUDM/CogVideoX1.5-5B-I2V",
|
model_path="THUDM/CogVideoX1.5-5B-I2V",
|
||||||
generate_type="i2v",
|
generate_type="i2v",
|
||||||
enable_cpu_offload=False,
|
enable_cpu_offload=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_i2v_pipeline.from_pretrained.assert_called_once()
|
mock_i2v_pipeline.from_pretrained.assert_called_once()
|
||||||
assert pipe is mock_pipe
|
assert pipe is mock_pipe
|
||||||
|
|
||||||
@ -598,14 +588,14 @@ class TestLoadPipeline:
|
|||||||
mock_pipe.scheduler = MagicMock()
|
mock_pipe.scheduler = MagicMock()
|
||||||
mock_pipe.vae = MagicMock()
|
mock_pipe.vae = MagicMock()
|
||||||
mock_v2v_pipeline.from_pretrained.return_value = mock_pipe
|
mock_v2v_pipeline.from_pretrained.return_value = mock_pipe
|
||||||
|
|
||||||
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
model_path="THUDM/CogVideoX1.5-5B",
|
model_path="THUDM/CogVideoX1.5-5B",
|
||||||
generate_type="v2v",
|
generate_type="v2v",
|
||||||
enable_cpu_offload=False,
|
enable_cpu_offload=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_v2v_pipeline.from_pretrained.assert_called_once()
|
mock_v2v_pipeline.from_pretrained.assert_called_once()
|
||||||
assert pipe is mock_pipe
|
assert pipe is mock_pipe
|
||||||
|
|
||||||
@ -616,14 +606,14 @@ class TestLoadPipeline:
|
|||||||
mock_pipe.scheduler = MagicMock()
|
mock_pipe.scheduler = MagicMock()
|
||||||
mock_pipe.vae = MagicMock()
|
mock_pipe.vae = MagicMock()
|
||||||
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
|
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
|
||||||
|
|
||||||
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
||||||
load_pipeline(
|
load_pipeline(
|
||||||
model_path="model",
|
model_path="model",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_pipe.enable_sequential_cpu_offload.assert_called_once()
|
mock_pipe.enable_sequential_cpu_offload.assert_called_once()
|
||||||
mock_pipe.to.assert_not_called()
|
mock_pipe.to.assert_not_called()
|
||||||
|
|
||||||
@ -634,14 +624,14 @@ class TestLoadPipeline:
|
|||||||
mock_pipe.scheduler = MagicMock()
|
mock_pipe.scheduler = MagicMock()
|
||||||
mock_pipe.vae = MagicMock()
|
mock_pipe.vae = MagicMock()
|
||||||
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
|
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
|
||||||
|
|
||||||
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
||||||
load_pipeline(
|
load_pipeline(
|
||||||
model_path="model",
|
model_path="model",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
enable_cpu_offload=False,
|
enable_cpu_offload=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_pipe.to.assert_called_once_with("cuda")
|
mock_pipe.to.assert_called_once_with("cuda")
|
||||||
mock_pipe.enable_sequential_cpu_offload.assert_not_called()
|
mock_pipe.enable_sequential_cpu_offload.assert_not_called()
|
||||||
|
|
||||||
@ -652,7 +642,7 @@ class TestLoadPipeline:
|
|||||||
mock_pipe.scheduler = MagicMock()
|
mock_pipe.scheduler = MagicMock()
|
||||||
mock_pipe.vae = MagicMock()
|
mock_pipe.vae = MagicMock()
|
||||||
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
|
mock_cogvideo_pipeline.from_pretrained.return_value = mock_pipe
|
||||||
|
|
||||||
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
with patch.object(batch_inference, "CogVideoXDPMScheduler"):
|
||||||
load_pipeline(
|
load_pipeline(
|
||||||
model_path="model",
|
model_path="model",
|
||||||
@ -660,7 +650,7 @@ class TestLoadPipeline:
|
|||||||
lora_path="/path/to/lora",
|
lora_path="/path/to/lora",
|
||||||
enable_cpu_offload=False,
|
enable_cpu_offload=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_pipe.load_lora_weights.assert_called_once()
|
mock_pipe.load_lora_weights.assert_called_once()
|
||||||
mock_pipe.fuse_lora.assert_called_once()
|
mock_pipe.fuse_lora.assert_called_once()
|
||||||
|
|
||||||
@ -682,7 +672,7 @@ class TestGenerateSingleVideo:
|
|||||||
num_frames=49,
|
num_frames=49,
|
||||||
)
|
)
|
||||||
output_path = tmp_path / "cat.mp4"
|
output_path = tmp_path / "cat.mp4"
|
||||||
|
|
||||||
generate_single_video(
|
generate_single_video(
|
||||||
pipe=mock_pipeline,
|
pipe=mock_pipeline,
|
||||||
job=job,
|
job=job,
|
||||||
@ -690,7 +680,7 @@ class TestGenerateSingleVideo:
|
|||||||
model_name="cogvideox-5b",
|
model_name="cogvideox-5b",
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_pipeline.assert_called_once()
|
mock_pipeline.assert_called_once()
|
||||||
mock_export.assert_called_once()
|
mock_export.assert_called_once()
|
||||||
|
|
||||||
@ -699,14 +689,14 @@ class TestGenerateSingleVideo:
|
|||||||
def test_generate_i2v_video(self, mock_load_image, mock_export, mock_pipeline, tmp_path):
|
def test_generate_i2v_video(self, mock_load_image, mock_export, mock_pipeline, tmp_path):
|
||||||
"""Test generating an image-to-video."""
|
"""Test generating an image-to-video."""
|
||||||
mock_load_image.return_value = MagicMock()
|
mock_load_image.return_value = MagicMock()
|
||||||
|
|
||||||
job = BatchJob(
|
job = BatchJob(
|
||||||
prompt="Transform this",
|
prompt="Transform this",
|
||||||
output_name="i2v.mp4",
|
output_name="i2v.mp4",
|
||||||
image_path="/path/to/image.jpg",
|
image_path="/path/to/image.jpg",
|
||||||
)
|
)
|
||||||
output_path = tmp_path / "i2v.mp4"
|
output_path = tmp_path / "i2v.mp4"
|
||||||
|
|
||||||
generate_single_video(
|
generate_single_video(
|
||||||
pipe=mock_pipeline,
|
pipe=mock_pipeline,
|
||||||
job=job,
|
job=job,
|
||||||
@ -714,7 +704,7 @@ class TestGenerateSingleVideo:
|
|||||||
model_name="cogvideox-5b-i2v",
|
model_name="cogvideox-5b-i2v",
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_load_image.assert_called_once_with(image="/path/to/image.jpg")
|
mock_load_image.assert_called_once_with(image="/path/to/image.jpg")
|
||||||
mock_pipeline.assert_called_once()
|
mock_pipeline.assert_called_once()
|
||||||
|
|
||||||
@ -723,14 +713,14 @@ class TestGenerateSingleVideo:
|
|||||||
def test_generate_v2v_video(self, mock_load_video, mock_export, mock_pipeline, tmp_path):
|
def test_generate_v2v_video(self, mock_load_video, mock_export, mock_pipeline, tmp_path):
|
||||||
"""Test generating a video-to-video."""
|
"""Test generating a video-to-video."""
|
||||||
mock_load_video.return_value = MagicMock()
|
mock_load_video.return_value = MagicMock()
|
||||||
|
|
||||||
job = BatchJob(
|
job = BatchJob(
|
||||||
prompt="Enhance this",
|
prompt="Enhance this",
|
||||||
output_name="v2v.mp4",
|
output_name="v2v.mp4",
|
||||||
video_path="/path/to/video.mp4",
|
video_path="/path/to/video.mp4",
|
||||||
)
|
)
|
||||||
output_path = tmp_path / "v2v.mp4"
|
output_path = tmp_path / "v2v.mp4"
|
||||||
|
|
||||||
generate_single_video(
|
generate_single_video(
|
||||||
pipe=mock_pipeline,
|
pipe=mock_pipeline,
|
||||||
job=job,
|
job=job,
|
||||||
@ -738,7 +728,7 @@ class TestGenerateSingleVideo:
|
|||||||
model_name="cogvideox-5b",
|
model_name="cogvideox-5b",
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_load_video.assert_called_once_with("/path/to/video.mp4")
|
mock_load_video.assert_called_once_with("/path/to/video.mp4")
|
||||||
mock_pipeline.assert_called_once()
|
mock_pipeline.assert_called_once()
|
||||||
|
|
||||||
@ -750,7 +740,7 @@ class TestGenerateSingleVideo:
|
|||||||
# Missing image_path
|
# Missing image_path
|
||||||
)
|
)
|
||||||
output_path = tmp_path / "i2v.mp4"
|
output_path = tmp_path / "i2v.mp4"
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="image_path is required"):
|
with pytest.raises(ValueError, match="image_path is required"):
|
||||||
generate_single_video(
|
generate_single_video(
|
||||||
pipe=mock_pipeline,
|
pipe=mock_pipeline,
|
||||||
@ -768,7 +758,7 @@ class TestGenerateSingleVideo:
|
|||||||
# Missing video_path
|
# Missing video_path
|
||||||
)
|
)
|
||||||
output_path = tmp_path / "v2v.mp4"
|
output_path = tmp_path / "v2v.mp4"
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="video_path is required"):
|
with pytest.raises(ValueError, match="video_path is required"):
|
||||||
generate_single_video(
|
generate_single_video(
|
||||||
pipe=mock_pipeline,
|
pipe=mock_pipeline,
|
||||||
@ -790,7 +780,7 @@ class TestGenerateSingleVideo:
|
|||||||
seed=456,
|
seed=456,
|
||||||
)
|
)
|
||||||
output_path = tmp_path / "test.mp4"
|
output_path = tmp_path / "test.mp4"
|
||||||
|
|
||||||
generate_single_video(
|
generate_single_video(
|
||||||
pipe=mock_pipeline,
|
pipe=mock_pipeline,
|
||||||
job=job,
|
job=job,
|
||||||
@ -802,7 +792,7 @@ class TestGenerateSingleVideo:
|
|||||||
default_num_inference_steps=50,
|
default_num_inference_steps=50,
|
||||||
default_seed=42,
|
default_seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check the pipeline was called with job-specific values
|
# Check the pipeline was called with job-specific values
|
||||||
call_kwargs = mock_pipeline.call_args[1]
|
call_kwargs = mock_pipeline.call_args[1]
|
||||||
assert call_kwargs["num_frames"] == 33
|
assert call_kwargs["num_frames"] == 33
|
||||||
@ -824,9 +814,9 @@ class TestRunBatch:
|
|||||||
"""Test basic batch processing."""
|
"""Test basic batch processing."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
|
|
||||||
summary = run_batch(
|
summary = run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="THUDM/CogVideoX1.5-5B",
|
model_path="THUDM/CogVideoX1.5-5B",
|
||||||
@ -834,7 +824,7 @@ class TestRunBatch:
|
|||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
resume=False,
|
resume=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert summary["total"] == 3
|
assert summary["total"] == 3
|
||||||
assert summary["completed"] == 3
|
assert summary["completed"] == 3
|
||||||
assert summary["failed"] == 0
|
assert summary["failed"] == 0
|
||||||
@ -842,21 +832,23 @@ class TestRunBatch:
|
|||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_creates_output_dir(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_creates_output_dir(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that output directory is created if it doesn't exist."""
|
"""Test that output directory is created if it doesn't exist."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|
||||||
output_dir = tmp_path / "nested" / "output" / "dir"
|
output_dir = tmp_path / "nested" / "output" / "dir"
|
||||||
assert not output_dir.exists()
|
assert not output_dir.exists()
|
||||||
|
|
||||||
run_batch(
|
run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="model",
|
model_path="model",
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=False,
|
resume=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert output_dir.exists()
|
assert output_dir.exists()
|
||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@ -865,64 +857,68 @@ class TestRunBatch:
|
|||||||
"""Test that state is saved after processing."""
|
"""Test that state is saved after processing."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
|
|
||||||
run_batch(
|
run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="model",
|
model_path="model",
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=True,
|
resume=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_file = output_dir / ".batch_state.json"
|
state_file = output_dir / ".batch_state.json"
|
||||||
assert state_file.exists()
|
assert state_file.exists()
|
||||||
|
|
||||||
with open(state_file) as f:
|
with open(state_file) as f:
|
||||||
state_data = json.load(f)
|
state_data = json.load(f)
|
||||||
|
|
||||||
assert len(state_data["completed"]) == 3
|
assert len(state_data["completed"]) == 3
|
||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_handles_errors(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_handles_errors(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that batch continues after individual job failures."""
|
"""Test that batch continues after individual job failures."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|
||||||
# Make the second call raise an exception
|
# Make the second call raise an exception
|
||||||
mock_generate.side_effect = [None, RuntimeError("CUDA OOM"), None]
|
mock_generate.side_effect = [None, RuntimeError("CUDA OOM"), None]
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
|
|
||||||
summary = run_batch(
|
summary = run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="model",
|
model_path="model",
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=False,
|
resume=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert summary["completed"] == 2
|
assert summary["completed"] == 2
|
||||||
assert summary["failed"] == 1
|
assert summary["failed"] == 1
|
||||||
assert mock_generate.call_count == 3
|
assert mock_generate.call_count == 3
|
||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_logs_errors_to_file(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_logs_errors_to_file(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that errors are logged to errors.log."""
|
"""Test that errors are logged to errors.log."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
mock_generate.side_effect = RuntimeError("Test error")
|
mock_generate.side_effect = RuntimeError("Test error")
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
|
|
||||||
run_batch(
|
run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="model",
|
model_path="model",
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=False,
|
resume=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
error_log = output_dir / "errors.log"
|
error_log = output_dir / "errors.log"
|
||||||
assert error_log.exists()
|
assert error_log.exists()
|
||||||
content = error_log.read_text()
|
content = error_log.read_text()
|
||||||
@ -930,14 +926,16 @@ class TestRunBatch:
|
|||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_resume_skips_completed(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_resume_skips_completed(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that resume skips already-completed jobs."""
|
"""Test that resume skips already-completed jobs."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
|
|
||||||
# Create a state file showing first job completed
|
# Create a state file showing first job completed
|
||||||
state = BatchState(
|
state = BatchState(
|
||||||
batch_file=str(batch_file),
|
batch_file=str(batch_file),
|
||||||
@ -947,14 +945,14 @@ class TestRunBatch:
|
|||||||
completed=["cat_piano.mp4"],
|
completed=["cat_piano.mp4"],
|
||||||
)
|
)
|
||||||
state.save(output_dir / ".batch_state.json")
|
state.save(output_dir / ".batch_state.json")
|
||||||
|
|
||||||
summary = run_batch(
|
summary = run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="model",
|
model_path="model",
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=True,
|
resume=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should only process 2 jobs (skipping the completed one)
|
# Should only process 2 jobs (skipping the completed one)
|
||||||
assert summary["total"] == 2
|
assert summary["total"] == 2
|
||||||
assert summary["skipped"] == 1
|
assert summary["skipped"] == 1
|
||||||
@ -965,16 +963,16 @@ class TestRunBatch:
|
|||||||
"""Test running batch with empty file."""
|
"""Test running batch with empty file."""
|
||||||
batch_file = tmp_path / "empty.jsonl"
|
batch_file = tmp_path / "empty.jsonl"
|
||||||
batch_file.write_text("")
|
batch_file.write_text("")
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
|
|
||||||
summary = run_batch(
|
summary = run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="model",
|
model_path="model",
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=False,
|
resume=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert summary["total"] == 0
|
assert summary["total"] == 0
|
||||||
# Pipeline should not be loaded for empty batch
|
# Pipeline should not be loaded for empty batch
|
||||||
mock_load_pipeline.assert_not_called()
|
mock_load_pipeline.assert_not_called()
|
||||||
@ -984,18 +982,15 @@ class TestRunBatch:
|
|||||||
def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path):
|
def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path):
|
||||||
"""Test multi-GPU job distribution in run_batch."""
|
"""Test multi-GPU job distribution in run_batch."""
|
||||||
# Create batch file with 6 jobs
|
# Create batch file with 6 jobs
|
||||||
jobs = [
|
jobs = [{"prompt": f"Job {i}", "output_name": f"video{i}.mp4"} for i in range(6)]
|
||||||
{"prompt": f"Job {i}", "output_name": f"video{i}.mp4"}
|
|
||||||
for i in range(6)
|
|
||||||
]
|
|
||||||
batch_file = tmp_path / "batch.jsonl"
|
batch_file = tmp_path / "batch.jsonl"
|
||||||
batch_file.write_text("\n".join(json.dumps(j) for j in jobs))
|
batch_file.write_text("\n".join(json.dumps(j) for j in jobs))
|
||||||
|
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
|
|
||||||
# GPU 0 of 3 should get jobs 0, 3 (indices 0, 3)
|
# GPU 0 of 3 should get jobs 0, 3 (indices 0, 3)
|
||||||
summary = run_batch(
|
summary = run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
@ -1005,25 +1000,27 @@ class TestRunBatch:
|
|||||||
num_gpus=3,
|
num_gpus=3,
|
||||||
resume=False,
|
resume=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert summary["total"] == 2 # GPU 0 gets 2 jobs
|
assert summary["total"] == 2 # GPU 0 gets 2 jobs
|
||||||
assert mock_generate.call_count == 2
|
assert mock_generate.call_count == 2
|
||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_state_persists_after_each_job(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_state_persists_after_each_job(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that state is saved after each job, not just at the end."""
|
"""Test that state is saved after each job, not just at the end."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|
||||||
# Track how many times state file is written
|
# Track how many times state file is written
|
||||||
state_writes = []
|
state_writes = []
|
||||||
original_save = BatchState.save
|
original_save = BatchState.save
|
||||||
|
|
||||||
def tracking_save(self, state_file):
|
def tracking_save(self, state_file):
|
||||||
original_save(self, state_file)
|
original_save(self, state_file)
|
||||||
state_writes.append(len(self.completed))
|
state_writes.append(len(self.completed))
|
||||||
|
|
||||||
with patch.object(BatchState, 'save', tracking_save):
|
with patch.object(BatchState, 'save', tracking_save):
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
run_batch(
|
run_batch(
|
||||||
@ -1032,7 +1029,7 @@ class TestRunBatch:
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=True,
|
resume=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# State should be saved 3 times (once per job)
|
# State should be saved 3 times (once per job)
|
||||||
assert len(state_writes) == 3
|
assert len(state_writes) == 3
|
||||||
assert state_writes == [1, 2, 3] # Progressively more completed jobs
|
assert state_writes == [1, 2, 3] # Progressively more completed jobs
|
||||||
@ -1086,7 +1083,7 @@ class TestEdgeCases:
|
|||||||
"height": 1080,
|
"height": 1080,
|
||||||
}
|
}
|
||||||
job = BatchJob.from_dict(data)
|
job = BatchJob.from_dict(data)
|
||||||
|
|
||||||
assert job.width == 1920
|
assert job.width == 1920
|
||||||
assert job.height == 1080
|
assert job.height == 1080
|
||||||
assert job.num_frames == 100
|
assert job.num_frames == 100
|
||||||
@ -1100,32 +1097,34 @@ class TestEdgeCases:
|
|||||||
"another_field": 123,
|
"another_field": 123,
|
||||||
}
|
}
|
||||||
job = BatchJob.from_dict(data)
|
job = BatchJob.from_dict(data)
|
||||||
|
|
||||||
assert job.prompt == "Test"
|
assert job.prompt == "Test"
|
||||||
assert job.output_name == "test.mp4"
|
assert job.output_name == "test.mp4"
|
||||||
# Extra fields should be ignored without error
|
# Extra fields should be ignored without error
|
||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_corrupted_state_starts_fresh(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_corrupted_state_starts_fresh(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that corrupted state file doesn't prevent batch from running."""
|
"""Test that corrupted state file doesn't prevent batch from running."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
|
|
||||||
# Create corrupted state file
|
# Create corrupted state file
|
||||||
state_file = output_dir / ".batch_state.json"
|
state_file = output_dir / ".batch_state.json"
|
||||||
state_file.write_text("{corrupted: json")
|
state_file.write_text("{corrupted: json")
|
||||||
|
|
||||||
summary = run_batch(
|
summary = run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="model",
|
model_path="model",
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=True,
|
resume=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should start fresh and process all jobs
|
# Should start fresh and process all jobs
|
||||||
assert summary["total"] == 3
|
assert summary["total"] == 3
|
||||||
assert mock_generate.call_count == 3
|
assert mock_generate.call_count == 3
|
||||||
@ -1138,12 +1137,12 @@ class TestEdgeCases:
|
|||||||
model_path="model",
|
model_path="model",
|
||||||
generate_type="t2v",
|
generate_type="t2v",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert state.updated_at == ""
|
assert state.updated_at == ""
|
||||||
|
|
||||||
state_file = tmp_path / "state.json"
|
state_file = tmp_path / "state.json"
|
||||||
state.save(state_file)
|
state.save(state_file)
|
||||||
|
|
||||||
# updated_at should be set after save
|
# updated_at should be set after save
|
||||||
assert state.updated_at != ""
|
assert state.updated_at != ""
|
||||||
|
|
||||||
@ -1154,16 +1153,16 @@ class TestEdgeCases:
|
|||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
mock_generate.side_effect = RuntimeError("All fail")
|
mock_generate.side_effect = RuntimeError("All fail")
|
||||||
|
|
||||||
output_dir = tmp_path / "output"
|
output_dir = tmp_path / "output"
|
||||||
|
|
||||||
summary = run_batch(
|
summary = run_batch(
|
||||||
batch_file=batch_file,
|
batch_file=batch_file,
|
||||||
model_path="model",
|
model_path="model",
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
resume=False,
|
resume=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert summary["completed"] == 0
|
assert summary["completed"] == 0
|
||||||
assert summary["failed"] == 3
|
assert summary["failed"] == 3
|
||||||
|
|
||||||
@ -1174,9 +1173,9 @@ class TestEdgeCases:
|
|||||||
{"prompt": "Café résumé naïve", "output_name": "unicode3.mp4"}"""
|
{"prompt": "Café résumé naïve", "output_name": "unicode3.mp4"}"""
|
||||||
batch_file = tmp_path / "unicode.jsonl"
|
batch_file = tmp_path / "unicode.jsonl"
|
||||||
batch_file.write_text(content, encoding="utf-8")
|
batch_file.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
assert len(jobs) == 3
|
assert len(jobs) == 3
|
||||||
assert jobs[0].prompt == "猫がピアノを弾いている"
|
assert jobs[0].prompt == "猫がピアノを弾いている"
|
||||||
assert "🌊" in jobs[1].prompt
|
assert "🌊" in jobs[1].prompt
|
||||||
@ -1187,8 +1186,8 @@ class TestEdgeCases:
|
|||||||
content = json.dumps({"prompt": long_prompt, "output_name": "long.mp4"})
|
content = json.dumps({"prompt": long_prompt, "output_name": "long.mp4"})
|
||||||
batch_file = tmp_path / "long.jsonl"
|
batch_file = tmp_path / "long.jsonl"
|
||||||
batch_file.write_text(content)
|
batch_file.write_text(content)
|
||||||
|
|
||||||
jobs = load_batch_file(batch_file)
|
jobs = load_batch_file(batch_file)
|
||||||
|
|
||||||
assert len(jobs) == 1
|
assert len(jobs) == 1
|
||||||
assert len(jobs[0].prompt) > 4000
|
assert len(jobs[0].prompt) > 4000
|
||||||
|
|||||||
@ -107,6 +107,7 @@ RESOLUTION_MAP = {
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchJob:
|
class BatchJob:
|
||||||
"""Represents a single job in the batch."""
|
"""Represents a single job in the batch."""
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
output_name: str
|
output_name: str
|
||||||
image_path: Optional[str] = None
|
image_path: Optional[str] = None
|
||||||
@ -117,12 +118,12 @@ class BatchJob:
|
|||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
width: Optional[int] = None
|
width: Optional[int] = None
|
||||||
height: Optional[int] = None
|
height: Optional[int] = None
|
||||||
|
|
||||||
# Internal fields
|
# Internal fields
|
||||||
line_number: int = 0
|
line_number: int = 0
|
||||||
status: str = "pending"
|
status: str = "pending"
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any], line_number: int = 0) -> "BatchJob":
|
def from_dict(cls, data: Dict[str, Any], line_number: int = 0) -> "BatchJob":
|
||||||
"""Create a BatchJob from a dictionary."""
|
"""Create a BatchJob from a dictionary."""
|
||||||
@ -139,7 +140,7 @@ class BatchJob:
|
|||||||
height=data.get("height"),
|
height=data.get("height"),
|
||||||
line_number=line_number,
|
line_number=line_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self) -> List[str]:
|
def validate(self) -> List[str]:
|
||||||
"""Validate the job and return list of errors."""
|
"""Validate the job and return list of errors."""
|
||||||
errors = []
|
errors = []
|
||||||
@ -153,6 +154,7 @@ class BatchJob:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchState:
|
class BatchState:
|
||||||
"""Tracks batch progress for resume capability."""
|
"""Tracks batch progress for resume capability."""
|
||||||
|
|
||||||
batch_file: str
|
batch_file: str
|
||||||
output_dir: str
|
output_dir: str
|
||||||
model_path: str
|
model_path: str
|
||||||
@ -161,7 +163,7 @@ class BatchState:
|
|||||||
failed: List[Dict[str, str]] = field(default_factory=list)
|
failed: List[Dict[str, str]] = field(default_factory=list)
|
||||||
started_at: str = ""
|
started_at: str = ""
|
||||||
updated_at: str = ""
|
updated_at: str = ""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, state_file: Path) -> Optional["BatchState"]:
|
def load(cls, state_file: Path) -> Optional["BatchState"]:
|
||||||
"""Load state from file."""
|
"""Load state from file."""
|
||||||
@ -174,22 +176,22 @@ class BatchState:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load state file: {e}")
|
logger.warning(f"Failed to load state file: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def save(self, state_file: Path):
|
def save(self, state_file: Path):
|
||||||
"""Save state to file."""
|
"""Save state to file."""
|
||||||
self.updated_at = datetime.now().isoformat()
|
self.updated_at = datetime.now().isoformat()
|
||||||
with open(state_file, "w") as f:
|
with open(state_file, "w") as f:
|
||||||
json.dump(self.__dict__, f, indent=2)
|
json.dump(self.__dict__, f, indent=2)
|
||||||
|
|
||||||
def mark_completed(self, output_name: str):
|
def mark_completed(self, output_name: str):
|
||||||
"""Mark a job as completed."""
|
"""Mark a job as completed."""
|
||||||
if output_name not in self.completed:
|
if output_name not in self.completed:
|
||||||
self.completed.append(output_name)
|
self.completed.append(output_name)
|
||||||
|
|
||||||
def mark_failed(self, output_name: str, error: str):
|
def mark_failed(self, output_name: str, error: str):
|
||||||
"""Mark a job as failed."""
|
"""Mark a job as failed."""
|
||||||
self.failed.append({"output_name": output_name, "error": error})
|
self.failed.append({"output_name": output_name, "error": error})
|
||||||
|
|
||||||
def is_completed(self, output_name: str) -> bool:
|
def is_completed(self, output_name: str) -> bool:
|
||||||
"""Check if a job was already completed."""
|
"""Check if a job was already completed."""
|
||||||
return output_name in self.completed
|
return output_name in self.completed
|
||||||
@ -198,37 +200,37 @@ class BatchState:
|
|||||||
def load_batch_file(batch_file: Path) -> List[BatchJob]:
|
def load_batch_file(batch_file: Path) -> List[BatchJob]:
|
||||||
"""
|
"""
|
||||||
Load and parse a JSONL batch file.
|
Load and parse a JSONL batch file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_file: Path to the JSONL file
|
batch_file: Path to the JSONL file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of BatchJob objects
|
List of BatchJob objects
|
||||||
"""
|
"""
|
||||||
jobs = []
|
jobs = []
|
||||||
|
|
||||||
with open(batch_file, "r") as f:
|
with open(batch_file, "r") as f:
|
||||||
for line_num, line in enumerate(f, 1):
|
for line_num, line in enumerate(f, 1):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line or line.startswith("#"):
|
if not line or line.startswith("#"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
job = BatchJob.from_dict(data, line_number=line_num)
|
job = BatchJob.from_dict(data, line_number=line_num)
|
||||||
|
|
||||||
# Validate job
|
# Validate job
|
||||||
errors = job.validate()
|
errors = job.validate()
|
||||||
if errors:
|
if errors:
|
||||||
logger.warning(f"Line {line_num}: Invalid job - {', '.join(errors)}")
|
logger.warning(f"Line {line_num}: Invalid job - {', '.join(errors)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
jobs.append(job)
|
jobs.append(job)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
|
logger.warning(f"Line {line_num}: Invalid JSON - {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
|
|
||||||
@ -241,26 +243,26 @@ def load_pipeline(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load the appropriate pipeline for the generation type.
|
Load the appropriate pipeline for the generation type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to the model
|
model_path: Path to the model
|
||||||
generate_type: Type of generation (t2v, i2v, v2v)
|
generate_type: Type of generation (t2v, i2v, v2v)
|
||||||
dtype: Data type for computation
|
dtype: Data type for computation
|
||||||
lora_path: Optional path to LoRA weights
|
lora_path: Optional path to LoRA weights
|
||||||
enable_cpu_offload: Whether to enable CPU offloading
|
enable_cpu_offload: Whether to enable CPU offloading
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The loaded pipeline
|
The loaded pipeline
|
||||||
"""
|
"""
|
||||||
logger.info(f"Loading pipeline for {generate_type} from {model_path}")
|
logger.info(f"Loading pipeline for {generate_type} from {model_path}")
|
||||||
|
|
||||||
if generate_type == "i2v":
|
if generate_type == "i2v":
|
||||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||||
elif generate_type == "t2v":
|
elif generate_type == "t2v":
|
||||||
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||||
else: # v2v
|
else: # v2v
|
||||||
pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||||
|
|
||||||
# Load LoRA weights if provided
|
# Load LoRA weights if provided
|
||||||
if lora_path:
|
if lora_path:
|
||||||
logger.info(f"Loading LoRA weights from {lora_path}")
|
logger.info(f"Loading LoRA weights from {lora_path}")
|
||||||
@ -268,21 +270,21 @@ def load_pipeline(
|
|||||||
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="batch_lora"
|
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="batch_lora"
|
||||||
)
|
)
|
||||||
pipe.fuse_lora(components=["transformer"], lora_scale=1.0)
|
pipe.fuse_lora(components=["transformer"], lora_scale=1.0)
|
||||||
|
|
||||||
# Set scheduler
|
# Set scheduler
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||||
pipe.scheduler.config, timestep_spacing="trailing"
|
pipe.scheduler.config, timestep_spacing="trailing"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enable memory optimizations
|
# Enable memory optimizations
|
||||||
if enable_cpu_offload:
|
if enable_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
else:
|
else:
|
||||||
pipe.to("cuda")
|
pipe.to("cuda")
|
||||||
|
|
||||||
pipe.vae.enable_slicing()
|
pipe.vae.enable_slicing()
|
||||||
pipe.vae.enable_tiling()
|
pipe.vae.enable_tiling()
|
||||||
|
|
||||||
logger.info("Pipeline loaded successfully")
|
logger.info("Pipeline loaded successfully")
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@ -301,7 +303,7 @@ def generate_single_video(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate a single video from a job.
|
Generate a single video from a job.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pipe: The loaded pipeline
|
pipe: The loaded pipeline
|
||||||
job: The batch job to process
|
job: The batch job to process
|
||||||
@ -315,17 +317,17 @@ def generate_single_video(
|
|||||||
desired_resolution = RESOLUTION_MAP.get(model_name.lower(), (480, 720))
|
desired_resolution = RESOLUTION_MAP.get(model_name.lower(), (480, 720))
|
||||||
height = job.height if job.height else desired_resolution[0]
|
height = job.height if job.height else desired_resolution[0]
|
||||||
width = job.width if job.width else desired_resolution[1]
|
width = job.width if job.width else desired_resolution[1]
|
||||||
|
|
||||||
# Use job-specific or default values
|
# Use job-specific or default values
|
||||||
num_frames = job.num_frames or default_num_frames
|
num_frames = job.num_frames or default_num_frames
|
||||||
guidance_scale = job.guidance_scale or default_guidance_scale
|
guidance_scale = job.guidance_scale or default_guidance_scale
|
||||||
num_inference_steps = job.num_inference_steps or default_num_inference_steps
|
num_inference_steps = job.num_inference_steps or default_num_inference_steps
|
||||||
seed = job.seed or default_seed
|
seed = job.seed or default_seed
|
||||||
|
|
||||||
# Load image/video if needed
|
# Load image/video if needed
|
||||||
image = None
|
image = None
|
||||||
video = None
|
video = None
|
||||||
|
|
||||||
if generate_type == "i2v":
|
if generate_type == "i2v":
|
||||||
if not job.image_path:
|
if not job.image_path:
|
||||||
raise ValueError("image_path is required for i2v generation")
|
raise ValueError("image_path is required for i2v generation")
|
||||||
@ -334,10 +336,10 @@ def generate_single_video(
|
|||||||
if not job.video_path:
|
if not job.video_path:
|
||||||
raise ValueError("video_path is required for v2v generation")
|
raise ValueError("video_path is required for v2v generation")
|
||||||
video = load_video(job.video_path)
|
video = load_video(job.video_path)
|
||||||
|
|
||||||
# Generate video
|
# Generate video
|
||||||
generator = torch.Generator().manual_seed(seed)
|
generator = torch.Generator().manual_seed(seed)
|
||||||
|
|
||||||
if generate_type == "i2v":
|
if generate_type == "i2v":
|
||||||
video_frames = pipe(
|
video_frames = pipe(
|
||||||
height=height,
|
height=height,
|
||||||
@ -376,7 +378,7 @@ def generate_single_video(
|
|||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
).frames[0]
|
).frames[0]
|
||||||
|
|
||||||
# Export video
|
# Export video
|
||||||
export_to_video(video_frames, str(output_path), fps=fps)
|
export_to_video(video_frames, str(output_path), fps=fps)
|
||||||
|
|
||||||
@ -400,7 +402,7 @@ def run_batch(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run batch inference on a JSONL file.
|
Run batch inference on a JSONL file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_file: Path to the JSONL batch file
|
batch_file: Path to the JSONL batch file
|
||||||
model_path: Path to the model
|
model_path: Path to the model
|
||||||
@ -414,7 +416,7 @@ def run_batch(
|
|||||||
num_gpus: Total number of GPUs for distribution
|
num_gpus: Total number of GPUs for distribution
|
||||||
default_*: Default values for optional parameters
|
default_*: Default values for optional parameters
|
||||||
fps: Frames per second for output videos
|
fps: Frames per second for output videos
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Summary dictionary with statistics
|
Summary dictionary with statistics
|
||||||
"""
|
"""
|
||||||
@ -422,26 +424,26 @@ def run_batch(
|
|||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
state_file = output_dir / ".batch_state.json"
|
state_file = output_dir / ".batch_state.json"
|
||||||
error_log = output_dir / "errors.log"
|
error_log = output_dir / "errors.log"
|
||||||
|
|
||||||
# Load jobs
|
# Load jobs
|
||||||
logger.info(f"Loading batch file: {batch_file}")
|
logger.info(f"Loading batch file: {batch_file}")
|
||||||
all_jobs = load_batch_file(batch_file)
|
all_jobs = load_batch_file(batch_file)
|
||||||
logger.info(f"Found {len(all_jobs)} valid jobs in batch file")
|
logger.info(f"Found {len(all_jobs)} valid jobs in batch file")
|
||||||
|
|
||||||
# Distribute jobs across GPUs if using multi-GPU
|
# Distribute jobs across GPUs if using multi-GPU
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
|
jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == gpu_id]
|
||||||
logger.info(f"GPU {gpu_id}/{num_gpus}: Processing {len(jobs)} jobs")
|
logger.info(f"GPU {gpu_id}/{num_gpus}: Processing {len(jobs)} jobs")
|
||||||
else:
|
else:
|
||||||
jobs = all_jobs
|
jobs = all_jobs
|
||||||
|
|
||||||
# Load or create state
|
# Load or create state
|
||||||
state = None
|
state = None
|
||||||
if resume:
|
if resume:
|
||||||
state = BatchState.load(state_file)
|
state = BatchState.load(state_file)
|
||||||
if state:
|
if state:
|
||||||
logger.info(f"Resuming batch: {len(state.completed)} already completed")
|
logger.info(f"Resuming batch: {len(state.completed)} already completed")
|
||||||
|
|
||||||
if state is None:
|
if state is None:
|
||||||
state = BatchState(
|
state = BatchState(
|
||||||
batch_file=str(batch_file),
|
batch_file=str(batch_file),
|
||||||
@ -450,7 +452,7 @@ def run_batch(
|
|||||||
generate_type=generate_type,
|
generate_type=generate_type,
|
||||||
started_at=datetime.now().isoformat(),
|
started_at=datetime.now().isoformat(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter out completed jobs
|
# Filter out completed jobs
|
||||||
if resume:
|
if resume:
|
||||||
pending_jobs = [j for j in jobs if not state.is_completed(j.output_name)]
|
pending_jobs = [j for j in jobs if not state.is_completed(j.output_name)]
|
||||||
@ -458,11 +460,11 @@ def run_batch(
|
|||||||
if skipped > 0:
|
if skipped > 0:
|
||||||
logger.info(f"Skipping {skipped} already-completed jobs")
|
logger.info(f"Skipping {skipped} already-completed jobs")
|
||||||
jobs = pending_jobs
|
jobs = pending_jobs
|
||||||
|
|
||||||
if not jobs:
|
if not jobs:
|
||||||
logger.info("No jobs to process")
|
logger.info("No jobs to process")
|
||||||
return {"total": 0, "completed": 0, "failed": 0, "skipped": len(all_jobs)}
|
return {"total": 0, "completed": 0, "failed": 0, "skipped": len(all_jobs)}
|
||||||
|
|
||||||
# Load pipeline
|
# Load pipeline
|
||||||
model_name = model_path.split("/")[-1]
|
model_name = model_path.split("/")[-1]
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
@ -472,19 +474,19 @@ def run_batch(
|
|||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
enable_cpu_offload=enable_cpu_offload,
|
enable_cpu_offload=enable_cpu_offload,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process jobs
|
# Process jobs
|
||||||
completed = 0
|
completed = 0
|
||||||
failed = 0
|
failed = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
with tqdm(total=len(jobs), desc="Generating videos", unit="video") as pbar:
|
with tqdm(total=len(jobs), desc="Generating videos", unit="video") as pbar:
|
||||||
for job in jobs:
|
for job in jobs:
|
||||||
output_path = output_dir / job.output_name
|
output_path = output_dir / job.output_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Processing: {job.output_name} - \"{job.prompt[:50]}...\"")
|
logger.info(f"Processing: {job.output_name} - \"{job.prompt[:50]}...\"")
|
||||||
|
|
||||||
generate_single_video(
|
generate_single_video(
|
||||||
pipe=pipe,
|
pipe=pipe,
|
||||||
job=job,
|
job=job,
|
||||||
@ -497,15 +499,15 @@ def run_batch(
|
|||||||
default_seed=default_seed,
|
default_seed=default_seed,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
)
|
)
|
||||||
|
|
||||||
state.mark_completed(job.output_name)
|
state.mark_completed(job.output_name)
|
||||||
completed += 1
|
completed += 1
|
||||||
logger.info(f"Completed: {job.output_name}")
|
logger.info(f"Completed: {job.output_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||||
logger.error(f"Failed: {job.output_name} - {error_msg}")
|
logger.error(f"Failed: {job.output_name} - {error_msg}")
|
||||||
|
|
||||||
# Log full traceback to error log
|
# Log full traceback to error log
|
||||||
with open(error_log, "a") as f:
|
with open(error_log, "a") as f:
|
||||||
f.write(f"\n{'='*60}\n")
|
f.write(f"\n{'='*60}\n")
|
||||||
@ -514,26 +516,24 @@ def run_batch(
|
|||||||
f.write(f"Time: {datetime.now().isoformat()}\n")
|
f.write(f"Time: {datetime.now().isoformat()}\n")
|
||||||
f.write(f"Error: {error_msg}\n")
|
f.write(f"Error: {error_msg}\n")
|
||||||
f.write(traceback.format_exc())
|
f.write(traceback.format_exc())
|
||||||
|
|
||||||
state.mark_failed(job.output_name, error_msg)
|
state.mark_failed(job.output_name, error_msg)
|
||||||
failed += 1
|
failed += 1
|
||||||
|
|
||||||
# Save state after each job (for resume)
|
# Save state after each job (for resume)
|
||||||
state.save(state_file)
|
state.save(state_file)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
# Update ETA in progress bar
|
# Update ETA in progress bar
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
if completed + failed > 0:
|
if completed + failed > 0:
|
||||||
avg_time = elapsed / (completed + failed)
|
avg_time = elapsed / (completed + failed)
|
||||||
remaining = len(jobs) - (completed + failed)
|
remaining = len(jobs) - (completed + failed)
|
||||||
eta_seconds = avg_time * remaining
|
eta_seconds = avg_time * remaining
|
||||||
pbar.set_postfix({
|
pbar.set_postfix(
|
||||||
"done": completed,
|
{"done": completed, "failed": failed, "ETA": f"{eta_seconds/60:.1f}m"}
|
||||||
"failed": failed,
|
)
|
||||||
"ETA": f"{eta_seconds/60:.1f}m"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Final summary
|
# Final summary
|
||||||
elapsed_total = time.time() - start_time
|
elapsed_total = time.time() - start_time
|
||||||
summary = {
|
summary = {
|
||||||
@ -544,7 +544,7 @@ def run_batch(
|
|||||||
"elapsed_seconds": elapsed_total,
|
"elapsed_seconds": elapsed_total,
|
||||||
"avg_seconds_per_video": elapsed_total / max(completed + failed, 1),
|
"avg_seconds_per_video": elapsed_total / max(completed + failed, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
logger.info("BATCH COMPLETE")
|
logger.info("BATCH COMPLETE")
|
||||||
logger.info(f" Total jobs: {summary['total']}")
|
logger.info(f" Total jobs: {summary['total']}")
|
||||||
@ -554,7 +554,7 @@ def run_batch(
|
|||||||
if summary['failed'] > 0:
|
if summary['failed'] > 0:
|
||||||
logger.info(f" See errors in: {error_log}")
|
logger.info(f" See errors in: {error_log}")
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
@ -566,11 +566,11 @@ def main():
|
|||||||
Examples:
|
Examples:
|
||||||
# Basic text-to-video batch
|
# Basic text-to-video batch
|
||||||
python batch_inference.py --batch_file prompts.jsonl --model_path THUDM/CogVideoX1.5-5B
|
python batch_inference.py --batch_file prompts.jsonl --model_path THUDM/CogVideoX1.5-5B
|
||||||
|
|
||||||
# Image-to-video batch with custom output directory
|
# Image-to-video batch with custom output directory
|
||||||
python batch_inference.py --batch_file i2v.jsonl --model_path THUDM/CogVideoX1.5-5B-I2V \\
|
python batch_inference.py --batch_file i2v.jsonl --model_path THUDM/CogVideoX1.5-5B-I2V \\
|
||||||
--generate_type i2v --output_dir ./my_videos
|
--generate_type i2v --output_dir ./my_videos
|
||||||
|
|
||||||
# Multi-GPU: run on 4 GPUs (one process per GPU)
|
# Multi-GPU: run on 4 GPUs (one process per GPU)
|
||||||
for i in {0..3}; do
|
for i in {0..3}; do
|
||||||
CUDA_VISIBLE_DEVICES=$i python batch_inference.py --batch_file batch.jsonl \\
|
CUDA_VISIBLE_DEVICES=$i python batch_inference.py --batch_file batch.jsonl \\
|
||||||
@ -579,128 +579,91 @@ Examples:
|
|||||||
|
|
||||||
JSONL Format:
|
JSONL Format:
|
||||||
Each line is a JSON object with: prompt (required), output_name (required),
|
Each line is a JSON object with: prompt (required), output_name (required),
|
||||||
and optional: image_path, video_path, num_frames, guidance_scale,
|
and optional: image_path, video_path, num_frames, guidance_scale,
|
||||||
num_inference_steps, seed, width, height
|
num_inference_steps, seed, width, height
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Required arguments
|
# Required arguments
|
||||||
parser.add_argument(
|
parser.add_argument("--batch_file", type=str, required=True, help="Path to JSONL batch file")
|
||||||
"--batch_file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to JSONL batch file"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path",
|
"--model_path",
|
||||||
type=str,
|
type=str,
|
||||||
default="THUDM/CogVideoX1.5-5B",
|
default="THUDM/CogVideoX1.5-5B",
|
||||||
help="Path to the model (default: THUDM/CogVideoX1.5-5B)"
|
help="Path to the model (default: THUDM/CogVideoX1.5-5B)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output settings
|
# Output settings
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./batch_output",
|
default="./batch_output",
|
||||||
help="Directory for output videos (default: ./batch_output)"
|
help="Directory for output videos (default: ./batch_output)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--generate_type",
|
"--generate_type",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["t2v", "i2v", "v2v"],
|
choices=["t2v", "i2v", "v2v"],
|
||||||
default="t2v",
|
default="t2v",
|
||||||
help="Generation type (default: t2v)"
|
help="Generation type (default: t2v)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model settings
|
# Model settings
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora_path",
|
"--lora_path", type=str, default=None, help="Path to LoRA weights (optional)"
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to LoRA weights (optional)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype",
|
"--dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["float16", "bfloat16"],
|
choices=["float16", "bfloat16"],
|
||||||
default="bfloat16",
|
default="bfloat16",
|
||||||
help="Data type for computation (default: bfloat16)"
|
help="Data type for computation (default: bfloat16)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable_cpu_offload",
|
"--disable_cpu_offload",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable CPU offloading (uses more VRAM but faster)"
|
help="Disable CPU offloading (uses more VRAM but faster)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Default generation parameters
|
# Default generation parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_frames",
|
"--num_frames", type=int, default=81, help="Default number of frames (default: 81)"
|
||||||
type=int,
|
|
||||||
default=81,
|
|
||||||
help="Default number of frames (default: 81)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--guidance_scale",
|
"--guidance_scale", type=float, default=6.0, help="Default guidance scale (default: 6.0)"
|
||||||
type=float,
|
|
||||||
default=6.0,
|
|
||||||
help="Default guidance scale (default: 6.0)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_inference_steps",
|
"--num_inference_steps", type=int, default=50, help="Default inference steps (default: 50)"
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="Default inference steps (default: 50)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--seed", type=int, default=42, help="Default random seed (default: 42)")
|
||||||
"--seed",
|
parser.add_argument("--fps", type=int, default=16, help="Output video FPS (default: 16)")
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="Default random seed (default: 42)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fps",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="Output video FPS (default: 16)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resume and multi-GPU
|
# Resume and multi-GPU
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume",
|
"--resume",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=True,
|
default=True,
|
||||||
help="Resume from previous state (default: True)"
|
help="Resume from previous state (default: True)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--no_resume", action="store_true", help="Don't resume, start fresh")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu_id", type=int, default=0, help="GPU ID for multi-GPU distribution (default: 0)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_resume",
|
"--num_gpus", type=int, default=1, help="Total number of GPUs for distribution (default: 1)"
|
||||||
action="store_true",
|
|
||||||
help="Don't resume, start fresh"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--gpu_id",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="GPU ID for multi-GPU distribution (default: 0)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_gpus",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Total number of GPUs for distribution (default: 1)"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate batch file exists
|
# Validate batch file exists
|
||||||
batch_file = Path(args.batch_file)
|
batch_file = Path(args.batch_file)
|
||||||
if not batch_file.exists():
|
if not batch_file.exists():
|
||||||
logger.error(f"Batch file not found: {batch_file}")
|
logger.error(f"Batch file not found: {batch_file}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Parse dtype
|
# Parse dtype
|
||||||
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
||||||
|
|
||||||
# Run batch
|
# Run batch
|
||||||
try:
|
try:
|
||||||
summary = run_batch(
|
summary = run_batch(
|
||||||
@ -720,11 +683,11 @@ JSONL Format:
|
|||||||
default_seed=args.seed,
|
default_seed=args.seed,
|
||||||
fps=args.fps,
|
fps=args.fps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exit with error code if any failures
|
# Exit with error code if any failures
|
||||||
if summary["failed"] > 0:
|
if summary["failed"] > 0:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("\nBatch interrupted by user. Progress saved for resume.")
|
logger.info("\nBatch interrupted by user. Progress saved for resume.")
|
||||||
sys.exit(130)
|
sys.exit(130)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user