mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
258 lines
9.2 KiB
Python
258 lines
9.2 KiB
Python
"""
|
|
The original experimental code for this project can be found at:
|
|
|
|
https://gist.github.com/a-r-r-o-w/d070cce059ab4ceab3a9f289ff83c69c
|
|
|
|
By using this code, description prompts will be generated through a local large language model, and images will be
|
|
generated using the black-forest-labs/FLUX.1-dev model, followed by video generation via CogVideoX.
|
|
The entire process utilizes open-source solutions, without the need for any API keys.
|
|
|
|
You can use the generate.sh file in the same folder to automate running this code
|
|
for batch generation of videos and images.
|
|
|
|
bash generate.sh
|
|
|
|
"""
|
|
|
|
import argparse
|
|
import gc
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import random
|
|
from typing import Any, Dict
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
os.environ["TORCH_LOGS"] = "+dynamo,recompiles,graph_breaks"
|
|
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
|
|
|
|
import numpy as np
|
|
import torch
|
|
import transformers
|
|
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline
|
|
from diffusers.utils.logging import get_logger
|
|
from diffusers.utils import export_to_video
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
SYSTEM_PROMPT = """
|
|
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe.
|
|
|
|
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model.
|
|
|
|
There are a few rules to follow:
|
|
- You will only ever output a single video description per request.
|
|
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit.
|
|
|
|
You responses should just be the video generation prompt. Here are examples:
|
|
- “A lone figure stands on a city rooftop at night, gazing up at the full moon. The moon glows brightly, casting a gentle light over the quiet cityscape. Below, the windows of countless homes shine with warm lights, creating a contrast between the bustling life below and the peaceful solitude above. The scene captures the essence of the Mid-Autumn Festival, where despite the distance, the figure feels connected to loved ones through the shared beauty of the moonlit sky.”
|
|
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
|
|
- "A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall"
|
|
""".strip()
|
|
|
|
USER_PROMPT = """
|
|
Could you generate a prompt for a video generation model?
|
|
Please limit the prompt to [{0}] words.
|
|
""".strip()
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--num_videos",
|
|
type=int,
|
|
default=5,
|
|
help="Number of unique videos you would like to generate."
|
|
)
|
|
parser.add_argument(
|
|
"--model_path",
|
|
type=str,
|
|
default="THUDM/CogVideoX-5B",
|
|
help="The path of Image2Video CogVideoX-5B",
|
|
)
|
|
parser.add_argument(
|
|
"--caption_generator_model_id",
|
|
type=str,
|
|
default="THUDM/glm-4-9b-chat",
|
|
help="Caption generation model. default GLM-4-9B",
|
|
)
|
|
parser.add_argument(
|
|
"--caption_generator_cache_dir",
|
|
type=str,
|
|
default=None,
|
|
help="Cache directory for caption generation model."
|
|
)
|
|
parser.add_argument(
|
|
"--image_generator_model_id",
|
|
type=str,
|
|
default="black-forest-labs/FLUX.1-dev",
|
|
help="Image generation model."
|
|
)
|
|
parser.add_argument(
|
|
"--image_generator_cache_dir",
|
|
type=str,
|
|
default=None,
|
|
help="Cache directory for image generation model."
|
|
)
|
|
parser.add_argument(
|
|
"--image_generator_num_inference_steps",
|
|
type=int,
|
|
default=50,
|
|
help="Caption generation model."
|
|
)
|
|
parser.add_argument(
|
|
"--guidance_scale",
|
|
type=float,
|
|
default=7,
|
|
help="Guidance scale to be use for generation."
|
|
)
|
|
parser.add_argument(
|
|
"--use_dynamic_cfg",
|
|
action="store_true",
|
|
help="Whether or not to use cosine dynamic guidance for generation [Recommended].",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="outputs/",
|
|
help="Location where generated images and videos should be stored.",
|
|
)
|
|
parser.add_argument(
|
|
"--compile",
|
|
action="store_true",
|
|
help="Whether or not to compile the transformer of image and video generators."
|
|
)
|
|
parser.add_argument(
|
|
"--enable_vae_tiling",
|
|
action="store_true",
|
|
help="Whether or not to use VAE tiling when encoding/decoding."
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=42,
|
|
help="Seed for reproducibility."
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def reset_memory():
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.reset_accumulated_memory_stats()
|
|
|
|
|
|
@torch.no_grad()
|
|
def main(args: Dict[str, Any]) -> None:
|
|
output_dir = pathlib.Path(args.output_dir)
|
|
os.makedirs(output_dir.as_posix(), exist_ok=True)
|
|
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
reset_memory()
|
|
tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True)
|
|
caption_generator = transformers.pipeline(
|
|
"text-generation",
|
|
model=args.caption_generator_model_id,
|
|
device_map="auto",
|
|
model_kwargs={
|
|
"local_files_only": True,
|
|
"cache_dir": args.caption_generator_cache_dir,
|
|
"torch_dtype": torch.bfloat16,
|
|
},
|
|
trust_remote_code=True,
|
|
tokenizer=tokenizer
|
|
)
|
|
|
|
captions = []
|
|
for i in range(args.num_videos):
|
|
num_words = random.choice([50, 75, 100])
|
|
user_prompt = USER_PROMPT.format(num_words)
|
|
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
|
|
outputs = caption_generator(messages, max_new_tokens=226)
|
|
caption = outputs[0]["generated_text"][-1]["content"]
|
|
if caption.startswith("\"") and caption.endswith("\""):
|
|
caption = caption[1:-1]
|
|
captions.append(caption)
|
|
logger.info(f"Generated caption: {caption}")
|
|
|
|
with open(output_dir / "captions.json", "w") as file:
|
|
json.dump(captions, file)
|
|
|
|
del caption_generator
|
|
reset_memory()
|
|
|
|
image_generator = DiffusionPipeline.from_pretrained(
|
|
args.image_generator_model_id,
|
|
cache_dir=args.image_generator_cache_dir,
|
|
torch_dtype=torch.bfloat16
|
|
)
|
|
image_generator.to("cuda")
|
|
|
|
if args.compile:
|
|
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True)
|
|
|
|
if args.enable_vae_tiling:
|
|
image_generator.vae.enable_tiling()
|
|
|
|
images = []
|
|
for index, caption in enumerate(captions):
|
|
image = image_generator(
|
|
prompt=caption,
|
|
height=480,
|
|
width=720,
|
|
num_inference_steps=args.image_generator_num_inference_steps,
|
|
guidance_scale=3.5,
|
|
).images[0]
|
|
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
|
image.save(output_dir / f"{index}_{filename}.png")
|
|
images.append(image)
|
|
|
|
del image_generator
|
|
reset_memory()
|
|
|
|
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
|
args.model_path, torch_dtype=torch.bfloat16).to("cuda")
|
|
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
|
|
video_generator.scheduler.config,
|
|
timestep_spacing="trailing")
|
|
|
|
if args.compile:
|
|
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True)
|
|
|
|
if args.enable_vae_tiling:
|
|
video_generator.vae.enable_tiling()
|
|
|
|
generator = torch.Generator().manual_seed(args.seed)
|
|
for index, (caption, image) in enumerate(zip(captions, images)):
|
|
video = video_generator(
|
|
image=image,
|
|
prompt=caption,
|
|
height=480,
|
|
width=720,
|
|
num_frames=49,
|
|
num_inference_steps=50,
|
|
guidance_scale=args.guidance_scale,
|
|
use_dynamic_cfg=args.use_dynamic_cfg,
|
|
generator=generator,
|
|
).frames[0]
|
|
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
|
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
main(args)
|