mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
style: format import statements across finetune module
This commit is contained in:
parent
1789f07256
commit
36427274d6
@ -1,6 +1,6 @@
|
|||||||
from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets
|
|
||||||
from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets
|
|
||||||
from .bucket_sampler import BucketSampler
|
from .bucket_sampler import BucketSampler
|
||||||
|
from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
|
||||||
|
from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset, Sampler
|
||||||
|
|
||||||
from torch.utils.data import Sampler
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -1,26 +1,26 @@
|
|||||||
import torch
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from accelerate.logging import get_logger
|
from typing_extensions import override
|
||||||
from safetensors.torch import save_file, load_file
|
|
||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
load_images,
|
||||||
load_prompts,
|
load_prompts,
|
||||||
load_videos,
|
load_videos,
|
||||||
load_images,
|
|
||||||
preprocess_image_with_resize,
|
preprocess_image_with_resize,
|
||||||
preprocess_video_with_resize,
|
|
||||||
preprocess_video_with_buckets,
|
preprocess_video_with_buckets,
|
||||||
|
preprocess_video_with_resize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from finetune.trainer import Trainer
|
from finetune.trainer import Trainer
|
||||||
|
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import torch
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from accelerate.logging import get_logger
|
from typing_extensions import override
|
||||||
from safetensors.torch import save_file, load_file
|
|
||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||||
|
|
||||||
|
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
|
||||||
|
|
||||||
from .utils import load_prompts, load_videos, preprocess_video_with_resize, preprocess_video_with_buckets
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from finetune.trainer import Trainer
|
from finetune.trainer import Trainer
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import torch
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
from typing import List, Tuple
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torchvision import transforms
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
from torchvision.transforms.functional import resize
|
from torchvision.transforms.functional import resize
|
||||||
|
|
||||||
|
|
||||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||||
import decord # isort:skip
|
import decord # isort:skip
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from ..utils import register
|
|
||||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||||
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
|
class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
|
||||||
|
@ -1,22 +1,21 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, T5EncoderModel
|
import torch
|
||||||
|
|
||||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
CogVideoXImageToVideoPipeline,
|
|
||||||
CogVideoXTransformer3DModel,
|
|
||||||
AutoencoderKLCogVideoX,
|
AutoencoderKLCogVideoX,
|
||||||
CogVideoXDPMScheduler,
|
CogVideoXDPMScheduler,
|
||||||
|
CogVideoXImageToVideoPipeline,
|
||||||
|
CogVideoXTransformer3DModel,
|
||||||
)
|
)
|
||||||
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoTokenizer, T5EncoderModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from finetune.trainer import Trainer
|
|
||||||
from finetune.schemas import Components
|
from finetune.schemas import Components
|
||||||
|
from finetune.trainer import Trainer
|
||||||
from finetune.utils import unwrap_model
|
from finetune.utils import unwrap_model
|
||||||
|
|
||||||
from ..utils import register
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,23 +1,21 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
import torch
|
||||||
|
|
||||||
from transformers import AutoTokenizer, T5EncoderModel
|
|
||||||
|
|
||||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
CogVideoXPipeline,
|
|
||||||
CogVideoXTransformer3DModel,
|
|
||||||
AutoencoderKLCogVideoX,
|
AutoencoderKLCogVideoX,
|
||||||
CogVideoXDPMScheduler,
|
CogVideoXDPMScheduler,
|
||||||
|
CogVideoXPipeline,
|
||||||
|
CogVideoXTransformer3DModel,
|
||||||
)
|
)
|
||||||
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoTokenizer, T5EncoderModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from finetune.trainer import Trainer
|
|
||||||
from finetune.schemas import Components
|
from finetune.schemas import Components
|
||||||
|
from finetune.trainer import Trainer
|
||||||
from finetune.utils import unwrap_model
|
from finetune.utils import unwrap_model
|
||||||
|
|
||||||
from ..utils import register
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Literal, Dict
|
from typing import Dict, Literal
|
||||||
|
|
||||||
from finetune.trainer import Trainer
|
from finetune.trainer import Trainer
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from .args import Args
|
from .args import Args
|
||||||
from .state import State
|
|
||||||
from .components import Components
|
from .components import Components
|
||||||
|
from .state import State
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Args", "State", "Components"]
|
__all__ = ["Args", "State", "Components"]
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import datetime
|
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Dict, Any, Literal, List, Tuple
|
import datetime
|
||||||
from pydantic import BaseModel, field_validator, ValidationInfo
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, List, Literal, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
|
||||||
class Args(BaseModel):
|
class Args(BaseModel):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any
|
from typing import Any, Dict, List
|
||||||
from pydantic import BaseModel, field_validator
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class State(BaseModel):
|
class State(BaseModel):
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
from finetune.schemas import Args
|
|
||||||
from finetune.models.utils import get_model_cls
|
from finetune.models.utils import get_model_cls
|
||||||
|
from finetune.schemas import Args
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -1,56 +1,52 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import json
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
import diffusers
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from typing import Any, Dict, List, Tuple
|
||||||
from typing import Dict, Any, List, Tuple
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
import diffusers
|
||||||
from accelerate.logging import get_logger
|
import torch
|
||||||
|
import transformers
|
||||||
|
import wandb
|
||||||
from accelerate.accelerator import Accelerator, DistributedType
|
from accelerate.accelerator import Accelerator, DistributedType
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
DistributedDataParallelKwargs,
|
DistributedDataParallelKwargs,
|
||||||
InitProcessGroupKwargs,
|
InitProcessGroupKwargs,
|
||||||
ProjectConfiguration,
|
ProjectConfiguration,
|
||||||
set_seed,
|
|
||||||
gather_object,
|
gather_object,
|
||||||
|
set_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
from diffusers.pipelines import DiffusionPipeline
|
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.pipelines import DiffusionPipeline
|
||||||
from diffusers.utils.export_utils import export_to_video
|
from diffusers.utils.export_utils import export_to_video
|
||||||
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from finetune.schemas import Args, State, Components
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||||
from finetune.utils import (
|
|
||||||
unwrap_model,
|
|
||||||
cast_training_params,
|
|
||||||
get_optimizer,
|
|
||||||
get_memory_statistics,
|
|
||||||
free_memory,
|
|
||||||
unload_model,
|
|
||||||
get_latest_ckpt_path_to_resume_from,
|
|
||||||
get_intermediate_ckpt_path,
|
|
||||||
string_to_filename,
|
|
||||||
)
|
|
||||||
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
||||||
from finetune.datasets.utils import (
|
from finetune.datasets.utils import (
|
||||||
load_prompts,
|
|
||||||
load_images,
|
load_images,
|
||||||
|
load_prompts,
|
||||||
load_videos,
|
load_videos,
|
||||||
preprocess_image_with_resize,
|
preprocess_image_with_resize,
|
||||||
preprocess_video_with_resize,
|
preprocess_video_with_resize,
|
||||||
)
|
)
|
||||||
|
from finetune.schemas import Args, Components, State
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.utils import (
|
||||||
|
cast_training_params,
|
||||||
|
free_memory,
|
||||||
|
get_intermediate_ckpt_path,
|
||||||
|
get_latest_ckpt_path_to_resume_from,
|
||||||
|
get_memory_statistics,
|
||||||
|
get_optimizer,
|
||||||
|
string_to_filename,
|
||||||
|
unload_model,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .torch_utils import *
|
|
||||||
from .optimizer_utils import *
|
|
||||||
from .memory_utils import *
|
|
||||||
from .checkpointing import *
|
from .checkpointing import *
|
||||||
from .file_utils import *
|
from .file_utils import *
|
||||||
|
from .memory_utils import *
|
||||||
|
from .optimizer_utils import *
|
||||||
|
from .torch_utils import *
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||||
from ..utils.file_utils import find_files, delete_files
|
|
||||||
|
from ..utils.file_utils import delete_files, find_files
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
|
||||||
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import gc
|
import gc
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import torch
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Optional, Union, List
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
@ -17,20 +17,20 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
|
|||||||
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
CogVideoXPipeline,
|
|
||||||
CogVideoXDPMScheduler,
|
CogVideoXDPMScheduler,
|
||||||
CogVideoXImageToVideoPipeline,
|
CogVideoXImageToVideoPipeline,
|
||||||
|
CogVideoXPipeline,
|
||||||
CogVideoXVideoToVideoPipeline,
|
CogVideoXVideoToVideoPipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
from diffusers.utils import export_to_video, load_image, load_video
|
from diffusers.utils import export_to_video, load_image, load_video
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# Recommended resolution for each model (width, height)
|
# Recommended resolution for each model (width, height)
|
||||||
@ -38,7 +38,6 @@ RESOLUTION_MAP = {
|
|||||||
# cogvideox1.5-*
|
# cogvideox1.5-*
|
||||||
"cogvideox1.5-5b-i2v": (1360, 768),
|
"cogvideox1.5-5b-i2v": (1360, 768),
|
||||||
"cogvideox1.5-5b": (1360, 768),
|
"cogvideox1.5-5b": (1360, 768),
|
||||||
|
|
||||||
# cogvideox-*
|
# cogvideox-*
|
||||||
"cogvideox-5b-i2v": (720, 480),
|
"cogvideox-5b-i2v": (720, 480),
|
||||||
"cogvideox-5b": (720, 480),
|
"cogvideox-5b": (720, 480),
|
||||||
@ -100,10 +99,14 @@ def generate_video(
|
|||||||
elif (width, height) != desired_resolution:
|
elif (width, height) != desired_resolution:
|
||||||
if generate_type == "i2v":
|
if generate_type == "i2v":
|
||||||
# For i2v models, use user-defined width and height
|
# For i2v models, use user-defined width and height
|
||||||
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m")
|
logging.warning(
|
||||||
|
f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Otherwise, use the recommended width and height
|
# Otherwise, use the recommended width and height
|
||||||
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m")
|
logging.warning(
|
||||||
|
f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m"
|
||||||
|
)
|
||||||
width, height = desired_resolution
|
width, height = desired_resolution
|
||||||
|
|
||||||
if generate_type == "i2v":
|
if generate_type == "i2v":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user