style: format import statements across finetune module

This commit is contained in:
OleehyO 2025-01-07 05:47:39 +00:00
parent 1789f07256
commit 36427274d6
23 changed files with 112 additions and 109 deletions

View File

@ -1,6 +1,6 @@
from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets
from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets
from .bucket_sampler import BucketSampler from .bucket_sampler import BucketSampler
from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
__all__ = [ __all__ = [

View File

@ -1,8 +1,8 @@
import random
import logging import logging
import random
from torch.utils.data import Dataset, Sampler
from torch.utils.data import Sampler
from torch.utils.data import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,26 +1,26 @@
import torch
import hashlib import hashlib
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing_extensions import override
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from accelerate.logging import get_logger from typing_extensions import override
from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import ( from .utils import (
load_images,
load_prompts, load_prompts,
load_videos, load_videos,
load_images,
preprocess_image_with_resize, preprocess_image_with_resize,
preprocess_video_with_resize,
preprocess_video_with_buckets, preprocess_video_with_buckets,
preprocess_video_with_resize,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from finetune.trainer import Trainer from finetune.trainer import Trainer

View File

@ -1,18 +1,18 @@
import hashlib import hashlib
import torch
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing_extensions import override
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from accelerate.logging import get_logger from typing_extensions import override
from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
if TYPE_CHECKING: if TYPE_CHECKING:
from finetune.trainer import Trainer from finetune.trainer import Trainer

View File

@ -1,11 +1,11 @@
import torch
import cv2
from typing import List, Tuple
from pathlib import Path from pathlib import Path
from torchvision import transforms from typing import List, Tuple
import cv2
import torch
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information. # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip import decord # isort:skip

View File

@ -1,5 +1,5 @@
from ..utils import register
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register
class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer): class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):

View File

@ -1,22 +1,21 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel import torch
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers import ( from diffusers import (
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
CogVideoXDPMScheduler, CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
) )
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from typing_extensions import override
from finetune.trainer import Trainer
from finetune.schemas import Components from finetune.schemas import Components
from finetune.trainer import Trainer
from finetune.utils import unwrap_model from finetune.utils import unwrap_model
from ..utils import register from ..utils import register

View File

@ -1,23 +1,21 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
from PIL import Image import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers import ( from diffusers import (
CogVideoXPipeline,
CogVideoXTransformer3DModel,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
CogVideoXDPMScheduler, CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
) )
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from typing_extensions import override
from finetune.trainer import Trainer
from finetune.schemas import Components from finetune.schemas import Components
from finetune.trainer import Trainer
from finetune.utils import unwrap_model from finetune.utils import unwrap_model
from ..utils import register from ..utils import register

View File

@ -1,4 +1,4 @@
from typing import Literal, Dict from typing import Dict, Literal
from finetune.trainer import Trainer from finetune.trainer import Trainer

View File

@ -1,5 +1,6 @@
from .args import Args from .args import Args
from .state import State
from .components import Components from .components import Components
from .state import State
__all__ = ["Args", "State", "Components"] __all__ = ["Args", "State", "Components"]

View File

@ -1,9 +1,9 @@
import datetime
import argparse import argparse
from typing import Dict, Any, Literal, List, Tuple import datetime
from pydantic import BaseModel, field_validator, ValidationInfo
from pathlib import Path from pathlib import Path
from typing import Any, List, Literal, Tuple
from pydantic import BaseModel, ValidationInfo, field_validator
class Args(BaseModel): class Args(BaseModel):

View File

@ -1,4 +1,5 @@
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -1,8 +1,8 @@
import torch
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any from typing import Any, Dict, List
from pydantic import BaseModel, field_validator
import torch
from pydantic import BaseModel
class State(BaseModel): class State(BaseModel):

View File

@ -1,6 +1,7 @@
import argparse import argparse
import os import os
from pathlib import Path from pathlib import Path
import cv2 import cv2

View File

@ -1,10 +1,11 @@
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
from finetune.schemas import Args
from finetune.models.utils import get_model_cls from finetune.models.utils import get_model_cls
from finetune.schemas import Args
def main(): def main():

View File

@ -1,56 +1,52 @@
import json
import logging import logging
import math import math
import json
import torch
import transformers
import diffusers
import wandb
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from tqdm import tqdm from typing import Any, Dict, List, Tuple
from typing import Dict, Any, List, Tuple
from PIL import Image
from torch.utils.data import Dataset, DataLoader import diffusers
from accelerate.logging import get_logger import torch
import transformers
import wandb
from accelerate.accelerator import Accelerator, DistributedType from accelerate.accelerator import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import ( from accelerate.utils import (
DistributedDataParallelKwargs, DistributedDataParallelKwargs,
InitProcessGroupKwargs, InitProcessGroupKwargs,
ProjectConfiguration, ProjectConfiguration,
set_seed,
gather_object, gather_object,
set_seed,
) )
from diffusers.pipelines import DiffusionPipeline
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines import DiffusionPipeline
from diffusers.utils.export_utils import export_to_video from diffusers.utils.export_utils import export_to_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from finetune.schemas import Args, State, Components from finetune.constants import LOG_LEVEL, LOG_NAME
from finetune.utils import (
unwrap_model,
cast_training_params,
get_optimizer,
get_memory_statistics,
free_memory,
unload_model,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
string_to_filename,
)
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
from finetune.datasets.utils import ( from finetune.datasets.utils import (
load_prompts,
load_images, load_images,
load_prompts,
load_videos, load_videos,
preprocess_image_with_resize, preprocess_image_with_resize,
preprocess_video_with_resize, preprocess_video_with_resize,
) )
from finetune.schemas import Args, Components, State
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.utils import (
cast_training_params,
free_memory,
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_memory_statistics,
get_optimizer,
string_to_filename,
unload_model,
unwrap_model,
)
logger = get_logger(LOG_NAME, LOG_LEVEL) logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,5 +1,5 @@
from .torch_utils import *
from .optimizer_utils import *
from .memory_utils import *
from .checkpointing import * from .checkpointing import *
from .file_utils import * from .file_utils import *
from .memory_utils import *
from .optimizer_utils import *
from .torch_utils import *

View File

@ -1,10 +1,12 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Tuple
from accelerate.logging import get_logger from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_LEVEL, LOG_NAME
from ..utils.file_utils import find_files, delete_files
from ..utils.file_utils import delete_files, find_files
logger = get_logger(LOG_NAME, LOG_LEVEL) logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,11 +1,12 @@
import logging import logging
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from accelerate.logging import get_logger from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL) logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,10 +1,10 @@
import gc import gc
import torch
from typing import Any, Dict, Union from typing import Any, Dict, Union
import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL) logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,9 +1,9 @@
import inspect import inspect
import torch
import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL) logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Union, List from typing import Dict, List, Optional, Union
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator

