diff --git a/finetune/datasets/__init__.py b/finetune/datasets/__init__.py index 204c00c..25f24e1 100644 --- a/finetune/datasets/__init__.py +++ b/finetune/datasets/__init__.py @@ -1,6 +1,6 @@ -from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets -from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets from .bucket_sampler import BucketSampler +from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize +from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize __all__ = [ diff --git a/finetune/datasets/bucket_sampler.py b/finetune/datasets/bucket_sampler.py index 91440a0..8bc1dde 100644 --- a/finetune/datasets/bucket_sampler.py +++ b/finetune/datasets/bucket_sampler.py @@ -1,8 +1,8 @@ -import random import logging +import random + +from torch.utils.data import Dataset, Sampler -from torch.utils.data import Sampler -from torch.utils.data import Dataset logger = logging.getLogger(__name__) diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py index f4ba6d5..6b06da4 100644 --- a/finetune/datasets/i2v_dataset.py +++ b/finetune/datasets/i2v_dataset.py @@ -1,26 +1,26 @@ -import torch import hashlib - from pathlib import Path -from typing import Any, Dict, List, Tuple, TYPE_CHECKING -from typing_extensions import override +from typing import TYPE_CHECKING, Any, Dict, List, Tuple +import torch +from accelerate.logging import get_logger +from safetensors.torch import load_file, save_file from torch.utils.data import Dataset from torchvision import transforms -from accelerate.logging import get_logger -from safetensors.torch import save_file, load_file +from typing_extensions import override -from finetune.constants import LOG_NAME, LOG_LEVEL +from finetune.constants import LOG_LEVEL, LOG_NAME from .utils import ( + load_images, load_prompts, load_videos, - load_images, preprocess_image_with_resize, - preprocess_video_with_resize, preprocess_video_with_buckets, + preprocess_video_with_resize, ) + if TYPE_CHECKING: from finetune.trainer import Trainer diff --git a/finetune/datasets/t2v_dataset.py b/finetune/datasets/t2v_dataset.py index e42aff9..d123ccf 100644 --- a/finetune/datasets/t2v_dataset.py +++ b/finetune/datasets/t2v_dataset.py @@ -1,18 +1,18 @@ import hashlib -import torch - from pathlib import Path -from typing import Any, Dict, List, Tuple, TYPE_CHECKING -from typing_extensions import override +from typing import TYPE_CHECKING, Any, Dict, List, Tuple +import torch +from accelerate.logging import get_logger +from safetensors.torch import load_file, save_file from torch.utils.data import Dataset from torchvision import transforms -from accelerate.logging import get_logger -from safetensors.torch import save_file, load_file +from typing_extensions import override -from finetune.constants import LOG_NAME, LOG_LEVEL +from finetune.constants import LOG_LEVEL, LOG_NAME + +from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize -from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets if TYPE_CHECKING: from finetune.trainer import Trainer diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py index 82589ff..d28975e 100644 --- a/finetune/datasets/utils.py +++ b/finetune/datasets/utils.py @@ -1,11 +1,11 @@ -import torch -import cv2 - -from typing import List, Tuple from pathlib import Path -from torchvision import transforms +from typing import List, Tuple + +import cv2 +import torch from torchvision.transforms.functional import resize + # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error # Very few bug reports but it happens. Look in decord Github issues for more relevant information. import decord # isort:skip diff --git a/finetune/models/cogvideox1dot5_i2v/lora_trainer.py b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py index 09d4b70..8b2b558 100644 --- a/finetune/models/cogvideox1dot5_i2v/lora_trainer.py +++ b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py @@ -1,5 +1,5 @@ -from ..utils import register from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer +from ..utils import register class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer): diff --git a/finetune/models/cogvideox_i2v/lora_trainer.py b/finetune/models/cogvideox_i2v/lora_trainer.py index f35c8cc..e7e1512 100644 --- a/finetune/models/cogvideox_i2v/lora_trainer.py +++ b/finetune/models/cogvideox_i2v/lora_trainer.py @@ -1,22 +1,21 @@ -import torch - -from typing_extensions import override from typing import Any, Dict, List, Tuple -from PIL import Image -from transformers import AutoTokenizer, T5EncoderModel - -from diffusers.models.embeddings import get_3d_rotary_pos_embed +import torch from diffusers import ( - CogVideoXImageToVideoPipeline, - CogVideoXTransformer3DModel, AutoencoderKLCogVideoX, CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, ) +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel +from typing_extensions import override -from finetune.trainer import Trainer from finetune.schemas import Components +from finetune.trainer import Trainer from finetune.utils import unwrap_model + from ..utils import register diff --git a/finetune/models/cogvideox_t2v/lora_trainer.py b/finetune/models/cogvideox_t2v/lora_trainer.py index c8d28e7..ec3b7fd 100644 --- a/finetune/models/cogvideox_t2v/lora_trainer.py +++ b/finetune/models/cogvideox_t2v/lora_trainer.py @@ -1,23 +1,21 @@ -import torch - -from typing_extensions import override from typing import Any, Dict, List, Tuple -from PIL import Image - -from transformers import AutoTokenizer, T5EncoderModel - -from diffusers.models.embeddings import get_3d_rotary_pos_embed +import torch from diffusers import ( - CogVideoXPipeline, - CogVideoXTransformer3DModel, AutoencoderKLCogVideoX, CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, ) +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel +from typing_extensions import override -from finetune.trainer import Trainer from finetune.schemas import Components +from finetune.trainer import Trainer from finetune.utils import unwrap_model + from ..utils import register diff --git a/finetune/models/utils.py b/finetune/models/utils.py index fd5a455..dcc963f 100644 --- a/finetune/models/utils.py +++ b/finetune/models/utils.py @@ -1,4 +1,4 @@ -from typing import Literal, Dict +from typing import Dict, Literal from finetune.trainer import Trainer diff --git a/finetune/schemas/__init__.py b/finetune/schemas/__init__.py index b566983..73f547b 100644 --- a/finetune/schemas/__init__.py +++ b/finetune/schemas/__init__.py @@ -1,5 +1,6 @@ from .args import Args -from .state import State from .components import Components +from .state import State + __all__ = ["Args", "State", "Components"] diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py index d4e10d2..b2fd9ce 100644 --- a/finetune/schemas/args.py +++ b/finetune/schemas/args.py @@ -1,9 +1,9 @@ -import datetime import argparse -from typing import Dict, Any, Literal, List, Tuple -from pydantic import BaseModel, field_validator, ValidationInfo - +import datetime from pathlib import Path +from typing import Any, List, Literal, Tuple + +from pydantic import BaseModel, ValidationInfo, field_validator class Args(BaseModel): diff --git a/finetune/schemas/components.py b/finetune/schemas/components.py index 2d3fef5..5edd38e 100644 --- a/finetune/schemas/components.py +++ b/finetune/schemas/components.py @@ -1,4 +1,5 @@ from typing import Any + from pydantic import BaseModel diff --git a/finetune/schemas/state.py b/finetune/schemas/state.py index d715c6d..315185c 100644 --- a/finetune/schemas/state.py +++ b/finetune/schemas/state.py @@ -1,8 +1,8 @@ -import torch - from pathlib import Path -from typing import List, Dict, Any -from pydantic import BaseModel, field_validator +from typing import Any, Dict, List + +import torch +from pydantic import BaseModel class State(BaseModel): diff --git a/finetune/scripts/extract_images.py b/finetune/scripts/extract_images.py index 8d0d9fa..42ce8e2 100644 --- a/finetune/scripts/extract_images.py +++ b/finetune/scripts/extract_images.py @@ -1,6 +1,7 @@ import argparse import os from pathlib import Path + import cv2 diff --git a/finetune/train.py b/finetune/train.py index 5f49f4b..660a744 100644 --- a/finetune/train.py +++ b/finetune/train.py @@ -1,10 +1,11 @@ import sys from pathlib import Path + sys.path.append(str(Path(__file__).parent.parent)) -from finetune.schemas import Args from finetune.models.utils import get_model_cls +from finetune.schemas import Args def main(): diff --git a/finetune/trainer.py b/finetune/trainer.py index 701f6c0..981c500 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -1,56 +1,52 @@ +import json import logging import math -import json - -import torch -import transformers -import diffusers -import wandb - from datetime import timedelta from pathlib import Path -from tqdm import tqdm -from typing import Dict, Any, List, Tuple -from PIL import Image +from typing import Any, Dict, List, Tuple -from torch.utils.data import Dataset, DataLoader -from accelerate.logging import get_logger +import diffusers +import torch +import transformers +import wandb from accelerate.accelerator import Accelerator, DistributedType +from accelerate.logging import get_logger from accelerate.utils import ( DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, - set_seed, gather_object, + set_seed, ) - -from diffusers.pipelines import DiffusionPipeline from diffusers.optimization import get_scheduler +from diffusers.pipelines import DiffusionPipeline from diffusers.utils.export_utils import export_to_video from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm -from finetune.schemas import Args, State, Components -from finetune.utils import ( - unwrap_model, - cast_training_params, - get_optimizer, - get_memory_statistics, - free_memory, - unload_model, - get_latest_ckpt_path_to_resume_from, - get_intermediate_ckpt_path, - string_to_filename, -) +from finetune.constants import LOG_LEVEL, LOG_NAME from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize from finetune.datasets.utils import ( - load_prompts, load_images, + load_prompts, load_videos, preprocess_image_with_resize, preprocess_video_with_resize, ) - -from finetune.constants import LOG_NAME, LOG_LEVEL +from finetune.schemas import Args, Components, State +from finetune.utils import ( + cast_training_params, + free_memory, + get_intermediate_ckpt_path, + get_latest_ckpt_path_to_resume_from, + get_memory_statistics, + get_optimizer, + string_to_filename, + unload_model, + unwrap_model, +) logger = get_logger(LOG_NAME, LOG_LEVEL) diff --git a/finetune/utils/__init__.py b/finetune/utils/__init__.py index 7becfd5..9ff4912 100644 --- a/finetune/utils/__init__.py +++ b/finetune/utils/__init__.py @@ -1,5 +1,5 @@ -from .torch_utils import * -from .optimizer_utils import * -from .memory_utils import * from .checkpointing import * from .file_utils import * +from .memory_utils import * +from .optimizer_utils import * +from .torch_utils import * diff --git a/finetune/utils/checkpointing.py b/finetune/utils/checkpointing.py index 1797153..775038c 100644 --- a/finetune/utils/checkpointing.py +++ b/finetune/utils/checkpointing.py @@ -1,10 +1,12 @@ import os from pathlib import Path from typing import Tuple + from accelerate.logging import get_logger -from finetune.constants import LOG_NAME, LOG_LEVEL -from ..utils.file_utils import find_files, delete_files +from finetune.constants import LOG_LEVEL, LOG_NAME + +from ..utils.file_utils import delete_files, find_files logger = get_logger(LOG_NAME, LOG_LEVEL) diff --git a/finetune/utils/file_utils.py b/finetune/utils/file_utils.py index f04dd85..38b1105 100644 --- a/finetune/utils/file_utils.py +++ b/finetune/utils/file_utils.py @@ -1,11 +1,12 @@ import logging import os import shutil - from pathlib import Path from typing import Any, Dict, List, Union + from accelerate.logging import get_logger -from finetune.constants import LOG_NAME, LOG_LEVEL + +from finetune.constants import LOG_LEVEL, LOG_NAME logger = get_logger(LOG_NAME, LOG_LEVEL) diff --git a/finetune/utils/memory_utils.py b/finetune/utils/memory_utils.py index b341d22..0c88d70 100644 --- a/finetune/utils/memory_utils.py +++ b/finetune/utils/memory_utils.py @@ -1,10 +1,10 @@ import gc -import torch - from typing import Any, Dict, Union + +import torch from accelerate.logging import get_logger -from finetune.constants import LOG_NAME, LOG_LEVEL +from finetune.constants import LOG_LEVEL, LOG_NAME logger = get_logger(LOG_NAME, LOG_LEVEL) diff --git a/finetune/utils/optimizer_utils.py b/finetune/utils/optimizer_utils.py index bd93f9c..d24aa3f 100644 --- a/finetune/utils/optimizer_utils.py +++ b/finetune/utils/optimizer_utils.py @@ -1,9 +1,9 @@ import inspect -import torch +import torch from accelerate.logging import get_logger -from finetune.constants import LOG_NAME, LOG_LEVEL +from finetune.constants import LOG_LEVEL, LOG_NAME logger = get_logger(LOG_NAME, LOG_LEVEL) diff --git a/finetune/utils/torch_utils.py b/finetune/utils/torch_utils.py index c63bd88..867a8bf 100644 --- a/finetune/utils/torch_utils.py +++ b/finetune/utils/torch_utils.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Union, List +from typing import Dict, List, Optional, Union import torch from accelerate import Accelerator diff --git a/inference/cli_demo.py b/inference/cli_demo.py index b9820c9..7c34216 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -17,20 +17,20 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths. """ -import logging import argparse +import logging from typing import Literal, Optional import torch from diffusers import ( - CogVideoXPipeline, CogVideoXDPMScheduler, CogVideoXImageToVideoPipeline, + CogVideoXPipeline, CogVideoXVideoToVideoPipeline, ) - from diffusers.utils import export_to_video, load_image, load_video + logging.basicConfig(level=logging.INFO) # Recommended resolution for each model (width, height) @@ -38,7 +38,6 @@ RESOLUTION_MAP = { # cogvideox1.5-* "cogvideox1.5-5b-i2v": (1360, 768), "cogvideox1.5-5b": (1360, 768), - # cogvideox-* "cogvideox-5b-i2v": (720, 480), "cogvideox-5b": (720, 480), @@ -100,10 +99,14 @@ def generate_video( elif (width, height) != desired_resolution: if generate_type == "i2v": # For i2v models, use user-defined width and height - logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m") + logging.warning( + f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m" + ) else: # Otherwise, use the recommended width and height - logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m") + logging.warning( + f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m" + ) width, height = desired_resolution if generate_type == "i2v":