mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
app done
This commit is contained in:
parent
6e64359524
commit
0a558e0964
@ -3,7 +3,7 @@ THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to
|
|||||||
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
|
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@ -12,9 +12,20 @@ import random
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import tempfile
|
||||||
|
import imageio_ffmpeg
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
from PIL import Image
|
||||||
|
from diffusers import (
|
||||||
|
CogVideoXPipeline,
|
||||||
|
CogVideoXDPMScheduler,
|
||||||
|
CogVideoXVideoToVideoPipeline,
|
||||||
|
CogVideoXImageToVideoPipeline,
|
||||||
|
CogVideoXTransformer3DModel,
|
||||||
|
)
|
||||||
|
from diffusers.utils import load_video, load_image
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
@ -31,18 +42,33 @@ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
|||||||
|
|
||||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device)
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device)
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||||
|
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
|
||||||
|
"THUDM/CogVideoX-5b",
|
||||||
|
transformer=pipe.transformer,
|
||||||
|
vae=pipe.vae,
|
||||||
|
scheduler=pipe.scheduler,
|
||||||
|
tokenizer=pipe.tokenizer,
|
||||||
|
text_encoder=pipe.text_encoder,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# Unnecessary
|
pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||||
|
"THUDM/CogVideoX-5b",
|
||||||
|
transformer=CogVideoXTransformer3DModel.from_pretrained(
|
||||||
|
"THUDM/CogVideoX-5b-I2V", subfolder="transformers", torch_dtype=torch.bfloat16
|
||||||
|
),
|
||||||
|
vae=pipe.vae,
|
||||||
|
scheduler=pipe.scheduler,
|
||||||
|
tokenizer=pipe.tokenizer,
|
||||||
|
text_encoder=pipe.text_encoder,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
pipe.enable_model_cpu_offload()
|
|
||||||
pipe.enable_sequential_cpu_offload()
|
|
||||||
pipe.vae.enable_slicing()
|
|
||||||
pipe.vae.enable_tiling()
|
|
||||||
|
|
||||||
# Compile
|
pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
|
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
||||||
# pipe.transformer.to(memory_format=torch.channels_last)
|
pipe_image.transformer.to(memory_format=torch.channels_last)
|
||||||
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
os.makedirs("./output", exist_ok=True)
|
os.makedirs("./output", exist_ok=True)
|
||||||
os.makedirs("./gradio_tmp", exist_ok=True)
|
os.makedirs("./gradio_tmp", exist_ok=True)
|
||||||
@ -64,6 +90,80 @@ Video descriptions must have the same num of words as examples below. Extra word
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
width, height = get_video_dimensions(input_video)
|
||||||
|
|
||||||
|
if width == 720 and height == 480:
|
||||||
|
processed_video = input_video
|
||||||
|
else:
|
||||||
|
processed_video = center_crop_resize(input_video)
|
||||||
|
return processed_video
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_dimensions(input_video_path):
|
||||||
|
reader = imageio_ffmpeg.read_frames(input_video_path)
|
||||||
|
metadata = next(reader)
|
||||||
|
return metadata["size"]
|
||||||
|
|
||||||
|
|
||||||
|
def center_crop_resize(input_video_path, target_width=720, target_height=480):
|
||||||
|
cap = cv2.VideoCapture(input_video_path)
|
||||||
|
|
||||||
|
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
orig_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
|
||||||
|
width_factor = target_width / orig_width
|
||||||
|
height_factor = target_height / orig_height
|
||||||
|
resize_factor = max(width_factor, height_factor)
|
||||||
|
|
||||||
|
inter_width = int(orig_width * resize_factor)
|
||||||
|
inter_height = int(orig_height * resize_factor)
|
||||||
|
|
||||||
|
target_fps = 8
|
||||||
|
ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
|
||||||
|
skip = min(5, ideal_skip) # Cap at 5
|
||||||
|
|
||||||
|
while (total_frames / (skip + 1)) < 49 and skip > 0:
|
||||||
|
skip -= 1
|
||||||
|
|
||||||
|
processed_frames = []
|
||||||
|
frame_count = 0
|
||||||
|
total_read = 0
|
||||||
|
|
||||||
|
while frame_count < 49 and total_read < total_frames:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
if total_read % (skip + 1) == 0:
|
||||||
|
resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
start_x = (inter_width - target_width) // 2
|
||||||
|
start_y = (inter_height - target_height) // 2
|
||||||
|
cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width]
|
||||||
|
|
||||||
|
processed_frames.append(cropped)
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
total_read += 1
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
|
||||||
|
temp_video_path = temp_file.name
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||||
|
out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))
|
||||||
|
|
||||||
|
for frame in processed_frames:
|
||||||
|
out.write(frame)
|
||||||
|
|
||||||
|
out.release()
|
||||||
|
|
||||||
|
return temp_video_path
|
||||||
|
|
||||||
|
|
||||||
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
return prompt
|
return prompt
|
||||||
@ -103,7 +203,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
|||||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
model="glm-4-0520",
|
model="glm-4-plus",
|
||||||
temperature=0.01,
|
temperature=0.01,
|
||||||
top_p=0.7,
|
top_p=0.7,
|
||||||
stream=False,
|
stream=False,
|
||||||
@ -116,6 +216,9 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
|||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
image_input: str,
|
||||||
|
video_input: str,
|
||||||
|
video_strenght: float,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
guidance_scale: float,
|
guidance_scale: float,
|
||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
@ -123,16 +226,44 @@ def infer(
|
|||||||
):
|
):
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = random.randint(0, 2**8 - 1)
|
seed = random.randint(0, 2**8 - 1)
|
||||||
video_pt = pipe(
|
|
||||||
prompt=prompt,
|
if video_input is not None:
|
||||||
num_videos_per_prompt=1,
|
video = load_video(video_input)[:49] # Limit to 49 frames
|
||||||
num_inference_steps=num_inference_steps,
|
video_pt = pipe_video(
|
||||||
num_frames=49,
|
video=video,
|
||||||
use_dynamic_cfg=True,
|
prompt=prompt,
|
||||||
output_type="pt",
|
num_inference_steps=num_inference_steps,
|
||||||
guidance_scale=guidance_scale,
|
num_videos_per_prompt=1,
|
||||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
strength=video_strenght,
|
||||||
).frames
|
use_dynamic_cfg=True,
|
||||||
|
output_type="pt",
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||||
|
).frames
|
||||||
|
elif image_input is not None:
|
||||||
|
image_input = Image.fromarray(image_input) # Change to PIL
|
||||||
|
image = load_image(image_input)
|
||||||
|
video_pt = pipe_image(
|
||||||
|
image=image,
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_videos_per_prompt=1,
|
||||||
|
use_dynamic_cfg=True,
|
||||||
|
output_type="pt",
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||||
|
).frames
|
||||||
|
else:
|
||||||
|
video_pt = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
num_videos_per_prompt=1,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_frames=49,
|
||||||
|
use_dynamic_cfg=True,
|
||||||
|
output_type="pt",
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||||
|
).frames
|
||||||
|
|
||||||
return (video_pt, seed)
|
return (video_pt, seed)
|
||||||
|
|
||||||
@ -163,6 +294,7 @@ def delete_old_files():
|
|||||||
|
|
||||||
|
|
||||||
threading.Thread(target=delete_old_files, daemon=True).start()
|
threading.Thread(target=delete_old_files, daemon=True).start()
|
||||||
|
examples = [["horse.mp4"], ["kitten.mp4"], ["train_running.mp4"]]
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
@ -174,13 +306,24 @@ with gr.Blocks() as demo:
|
|||||||
<a href="https://github.com/THUDM/CogVideo">🌐 Github</a> |
|
<a href="https://github.com/THUDM/CogVideo">🌐 Github</a> |
|
||||||
<a href="https://arxiv.org/pdf/2408.06072">📜 arxiv </a>
|
<a href="https://arxiv.org/pdf/2408.06072">📜 arxiv </a>
|
||||||
</div>
|
</div>
|
||||||
|
<div style="text-align: center;display: flex;justify-content: center;align-items: center;margin-top: 1em;margin-bottom: .5em;">
|
||||||
|
<span>If the Space is too busy, duplicate it to use privately</span>
|
||||||
|
<a href="https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" width="160" style="
|
||||||
|
margin-left: .75em;
|
||||||
|
"></a>
|
||||||
|
</div>
|
||||||
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
|
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
|
||||||
⚠️ This demo is for academic research and experiential use only.
|
⚠️ This demo is for academic research and experiential use only.
|
||||||
</div>
|
</div>
|
||||||
""")
|
""")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
||||||
|
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
||||||
|
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
|
||||||
|
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
|
||||||
|
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
|
||||||
|
examples_component = gr.Examples(examples, inputs=[video_input], cache_examples=False)
|
||||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -188,7 +331,6 @@ with gr.Blocks() as demo:
|
|||||||
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
|
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
|
||||||
)
|
)
|
||||||
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -275,13 +417,25 @@ with gr.Blocks() as demo:
|
|||||||
</table>
|
</table>
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def generate(prompt, seed_value, scale_status, rife_status, progress=gr.Progress(track_tqdm=True)):
|
def generate(
|
||||||
|
prompt,
|
||||||
|
image_input,
|
||||||
|
video_input,
|
||||||
|
video_strength,
|
||||||
|
seed_value,
|
||||||
|
scale_status,
|
||||||
|
rife_status,
|
||||||
|
progress=gr.Progress(track_tqdm=True)
|
||||||
|
):
|
||||||
latents, seed = infer(
|
latents, seed = infer(
|
||||||
prompt,
|
prompt,
|
||||||
|
image_input,
|
||||||
|
video_input,
|
||||||
|
video_strength,
|
||||||
num_inference_steps=50, # NOT Changed
|
num_inference_steps=50, # NOT Changed
|
||||||
guidance_scale=7.0, # NOT Changed
|
guidance_scale=7.0, # NOT Changed
|
||||||
seed=seed_value,
|
seed=seed_value,
|
||||||
# progress=progress,
|
progress=progress,
|
||||||
)
|
)
|
||||||
if scale_status:
|
if scale_status:
|
||||||
latents = utils.upscale_batch_and_concatenate(upscale_model, latents, device)
|
latents = utils.upscale_batch_and_concatenate(upscale_model, latents, device)
|
||||||
@ -311,11 +465,13 @@ with gr.Blocks() as demo:
|
|||||||
|
|
||||||
generate_button.click(
|
generate_button.click(
|
||||||
generate,
|
generate,
|
||||||
inputs=[prompt, seed_param, enable_scale, enable_rife],
|
inputs=[prompt, image_input, video_input, strength, seed_param, enable_scale, enable_rife],
|
||||||
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
||||||
)
|
)
|
||||||
|
|
||||||
enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
|
enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
|
||||||
|
video_input.upload(resize_if_unfit, inputs=[video_input], outputs=[video_input])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
demo.queue(max_size=15)
|
||||||
demo.launch()
|
demo.launch()
|
||||||
|
@ -1,21 +1,19 @@
|
|||||||
spaces==0.29.3
|
spaces>=0.29.3
|
||||||
safetensors>=0.4.4
|
safetensors>=0.4.5
|
||||||
spandrel>=0.3.4
|
spandrel>=0.4.0
|
||||||
tqdm>=4.66.5
|
tqdm>=4.66.5
|
||||||
opencv-python>=4.10.0.84
|
|
||||||
scikit-video>=1.1.11
|
scikit-video>=1.1.11
|
||||||
diffusers>=0.30.1
|
diffusers>=0.30.1
|
||||||
transformers>=4.44.0
|
transformers>=4.44.0
|
||||||
accelerate>=0.33.0
|
accelerate>=0.34.2
|
||||||
|
opencv-python>=4.10.0.84
|
||||||
sentencepiece>=0.2.0
|
sentencepiece>=0.2.0
|
||||||
SwissArmyTransformer>=0.4.12
|
|
||||||
numpy==1.26.0
|
numpy==1.26.0
|
||||||
torch>=2.4.0
|
torch>=2.4.0
|
||||||
torchvision>=0.19.0
|
torchvision>=0.19.0
|
||||||
gradio>=4.42.0
|
gradio>=4.44.0
|
||||||
streamlit>=1.37.1
|
imageio>=2.34.2
|
||||||
imageio==2.34.2
|
imageio-ffmpeg>=0.5.1
|
||||||
imageio-ffmpeg==0.5.1
|
openai>=1.45.0
|
||||||
openai>=1.42.0
|
moviepy>=1.0.3
|
||||||
moviepy==1.0.3
|
|
||||||
pillow==9.5.0
|
pillow==9.5.0
|
@ -10,7 +10,6 @@ import skvideo.io
|
|||||||
from rife.RIFE_HDv3 import Model
|
from rife.RIFE_HDv3 import Model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
@ -37,8 +36,7 @@ def make_inference(model, I0, I1, upscale_amount, n):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"):
|
def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"):
|
||||||
print(f"samples dtype:{samples.dtype}")
|
|
||||||
print(f"samples shape:{samples.shape}")
|
|
||||||
output = []
|
output = []
|
||||||
# [f, c, h, w]
|
# [f, c, h, w]
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
@ -119,13 +117,11 @@ def rife_inference_with_path(model, video_path):
|
|||||||
|
|
||||||
|
|
||||||
def rife_inference_with_latents(model, latents):
|
def rife_inference_with_latents(model, latents):
|
||||||
pbar = utils.ProgressBar(latents.shape[1], desc="RIFE inference")
|
|
||||||
rife_results = []
|
rife_results = []
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
for i in range(latents.size(0)):
|
for i in range(latents.size(0)):
|
||||||
# [f, c, w, h]
|
# [f, c, w, h]
|
||||||
latent = latents[i]
|
latent = latents[i]
|
||||||
|
|
||||||
frames = ssim_interpolation_rife(model, latent)
|
frames = ssim_interpolation_rife(model, latent)
|
||||||
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
|
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
|
||||||
rife_results.append(pt_image)
|
rife_results.append(pt_image)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user