mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
188 lines
7.2 KiB
Python
188 lines
7.2 KiB
Python
import os
|
|
import gradio as gr
|
|
import gc
|
|
import random
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
import transformers
|
|
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline
|
|
from diffusers.utils import export_to_video
|
|
from transformers import AutoTokenizer
|
|
from datetime import datetime, timedelta
|
|
import threading
|
|
import time
|
|
from moviepy import VideoFileClip
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
# Set default values
|
|
caption_generator_model_id = "/share/home/zyx/Models/Meta-Llama-3.1-8B-Instruct"
|
|
image_generator_model_id = "/share/home/zyx/Models/FLUX.1-dev"
|
|
video_generator_model_id = "/share/official_pretrains/hf_home/CogVideoX-5b-I2V"
|
|
seed = 1337
|
|
|
|
os.makedirs("./output", exist_ok=True)
|
|
os.makedirs("./gradio_tmp", exist_ok=True)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(caption_generator_model_id, trust_remote_code=True)
|
|
caption_generator = transformers.pipeline(
|
|
"text-generation",
|
|
model=caption_generator_model_id,
|
|
device_map="balanced",
|
|
model_kwargs={
|
|
"local_files_only": True,
|
|
"torch_dtype": torch.bfloat16,
|
|
},
|
|
trust_remote_code=True,
|
|
tokenizer=tokenizer,
|
|
)
|
|
|
|
image_generator = DiffusionPipeline.from_pretrained(
|
|
image_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced"
|
|
)
|
|
# image_generator.to("cuda")
|
|
|
|
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
|
video_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced"
|
|
)
|
|
|
|
video_generator.vae.enable_slicing()
|
|
video_generator.vae.enable_tiling()
|
|
|
|
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
|
|
video_generator.scheduler.config, timestep_spacing="trailing"
|
|
)
|
|
|
|
# Define prompts
|
|
SYSTEM_PROMPT = """
|
|
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe.
|
|
|
|
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. Your task is to summarize the descriptions of videos provided by users and create detailed prompts to feed into the generative model.
|
|
|
|
There are a few rules to follow:
|
|
- You will only ever output a single video description per request.
|
|
- If the user mentions to summarize the prompt in [X] words, make sure not to exceed the limit.
|
|
|
|
Your responses should just be the video generation prompt. Here are examples:
|
|
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
|
|
- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart of the city, holding a can of spray paint, spray-painting a colorful bird on a mottled wall."
|
|
""".strip()
|
|
|
|
USER_PROMPT = """
|
|
Could you generate a prompt for a video generation model? Please limit the prompt to [{0}] words.
|
|
""".strip()
|
|
|
|
|
|
def generate_caption(prompt):
|
|
num_words = random.choice([25, 50, 75, 100])
|
|
user_prompt = USER_PROMPT.format(num_words)
|
|
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": prompt + "\n" + user_prompt},
|
|
]
|
|
|
|
response = caption_generator(messages, max_new_tokens=226, return_full_text=False)
|
|
caption = response[0]["generated_text"]
|
|
if caption.startswith("\"") and caption.endswith("\""):
|
|
caption = caption[1:-1]
|
|
return caption
|
|
|
|
|
|
def generate_image(caption, progress=gr.Progress(track_tqdm=True)):
|
|
image = image_generator(
|
|
prompt=caption,
|
|
height=480,
|
|
width=720,
|
|
num_inference_steps=30,
|
|
guidance_scale=3.5,
|
|
).images[0]
|
|
return image, image # One for output One for State
|
|
|
|
|
|
def generate_video(caption, image, progress=gr.Progress(track_tqdm=True)):
|
|
generator = torch.Generator().manual_seed(seed)
|
|
video_frames = video_generator(
|
|
image=image,
|
|
prompt=caption,
|
|
height=480,
|
|
width=720,
|
|
num_frames=49,
|
|
num_inference_steps=50,
|
|
guidance_scale=6,
|
|
use_dynamic_cfg=True,
|
|
generator=generator,
|
|
).frames[0]
|
|
video_path = save_video(video_frames)
|
|
gif_path = convert_to_gif(video_path)
|
|
return video_path, gif_path
|
|
|
|
|
|
def save_video(tensor):
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
video_path = f"./output/{timestamp}.mp4"
|
|
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
|
export_to_video(tensor, video_path, fps=8)
|
|
return video_path
|
|
|
|
|
|
def convert_to_gif(video_path):
|
|
clip = VideoFileClip(video_path)
|
|
clip = clip.with_fps(8)
|
|
clip = clip.resized(height=240)
|
|
gif_path = video_path.replace(".mp4", ".gif")
|
|
clip.write_gif(gif_path, fps=8)
|
|
return gif_path
|
|
|
|
|
|
def delete_old_files():
|
|
while True:
|
|
now = datetime.now()
|
|
cutoff = now - timedelta(minutes=10)
|
|
directories = ["./output", "./gradio_tmp"]
|
|
|
|
for directory in directories:
|
|
for filename in os.listdir(directory):
|
|
file_path = os.path.join(directory, filename)
|
|
if os.path.isfile(file_path):
|
|
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
|
|
if file_mtime < cutoff:
|
|
os.remove(file_path)
|
|
time.sleep(600)
|
|
|
|
|
|
threading.Thread(target=delete_old_files, daemon=True).start()
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("""
|
|
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
|
LLM + FLUX + CogVideoX-I2V Space 🤗
|
|
</div>
|
|
""")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=5)
|
|
generate_caption_button = gr.Button("Generate Caption")
|
|
caption = gr.Textbox(label="Caption", placeholder="Caption will appear here", lines=5)
|
|
generate_image_button = gr.Button("Generate Image")
|
|
image_output = gr.Image(label="Generated Image")
|
|
state_image = gr.State()
|
|
generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption)
|
|
generate_image_button.click(
|
|
fn=generate_image, inputs=caption, outputs=[image_output, state_image]
|
|
)
|
|
with gr.Column():
|
|
video_output = gr.Video(label="Generated Video", width=720, height=480)
|
|
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
|
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
|
generate_video_button = gr.Button("Generate Video from Image")
|
|
generate_video_button.click(
|
|
fn=generate_video,
|
|
inputs=[caption, state_image],
|
|
outputs=[video_output, download_gif_button],
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch()
|