mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
app done
This commit is contained in:
parent
6e64359524
commit
0a558e0964
@ -3,7 +3,7 @@ THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to
|
||||
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
|
||||
|
||||
Usage:
|
||||
OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
||||
OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
||||
"""
|
||||
|
||||
import math
|
||||
@ -12,9 +12,20 @@ import random
|
||||
import threading
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import tempfile
|
||||
import imageio_ffmpeg
|
||||
import gradio as gr
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from PIL import Image
|
||||
from diffusers import (
|
||||
CogVideoXPipeline,
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils import load_video, load_image
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@ -31,18 +42,33 @@ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device)
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-5b",
|
||||
transformer=pipe.transformer,
|
||||
vae=pipe.vae,
|
||||
scheduler=pipe.scheduler,
|
||||
tokenizer=pipe.tokenizer,
|
||||
text_encoder=pipe.text_encoder,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
|
||||
# Unnecessary
|
||||
pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-5b",
|
||||
transformer=CogVideoXTransformer3DModel.from_pretrained(
|
||||
"THUDM/CogVideoX-5b-I2V", subfolder="transformers", torch_dtype=torch.bfloat16
|
||||
),
|
||||
vae=pipe.vae,
|
||||
scheduler=pipe.scheduler,
|
||||
tokenizer=pipe.tokenizer,
|
||||
text_encoder=pipe.text_encoder,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
# Compile
|
||||
|
||||
# pipe.transformer.to(memory_format=torch.channels_last)
|
||||
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
||||
pipe.transformer.to(memory_format=torch.channels_last)
|
||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
||||
pipe_image.transformer.to(memory_format=torch.channels_last)
|
||||
pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True)
|
||||
|
||||
os.makedirs("./output", exist_ok=True)
|
||||
os.makedirs("./gradio_tmp", exist_ok=True)
|
||||
@ -64,6 +90,80 @@ Video descriptions must have the same num of words as examples below. Extra word
|
||||
"""
|
||||
|
||||
|
||||
def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
|
||||
width, height = get_video_dimensions(input_video)
|
||||
|
||||
if width == 720 and height == 480:
|
||||
processed_video = input_video
|
||||
else:
|
||||
processed_video = center_crop_resize(input_video)
|
||||
return processed_video
|
||||
|
||||
|
||||
def get_video_dimensions(input_video_path):
|
||||
reader = imageio_ffmpeg.read_frames(input_video_path)
|
||||
metadata = next(reader)
|
||||
return metadata["size"]
|
||||
|
||||
|
||||
def center_crop_resize(input_video_path, target_width=720, target_height=480):
|
||||
cap = cv2.VideoCapture(input_video_path)
|
||||
|
||||
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
orig_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
width_factor = target_width / orig_width
|
||||
height_factor = target_height / orig_height
|
||||
resize_factor = max(width_factor, height_factor)
|
||||
|
||||
inter_width = int(orig_width * resize_factor)
|
||||
inter_height = int(orig_height * resize_factor)
|
||||
|
||||
target_fps = 8
|
||||
ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
|
||||
skip = min(5, ideal_skip) # Cap at 5
|
||||
|
||||
while (total_frames / (skip + 1)) < 49 and skip > 0:
|
||||
skip -= 1
|
||||
|
||||
processed_frames = []
|
||||
frame_count = 0
|
||||
total_read = 0
|
||||
|
||||
while frame_count < 49 and total_read < total_frames:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if total_read % (skip + 1) == 0:
|
||||
resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
start_x = (inter_width - target_width) // 2
|
||||
start_y = (inter_height - target_height) // 2
|
||||
cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width]
|
||||
|
||||
processed_frames.append(cropped)
|
||||
frame_count += 1
|
||||
|
||||
total_read += 1
|
||||
|
||||
cap.release()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
|
||||
temp_video_path = temp_file.name
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))
|
||||
|
||||
for frame in processed_frames:
|
||||
out.write(frame)
|
||||
|
||||
out.release()
|
||||
|
||||
return temp_video_path
|
||||
|
||||
|
||||
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
return prompt
|
||||
@ -103,7 +203,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
||||
},
|
||||
],
|
||||
model="glm-4-0520",
|
||||
model="glm-4-plus",
|
||||
temperature=0.01,
|
||||
top_p=0.7,
|
||||
stream=False,
|
||||
@ -116,6 +216,9 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
|
||||
def infer(
|
||||
prompt: str,
|
||||
image_input: str,
|
||||
video_input: str,
|
||||
video_strenght: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int = -1,
|
||||
@ -123,16 +226,44 @@ def infer(
|
||||
):
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 2**8 - 1)
|
||||
video_pt = pipe(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=1,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_frames=49,
|
||||
use_dynamic_cfg=True,
|
||||
output_type="pt",
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).frames
|
||||
|
||||
if video_input is not None:
|
||||
video = load_video(video_input)[:49] # Limit to 49 frames
|
||||
video_pt = pipe_video(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_videos_per_prompt=1,
|
||||
strength=video_strenght,
|
||||
use_dynamic_cfg=True,
|
||||
output_type="pt",
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).frames
|
||||
elif image_input is not None:
|
||||
image_input = Image.fromarray(image_input) # Change to PIL
|
||||
image = load_image(image_input)
|
||||
video_pt = pipe_image(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_videos_per_prompt=1,
|
||||
use_dynamic_cfg=True,
|
||||
output_type="pt",
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).frames
|
||||
else:
|
||||
video_pt = pipe(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=1,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_frames=49,
|
||||
use_dynamic_cfg=True,
|
||||
output_type="pt",
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).frames
|
||||
|
||||
return (video_pt, seed)
|
||||
|
||||
@ -163,6 +294,7 @@ def delete_old_files():
|
||||
|
||||
|
||||
threading.Thread(target=delete_old_files, daemon=True).start()
|
||||
examples = [["horse.mp4"], ["kitten.mp4"], ["train_running.mp4"]]
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("""
|
||||
@ -174,13 +306,24 @@ with gr.Blocks() as demo:
|
||||
<a href="https://github.com/THUDM/CogVideo">🌐 Github</a> |
|
||||
<a href="https://arxiv.org/pdf/2408.06072">📜 arxiv </a>
|
||||
</div>
|
||||
|
||||
<div style="text-align: center;display: flex;justify-content: center;align-items: center;margin-top: 1em;margin-bottom: .5em;">
|
||||
<span>If the Space is too busy, duplicate it to use privately</span>
|
||||
<a href="https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" width="160" style="
|
||||
margin-left: .75em;
|
||||
"></a>
|
||||
</div>
|
||||
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
|
||||
⚠️ This demo is for academic research and experiential use only.
|
||||
</div>
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
||||
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
||||
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
|
||||
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
|
||||
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
|
||||
examples_component = gr.Examples(examples, inputs=[video_input], cache_examples=False)
|
||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||
|
||||
with gr.Row():
|
||||
@ -188,7 +331,6 @@ with gr.Blocks() as demo:
|
||||
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
|
||||
)
|
||||
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
||||
|
||||
with gr.Group():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
@ -275,13 +417,25 @@ with gr.Blocks() as demo:
|
||||
</table>
|
||||
""")
|
||||
|
||||
def generate(prompt, seed_value, scale_status, rife_status, progress=gr.Progress(track_tqdm=True)):
|
||||
def generate(
|
||||
prompt,
|
||||
image_input,
|
||||
video_input,
|
||||
video_strength,
|
||||
seed_value,
|
||||
scale_status,
|
||||
rife_status,
|
||||
progress=gr.Progress(track_tqdm=True)
|
||||
):
|
||||
latents, seed = infer(
|
||||
prompt,
|
||||
image_input,
|
||||
video_input,
|
||||
video_strength,
|
||||
num_inference_steps=50, # NOT Changed
|
||||
guidance_scale=7.0, # NOT Changed
|
||||
seed=seed_value,
|
||||
# progress=progress,
|
||||
progress=progress,
|
||||
)
|
||||
if scale_status:
|
||||
latents = utils.upscale_batch_and_concatenate(upscale_model, latents, device)
|
||||
@ -311,11 +465,13 @@ with gr.Blocks() as demo:
|
||||
|
||||
generate_button.click(
|
||||
generate,
|
||||
inputs=[prompt, seed_param, enable_scale, enable_rife],
|
||||
inputs=[prompt, image_input, video_input, strength, seed_param, enable_scale, enable_rife],
|
||||
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
||||
)
|
||||
|
||||
enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
|
||||
video_input.upload(resize_if_unfit, inputs=[video_input], outputs=[video_input])
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue(max_size=15)
|
||||
demo.launch()
|
||||
|
@ -1,21 +1,19 @@
|
||||
spaces==0.29.3
|
||||
safetensors>=0.4.4
|
||||
spandrel>=0.3.4
|
||||
spaces>=0.29.3
|
||||
safetensors>=0.4.5
|
||||
spandrel>=0.4.0
|
||||
tqdm>=4.66.5
|
||||
opencv-python>=4.10.0.84
|
||||
scikit-video>=1.1.11
|
||||
diffusers>=0.30.1
|
||||
transformers>=4.44.0
|
||||
accelerate>=0.33.0
|
||||
accelerate>=0.34.2
|
||||
opencv-python>=4.10.0.84
|
||||
sentencepiece>=0.2.0
|
||||
SwissArmyTransformer>=0.4.12
|
||||
numpy==1.26.0
|
||||
torch>=2.4.0
|
||||
torchvision>=0.19.0
|
||||
gradio>=4.42.0
|
||||
streamlit>=1.37.1
|
||||
imageio==2.34.2
|
||||
imageio-ffmpeg==0.5.1
|
||||
openai>=1.42.0
|
||||
moviepy==1.0.3
|
||||
gradio>=4.44.0
|
||||
imageio>=2.34.2
|
||||
imageio-ffmpeg>=0.5.1
|
||||
openai>=1.45.0
|
||||
moviepy>=1.0.3
|
||||
pillow==9.5.0
|
@ -10,7 +10,6 @@ import skvideo.io
|
||||
from rife.RIFE_HDv3 import Model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@ -37,8 +36,7 @@ def make_inference(model, I0, I1, upscale_amount, n):
|
||||
|
||||
@torch.inference_mode()
|
||||
def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"):
|
||||
print(f"samples dtype:{samples.dtype}")
|
||||
print(f"samples shape:{samples.shape}")
|
||||
|
||||
output = []
|
||||
# [f, c, h, w]
|
||||
for b in range(samples.shape[0]):
|
||||
@ -119,13 +117,11 @@ def rife_inference_with_path(model, video_path):
|
||||
|
||||
|
||||
def rife_inference_with_latents(model, latents):
|
||||
pbar = utils.ProgressBar(latents.shape[1], desc="RIFE inference")
|
||||
rife_results = []
|
||||
latents = latents.to(device)
|
||||
for i in range(latents.size(0)):
|
||||
# [f, c, w, h]
|
||||
latent = latents[i]
|
||||
|
||||
frames = ssim_interpolation_rife(model, latent)
|
||||
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
|
||||
rife_results.append(pt_image)
|
||||
|
Loading…
x
Reference in New Issue
Block a user