mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
style: format import statements across finetune module
This commit is contained in:
parent
1789f07256
commit
36427274d6
@ -1,6 +1,6 @@
|
||||
from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets
|
||||
from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets
|
||||
from .bucket_sampler import BucketSampler
|
||||
from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
|
||||
from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -1,8 +1,8 @@
|
||||
import random
|
||||
import logging
|
||||
import random
|
||||
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from torch.utils.data import Sampler
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -1,26 +1,26 @@
|
||||
import torch
|
||||
import hashlib
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from accelerate.logging import get_logger
|
||||
from safetensors.torch import save_file, load_file
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
from .utils import (
|
||||
load_images,
|
||||
load_prompts,
|
||||
load_videos,
|
||||
load_images,
|
||||
preprocess_image_with_resize,
|
||||
preprocess_video_with_resize,
|
||||
preprocess_video_with_buckets,
|
||||
preprocess_video_with_resize,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from finetune.trainer import Trainer
|
||||
|
||||
|
@ -1,18 +1,18 @@
|
||||
import hashlib
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from accelerate.logging import get_logger
|
||||
from safetensors.torch import save_file, load_file
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
|
||||
|
||||
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from finetune.trainer import Trainer
|
||||
|
@ -1,11 +1,11 @@
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
from typing import List, Tuple
|
||||
from pathlib import Path
|
||||
from torchvision import transforms
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision.transforms.functional import resize
|
||||
|
||||
|
||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||
import decord # isort:skip
|
||||
|
@ -1,5 +1,5 @@
|
||||
from ..utils import register
|
||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
|
||||
|
@ -1,22 +1,21 @@
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
import torch
|
||||
from diffusers import (
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXTransformer3DModel,
|
||||
AutoencoderKLCogVideoX,
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXTransformer3DModel,
|
||||
)
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.trainer import Trainer
|
||||
from finetune.schemas import Components
|
||||
from finetune.trainer import Trainer
|
||||
from finetune.utils import unwrap_model
|
||||
|
||||
from ..utils import register
|
||||
|
||||
|
||||
|
@ -1,23 +1,21 @@
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
import torch
|
||||
from diffusers import (
|
||||
CogVideoXPipeline,
|
||||
CogVideoXTransformer3DModel,
|
||||
AutoencoderKLCogVideoX,
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXTransformer3DModel,
|
||||
)
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.trainer import Trainer
|
||||
from finetune.schemas import Components
|
||||
from finetune.trainer import Trainer
|
||||
from finetune.utils import unwrap_model
|
||||
|
||||
from ..utils import register
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Dict
|
||||
from typing import Dict, Literal
|
||||
|
||||
from finetune.trainer import Trainer
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .args import Args
|
||||
from .state import State
|
||||
from .components import Components
|
||||
from .state import State
|
||||
|
||||
|
||||
__all__ = ["Args", "State", "Components"]
|
||||
|
@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import argparse
|
||||
from typing import Dict, Any, Literal, List, Tuple
|
||||
from pydantic import BaseModel, field_validator, ValidationInfo
|
||||
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Tuple
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
|
||||
class Args(BaseModel):
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
|
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from finetune.schemas import Args
|
||||
from finetune.models.utils import get_model_cls
|
||||
from finetune.schemas import Args
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -1,56 +1,52 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import json
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import diffusers
|
||||
import wandb
|
||||
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from accelerate.logging import get_logger
|
||||
import diffusers
|
||||
import torch
|
||||
import transformers
|
||||
import wandb
|
||||
from accelerate.accelerator import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import (
|
||||
DistributedDataParallelKwargs,
|
||||
InitProcessGroupKwargs,
|
||||
ProjectConfiguration,
|
||||
set_seed,
|
||||
gather_object,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.utils.export_utils import export_to_video
|
||||
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from finetune.schemas import Args, State, Components
|
||||
from finetune.utils import (
|
||||
unwrap_model,
|
||||
cast_training_params,
|
||||
get_optimizer,
|
||||
get_memory_statistics,
|
||||
free_memory,
|
||||
unload_model,
|
||||
get_latest_ckpt_path_to_resume_from,
|
||||
get_intermediate_ckpt_path,
|
||||
string_to_filename,
|
||||
)
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
||||
from finetune.datasets.utils import (
|
||||
load_prompts,
|
||||
load_images,
|
||||
load_prompts,
|
||||
load_videos,
|
||||
preprocess_image_with_resize,
|
||||
preprocess_video_with_resize,
|
||||
)
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
from finetune.schemas import Args, Components, State
|
||||
from finetune.utils import (
|
||||
cast_training_params,
|
||||
free_memory,
|
||||
get_intermediate_ckpt_path,
|
||||
get_latest_ckpt_path_to_resume_from,
|
||||
get_memory_statistics,
|
||||
get_optimizer,
|
||||
string_to_filename,
|
||||
unload_model,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .torch_utils import *
|
||||
from .optimizer_utils import *
|
||||
from .memory_utils import *
|
||||
from .checkpointing import *
|
||||
from .file_utils import *
|
||||
from .memory_utils import *
|
||||
from .optimizer_utils import *
|
||||
from .torch_utils import *
|
||||
|
@ -1,10 +1,12 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
from ..utils.file_utils import find_files, delete_files
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
from ..utils.file_utils import delete_files, find_files
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
@ -1,11 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
@ -1,10 +1,10 @@
|
||||
import gc
|
||||
import torch
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
@ -1,9 +1,9 @@
|
||||
import inspect
|
||||
import torch
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional, Union, List
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
@ -17,20 +17,20 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
|
||||
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
CogVideoXPipeline,
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
)
|
||||
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Recommended resolution for each model (width, height)
|
||||
@ -38,7 +38,6 @@ RESOLUTION_MAP = {
|
||||
# cogvideox1.5-*
|
||||
"cogvideox1.5-5b-i2v": (1360, 768),
|
||||
"cogvideox1.5-5b": (1360, 768),
|
||||
|
||||
# cogvideox-*
|
||||
"cogvideox-5b-i2v": (720, 480),
|
||||
"cogvideox-5b": (720, 480),
|
||||
@ -100,10 +99,14 @@ def generate_video(
|
||||
elif (width, height) != desired_resolution:
|
||||
if generate_type == "i2v":
|
||||
# For i2v models, use user-defined width and height
|
||||
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m")
|
||||
logging.warning(
|
||||
f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m"
|
||||
)
|
||||
else:
|
||||
# Otherwise, use the recommended width and height
|
||||
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m")
|
||||
logging.warning(
|
||||
f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m"
|
||||
)
|
||||
width, height = desired_resolution
|
||||
|
||||
if generate_type == "i2v":
|
||||
|
Loading…
x
Reference in New Issue
Block a user