View File

@ -17,20 +17,20 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths. Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
""" """
import logging
import argparse import argparse
import logging
from typing import Literal, Optional from typing import Literal, Optional
import torch import torch
from diffusers import ( from diffusers import (
CogVideoXPipeline,
CogVideoXDPMScheduler, CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline, CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline, CogVideoXVideoToVideoPipeline,
) )
from diffusers.utils import export_to_video, load_image, load_video from diffusers.utils import export_to_video, load_image, load_video
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# Recommended resolution for each model (width, height) # Recommended resolution for each model (width, height)
@ -38,7 +38,6 @@ RESOLUTION_MAP = {
# cogvideox1.5-* # cogvideox1.5-*
"cogvideox1.5-5b-i2v": (1360, 768), "cogvideox1.5-5b-i2v": (1360, 768),
"cogvideox1.5-5b": (1360, 768), "cogvideox1.5-5b": (1360, 768),
# cogvideox-* # cogvideox-*
"cogvideox-5b-i2v": (720, 480), "cogvideox-5b-i2v": (720, 480),
"cogvideox-5b": (720, 480), "cogvideox-5b": (720, 480),
@ -100,10 +99,14 @@ def generate_video(
elif (width, height) != desired_resolution: elif (width, height) != desired_resolution:
if generate_type == "i2v": if generate_type == "i2v":
# For i2v models, use user-defined width and height # For i2v models, use user-defined width and height
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m") logging.warning(
f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m"
)
else: else:
# Otherwise, use the recommended width and height # Otherwise, use the recommended width and height
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m") logging.warning(
f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m"
)
width, height = desired_resolution width, height = desired_resolution
if generate_type == "i2v": if generate_type == "i2v":