style: format import statements across finetune module

This commit is contained in:
OleehyO 2025-01-07 05:47:39 +00:00
parent 1789f07256
commit 36427274d6
23 changed files with 112 additions and 109 deletions

View File

@ -1,6 +1,6 @@
from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets
from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets
from .bucket_sampler import BucketSampler
from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
__all__ = [

View File

@ -1,8 +1,8 @@
import random
import logging
import random
from torch.utils.data import Dataset, Sampler
from torch.utils.data import Sampler
from torch.utils.data import Dataset
logger = logging.getLogger(__name__)

View File

@ -1,26 +1,26 @@
import torch
import hashlib
from pathlib import Path
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
from typing_extensions import override
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
from torchvision import transforms
from accelerate.logging import get_logger
from safetensors.torch import save_file, load_file
from typing_extensions import override
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import (
load_images,
load_prompts,
load_videos,
load_images,
preprocess_image_with_resize,
preprocess_video_with_resize,
preprocess_video_with_buckets,
preprocess_video_with_resize,
)
if TYPE_CHECKING:
from finetune.trainer import Trainer

View File

@ -1,18 +1,18 @@
import hashlib
import torch
from pathlib import Path
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
from typing_extensions import override
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
from torchvision import transforms
from accelerate.logging import get_logger
from safetensors.torch import save_file, load_file
from typing_extensions import override
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
if TYPE_CHECKING:
from finetune.trainer import Trainer

View File

@ -1,11 +1,11 @@
import torch
import cv2
from typing import List, Tuple
from pathlib import Path
from torchvision import transforms
from typing import List, Tuple
import cv2
import torch
from torchvision.transforms.functional import resize
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip

View File

@ -1,5 +1,5 @@
from ..utils import register
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register
class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):

View File

@ -1,22 +1,21 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List, Tuple
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers.models.embeddings import get_3d_rotary_pos_embed
import torch
from diffusers import (
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
)
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from typing_extensions import override
from finetune.trainer import Trainer
from finetune.schemas import Components
from finetune.trainer import Trainer
from finetune.utils import unwrap_model
from ..utils import register

View File

@ -1,23 +1,21 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List, Tuple
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers.models.embeddings import get_3d_rotary_pos_embed
import torch
from diffusers import (
CogVideoXPipeline,
CogVideoXTransformer3DModel,
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
)
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from typing_extensions import override
from finetune.trainer import Trainer
from finetune.schemas import Components
from finetune.trainer import Trainer
from finetune.utils import unwrap_model
from ..utils import register

View File

@ -1,4 +1,4 @@
from typing import Literal, Dict
from typing import Dict, Literal
from finetune.trainer import Trainer

View File

@ -1,5 +1,6 @@
from .args import Args
from .state import State
from .components import Components
from .state import State
__all__ = ["Args", "State", "Components"]

View File

@ -1,9 +1,9 @@
import datetime
import argparse
from typing import Dict, Any, Literal, List, Tuple
from pydantic import BaseModel, field_validator, ValidationInfo
import datetime
from pathlib import Path
from typing import Any, List, Literal, Tuple
from pydantic import BaseModel, ValidationInfo, field_validator
class Args(BaseModel):

View File

@ -1,4 +1,5 @@
from typing import Any
from pydantic import BaseModel

View File

@ -1,8 +1,8 @@
import torch
from pathlib import Path
from typing import List, Dict, Any
from pydantic import BaseModel, field_validator
from typing import Any, Dict, List
import torch
from pydantic import BaseModel
class State(BaseModel):

View File

@ -1,6 +1,7 @@
import argparse
import os
from pathlib import Path
import cv2

View File

@ -1,10 +1,11 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from finetune.schemas import Args
from finetune.models.utils import get_model_cls
from finetune.schemas import Args
def main():

View File

@ -1,56 +1,52 @@
import json
import logging
import math
import json
import torch
import transformers
import diffusers
import wandb
from datetime import timedelta
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Any, List, Tuple
from PIL import Image
from typing import Any, Dict, List, Tuple
from torch.utils.data import Dataset, DataLoader
from accelerate.logging import get_logger
import diffusers
import torch
import transformers
import wandb
from accelerate.accelerator import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import (
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
ProjectConfiguration,
set_seed,
gather_object,
set_seed,
)
from diffusers.pipelines import DiffusionPipeline
from diffusers.optimization import get_scheduler
from diffusers.pipelines import DiffusionPipeline
from diffusers.utils.export_utils import export_to_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from finetune.schemas import Args, State, Components
from finetune.utils import (
unwrap_model,
cast_training_params,
get_optimizer,
get_memory_statistics,
free_memory,
unload_model,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
string_to_filename,
)
from finetune.constants import LOG_LEVEL, LOG_NAME
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
from finetune.datasets.utils import (
load_prompts,
load_images,
load_prompts,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_resize,
)
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.schemas import Args, Components, State
from finetune.utils import (
cast_training_params,
free_memory,
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_memory_statistics,
get_optimizer,
string_to_filename,
unload_model,
unwrap_model,
)
logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,5 +1,5 @@
from .torch_utils import *
from .optimizer_utils import *
from .memory_utils import *
from .checkpointing import *
from .file_utils import *
from .memory_utils import *
from .optimizer_utils import *
from .torch_utils import *

View File

@ -1,10 +1,12 @@
import os
from pathlib import Path
from typing import Tuple
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from ..utils.file_utils import find_files, delete_files
from finetune.constants import LOG_LEVEL, LOG_NAME
from ..utils.file_utils import delete_files, find_files
logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,11 +1,12 @@
import logging
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Union
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,10 +1,10 @@
import gc
import torch
from typing import Any, Dict, Union
import torch
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,9 +1,9 @@
import inspect
import torch
import torch
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Union, List
from typing import Dict, List, Optional, Union
import torch
from accelerate import Accelerator

View File

@ -17,20 +17,20 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
"""
import logging
import argparse
import logging
from typing import Literal, Optional
import torch
from diffusers import (
CogVideoXPipeline,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
)
from diffusers.utils import export_to_video, load_image, load_video
logging.basicConfig(level=logging.INFO)
# Recommended resolution for each model (width, height)
@ -38,7 +38,6 @@ RESOLUTION_MAP = {
# cogvideox1.5-*
"cogvideox1.5-5b-i2v": (1360, 768),
"cogvideox1.5-5b": (1360, 768),
# cogvideox-*
"cogvideox-5b-i2v": (720, 480),
"cogvideox-5b": (720, 480),
@ -100,10 +99,14 @@ def generate_video(
elif (width, height) != desired_resolution:
if generate_type == "i2v":
# For i2v models, use user-defined width and height
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m")
logging.warning(
f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m"
)
else:
# Otherwise, use the recommended width and height
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m")
logging.warning(
f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m"
)
width, height = desired_resolution
if generate_type == "i2v":