mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-31 16:28:17 +08:00
chore: allow tests/ directory in gitignore
This commit is contained in:
parent
15a4b403c4
commit
206760830a
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,6 +8,8 @@ logs/
|
|||||||
.idea
|
.idea
|
||||||
output*
|
output*
|
||||||
test*
|
test*
|
||||||
|
!tests/
|
||||||
|
!tests/**
|
||||||
venv
|
venv
|
||||||
**/.swp
|
**/.swp
|
||||||
**/*.log
|
**/*.log
|
||||||
|
|||||||
@ -34,9 +34,11 @@ mock_torch.Generator = MagicMock
|
|||||||
# Create mock diffusers module
|
# Create mock diffusers module
|
||||||
mock_diffusers = MagicMock()
|
mock_diffusers = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
# Create mock tqdm that works as a context manager
|
# Create mock tqdm that works as a context manager
|
||||||
class MockTqdm:
|
class MockTqdm:
|
||||||
"""Mock tqdm that works as a context manager and returns a dummy progress bar."""
|
"""Mock tqdm that works as a context manager and returns a dummy progress bar."""
|
||||||
|
|
||||||
def __init__(self, iterable=None, total=None, **kwargs):
|
def __init__(self, iterable=None, total=None, **kwargs):
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
self.n = 0
|
self.n = 0
|
||||||
@ -488,10 +490,7 @@ class TestMultiGPUDistribution:
|
|||||||
|
|
||||||
def test_single_gpu_gets_all_jobs(self):
|
def test_single_gpu_gets_all_jobs(self):
|
||||||
"""Test that a single GPU processes all jobs."""
|
"""Test that a single GPU processes all jobs."""
|
||||||
all_jobs = [
|
all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
|
||||||
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Simulate distribution for single GPU (gpu_id=0, num_gpus=1)
|
# Simulate distribution for single GPU (gpu_id=0, num_gpus=1)
|
||||||
gpu_id = 0
|
gpu_id = 0
|
||||||
@ -502,10 +501,7 @@ class TestMultiGPUDistribution:
|
|||||||
|
|
||||||
def test_two_gpu_distribution(self):
|
def test_two_gpu_distribution(self):
|
||||||
"""Test job distribution across 2 GPUs."""
|
"""Test job distribution across 2 GPUs."""
|
||||||
all_jobs = [
|
all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
|
||||||
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
|
|
||||||
# GPU 0 gets jobs 0, 2, 4, 6, 8
|
# GPU 0 gets jobs 0, 2, 4, 6, 8
|
||||||
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0]
|
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % 2 == 0]
|
||||||
@ -519,10 +515,7 @@ class TestMultiGPUDistribution:
|
|||||||
|
|
||||||
def test_four_gpu_distribution(self):
|
def test_four_gpu_distribution(self):
|
||||||
"""Test job distribution across 4 GPUs."""
|
"""Test job distribution across 4 GPUs."""
|
||||||
all_jobs = [
|
all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(12)]
|
||||||
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
|
|
||||||
for i in range(12)
|
|
||||||
]
|
|
||||||
|
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
for gpu_id in range(num_gpus):
|
for gpu_id in range(num_gpus):
|
||||||
@ -531,10 +524,7 @@ class TestMultiGPUDistribution:
|
|||||||
|
|
||||||
def test_uneven_distribution(self):
|
def test_uneven_distribution(self):
|
||||||
"""Test distribution when jobs don't divide evenly."""
|
"""Test distribution when jobs don't divide evenly."""
|
||||||
all_jobs = [
|
all_jobs = [BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4") for i in range(10)]
|
||||||
BatchJob(prompt=f"Job {i}", output_name=f"video{i}.mp4")
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
|
|
||||||
num_gpus = 3
|
num_gpus = 3
|
||||||
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0]
|
gpu0_jobs = [j for i, j in enumerate(all_jobs) if i % num_gpus == 0]
|
||||||
@ -842,7 +832,9 @@ class TestRunBatch:
|
|||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_creates_output_dir(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_creates_output_dir(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that output directory is created if it doesn't exist."""
|
"""Test that output directory is created if it doesn't exist."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
@ -885,7 +877,9 @@ class TestRunBatch:
|
|||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_handles_errors(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_handles_errors(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that batch continues after individual job failures."""
|
"""Test that batch continues after individual job failures."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
@ -908,7 +902,9 @@ class TestRunBatch:
|
|||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_logs_errors_to_file(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_logs_errors_to_file(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that errors are logged to errors.log."""
|
"""Test that errors are logged to errors.log."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
@ -930,7 +926,9 @@ class TestRunBatch:
|
|||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_resume_skips_completed(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_resume_skips_completed(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that resume skips already-completed jobs."""
|
"""Test that resume skips already-completed jobs."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
@ -984,10 +982,7 @@ class TestRunBatch:
|
|||||||
def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path):
|
def test_run_batch_multi_gpu_distribution(self, mock_generate, mock_load_pipeline, tmp_path):
|
||||||
"""Test multi-GPU job distribution in run_batch."""
|
"""Test multi-GPU job distribution in run_batch."""
|
||||||
# Create batch file with 6 jobs
|
# Create batch file with 6 jobs
|
||||||
jobs = [
|
jobs = [{"prompt": f"Job {i}", "output_name": f"video{i}.mp4"} for i in range(6)]
|
||||||
{"prompt": f"Job {i}", "output_name": f"video{i}.mp4"}
|
|
||||||
for i in range(6)
|
|
||||||
]
|
|
||||||
batch_file = tmp_path / "batch.jsonl"
|
batch_file = tmp_path / "batch.jsonl"
|
||||||
batch_file.write_text("\n".join(json.dumps(j) for j in jobs))
|
batch_file.write_text("\n".join(json.dumps(j) for j in jobs))
|
||||||
|
|
||||||
@ -1011,7 +1006,9 @@ class TestRunBatch:
|
|||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_state_persists_after_each_job(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_state_persists_after_each_job(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that state is saved after each job, not just at the end."""
|
"""Test that state is saved after each job, not just at the end."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
@ -1107,7 +1104,9 @@ class TestEdgeCases:
|
|||||||
|
|
||||||
@patch.object(batch_inference, "load_pipeline")
|
@patch.object(batch_inference, "load_pipeline")
|
||||||
@patch.object(batch_inference, "generate_single_video")
|
@patch.object(batch_inference, "generate_single_video")
|
||||||
def test_run_batch_corrupted_state_starts_fresh(self, mock_generate, mock_load_pipeline, batch_file, tmp_path):
|
def test_run_batch_corrupted_state_starts_fresh(
|
||||||
|
self, mock_generate, mock_load_pipeline, batch_file, tmp_path
|
||||||
|
):
|
||||||
"""Test that corrupted state file doesn't prevent batch from running."""
|
"""Test that corrupted state file doesn't prevent batch from running."""
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_load_pipeline.return_value = mock_pipe
|
mock_load_pipeline.return_value = mock_pipe
|
||||||
|
|||||||
@ -107,6 +107,7 @@ RESOLUTION_MAP = {
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchJob:
|
class BatchJob:
|
||||||
"""Represents a single job in the batch."""
|
"""Represents a single job in the batch."""
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
output_name: str
|
output_name: str
|
||||||
image_path: Optional[str] = None
|
image_path: Optional[str] = None
|
||||||
@ -153,6 +154,7 @@ class BatchJob:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchState:
|
class BatchState:
|
||||||
"""Tracks batch progress for resume capability."""
|
"""Tracks batch progress for resume capability."""
|
||||||
|
|
||||||
batch_file: str
|
batch_file: str
|
||||||
output_dir: str
|
output_dir: str
|
||||||
model_path: str
|
model_path: str
|
||||||
@ -528,11 +530,9 @@ def run_batch(
|
|||||||
avg_time = elapsed / (completed + failed)
|
avg_time = elapsed / (completed + failed)
|
||||||
remaining = len(jobs) - (completed + failed)
|
remaining = len(jobs) - (completed + failed)
|
||||||
eta_seconds = avg_time * remaining
|
eta_seconds = avg_time * remaining
|
||||||
pbar.set_postfix({
|
pbar.set_postfix(
|
||||||
"done": completed,
|
{"done": completed, "failed": failed, "ETA": f"{eta_seconds/60:.1f}m"}
|
||||||
"failed": failed,
|
)
|
||||||
"ETA": f"{eta_seconds/60:.1f}m"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Final summary
|
# Final summary
|
||||||
elapsed_total = time.time() - start_time
|
elapsed_total = time.time() - start_time
|
||||||
@ -581,21 +581,16 @@ JSONL Format:
|
|||||||
Each line is a JSON object with: prompt (required), output_name (required),
|
Each line is a JSON object with: prompt (required), output_name (required),
|
||||||
and optional: image_path, video_path, num_frames, guidance_scale,
|
and optional: image_path, video_path, num_frames, guidance_scale,
|
||||||
num_inference_steps, seed, width, height
|
num_inference_steps, seed, width, height
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Required arguments
|
# Required arguments
|
||||||
parser.add_argument(
|
parser.add_argument("--batch_file", type=str, required=True, help="Path to JSONL batch file")
|
||||||
"--batch_file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to JSONL batch file"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path",
|
"--model_path",
|
||||||
type=str,
|
type=str,
|
||||||
default="THUDM/CogVideoX1.5-5B",
|
default="THUDM/CogVideoX1.5-5B",
|
||||||
help="Path to the model (default: THUDM/CogVideoX1.5-5B)"
|
help="Path to the model (default: THUDM/CogVideoX1.5-5B)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output settings
|
# Output settings
|
||||||
@ -603,91 +598,59 @@ JSONL Format:
|
|||||||
"--output_dir",
|
"--output_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./batch_output",
|
default="./batch_output",
|
||||||
help="Directory for output videos (default: ./batch_output)"
|
help="Directory for output videos (default: ./batch_output)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--generate_type",
|
"--generate_type",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["t2v", "i2v", "v2v"],
|
choices=["t2v", "i2v", "v2v"],
|
||||||
default="t2v",
|
default="t2v",
|
||||||
help="Generation type (default: t2v)"
|
help="Generation type (default: t2v)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model settings
|
# Model settings
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora_path",
|
"--lora_path", type=str, default=None, help="Path to LoRA weights (optional)"
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to LoRA weights (optional)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype",
|
"--dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["float16", "bfloat16"],
|
choices=["float16", "bfloat16"],
|
||||||
default="bfloat16",
|
default="bfloat16",
|
||||||
help="Data type for computation (default: bfloat16)"
|
help="Data type for computation (default: bfloat16)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable_cpu_offload",
|
"--disable_cpu_offload",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable CPU offloading (uses more VRAM but faster)"
|
help="Disable CPU offloading (uses more VRAM but faster)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Default generation parameters
|
# Default generation parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_frames",
|
"--num_frames", type=int, default=81, help="Default number of frames (default: 81)"
|
||||||
type=int,
|
|
||||||
default=81,
|
|
||||||
help="Default number of frames (default: 81)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--guidance_scale",
|
"--guidance_scale", type=float, default=6.0, help="Default guidance scale (default: 6.0)"
|
||||||
type=float,
|
|
||||||
default=6.0,
|
|
||||||
help="Default guidance scale (default: 6.0)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_inference_steps",
|
"--num_inference_steps", type=int, default=50, help="Default inference steps (default: 50)"
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="Default inference steps (default: 50)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="Default random seed (default: 42)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fps",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="Output video FPS (default: 16)"
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="Default random seed (default: 42)")
|
||||||
|
parser.add_argument("--fps", type=int, default=16, help="Output video FPS (default: 16)")
|
||||||
|
|
||||||
# Resume and multi-GPU
|
# Resume and multi-GPU
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume",
|
"--resume",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=True,
|
default=True,
|
||||||
help="Resume from previous state (default: True)"
|
help="Resume from previous state (default: True)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--no_resume", action="store_true", help="Don't resume, start fresh")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu_id", type=int, default=0, help="GPU ID for multi-GPU distribution (default: 0)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_resume",
|
"--num_gpus", type=int, default=1, help="Total number of GPUs for distribution (default: 1)"
|
||||||
action="store_true",
|
|
||||||
help="Don't resume, start fresh"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--gpu_id",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="GPU ID for multi-GPU distribution (default: 0)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_gpus",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Total number of GPUs for distribution (default: 1)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user