diff --git a/inference/gradio_composite_demo/app.py b/inference/gradio_composite_demo/app.py index 03ec1a1..07ff37f 100644 --- a/inference/gradio_composite_demo/app.py +++ b/inference/gradio_composite_demo/app.py @@ -3,7 +3,7 @@ THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt. Usage: - OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py + OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py """ import math @@ -12,9 +12,20 @@ import random import threading import time +import cv2 +import tempfile +import imageio_ffmpeg import gradio as gr import torch -from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from PIL import Image +from diffusers import ( + CogVideoXPipeline, + CogVideoXDPMScheduler, + CogVideoXVideoToVideoPipeline, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.utils import load_video, load_image from datetime import datetime, timedelta from diffusers.image_processor import VaeImageProcessor @@ -31,18 +42,33 @@ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device) pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") +pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained( + "THUDM/CogVideoX-5b", + transformer=pipe.transformer, + vae=pipe.vae, + scheduler=pipe.scheduler, + tokenizer=pipe.tokenizer, + text_encoder=pipe.text_encoder, + torch_dtype=torch.bfloat16, +).to(device) -# Unnecessary +pipe_image = CogVideoXImageToVideoPipeline.from_pretrained( + "THUDM/CogVideoX-5b", + transformer=CogVideoXTransformer3DModel.from_pretrained( + "THUDM/CogVideoX-5b-I2V", subfolder="transformers", torch_dtype=torch.bfloat16 + ), + vae=pipe.vae, + scheduler=pipe.scheduler, + tokenizer=pipe.tokenizer, + text_encoder=pipe.text_encoder, + torch_dtype=torch.bfloat16, +).to(device) -pipe.enable_model_cpu_offload() -pipe.enable_sequential_cpu_offload() -pipe.vae.enable_slicing() -pipe.vae.enable_tiling() -# Compile - -# pipe.transformer.to(memory_format=torch.channels_last) -# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) +pipe.transformer.to(memory_format=torch.channels_last) +pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) +pipe_image.transformer.to(memory_format=torch.channels_last) +pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True) os.makedirs("./output", exist_ok=True) os.makedirs("./gradio_tmp", exist_ok=True) @@ -64,6 +90,80 @@ Video descriptions must have the same num of words as examples below. Extra word """ +def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)): + width, height = get_video_dimensions(input_video) + + if width == 720 and height == 480: + processed_video = input_video + else: + processed_video = center_crop_resize(input_video) + return processed_video + + +def get_video_dimensions(input_video_path): + reader = imageio_ffmpeg.read_frames(input_video_path) + metadata = next(reader) + return metadata["size"] + + +def center_crop_resize(input_video_path, target_width=720, target_height=480): + cap = cv2.VideoCapture(input_video_path) + + orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + orig_fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + width_factor = target_width / orig_width + height_factor = target_height / orig_height + resize_factor = max(width_factor, height_factor) + + inter_width = int(orig_width * resize_factor) + inter_height = int(orig_height * resize_factor) + + target_fps = 8 + ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1) + skip = min(5, ideal_skip) # Cap at 5 + + while (total_frames / (skip + 1)) < 49 and skip > 0: + skip -= 1 + + processed_frames = [] + frame_count = 0 + total_read = 0 + + while frame_count < 49 and total_read < total_frames: + ret, frame = cap.read() + if not ret: + break + + if total_read % (skip + 1) == 0: + resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA) + + start_x = (inter_width - target_width) // 2 + start_y = (inter_height - target_height) // 2 + cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width] + + processed_frames.append(cropped) + frame_count += 1 + + total_read += 1 + + cap.release() + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: + temp_video_path = temp_file.name + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height)) + + for frame in processed_frames: + out.write(frame) + + out.release() + + return temp_video_path + + def convert_prompt(prompt: str, retry_times: int = 3) -> str: if not os.environ.get("OPENAI_API_KEY"): return prompt @@ -103,7 +203,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str: "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"', }, ], - model="glm-4-0520", + model="glm-4-plus", temperature=0.01, top_p=0.7, stream=False, @@ -116,6 +216,9 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str: def infer( prompt: str, + image_input: str, + video_input: str, + video_strenght: float, num_inference_steps: int, guidance_scale: float, seed: int = -1, @@ -123,16 +226,44 @@ def infer( ): if seed == -1: seed = random.randint(0, 2**8 - 1) - video_pt = pipe( - prompt=prompt, - num_videos_per_prompt=1, - num_inference_steps=num_inference_steps, - num_frames=49, - use_dynamic_cfg=True, - output_type="pt", - guidance_scale=guidance_scale, - generator=torch.Generator(device="cpu").manual_seed(seed), - ).frames + + if video_input is not None: + video = load_video(video_input)[:49] # Limit to 49 frames + video_pt = pipe_video( + video=video, + prompt=prompt, + num_inference_steps=num_inference_steps, + num_videos_per_prompt=1, + strength=video_strenght, + use_dynamic_cfg=True, + output_type="pt", + guidance_scale=guidance_scale, + generator=torch.Generator(device="cpu").manual_seed(seed), + ).frames + elif image_input is not None: + image_input = Image.fromarray(image_input) # Change to PIL + image = load_image(image_input) + video_pt = pipe_image( + image=image, + prompt=prompt, + num_inference_steps=num_inference_steps, + num_videos_per_prompt=1, + use_dynamic_cfg=True, + output_type="pt", + guidance_scale=guidance_scale, + generator=torch.Generator(device="cpu").manual_seed(seed), + ).frames + else: + video_pt = pipe( + prompt=prompt, + num_videos_per_prompt=1, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + output_type="pt", + guidance_scale=guidance_scale, + generator=torch.Generator(device="cpu").manual_seed(seed), + ).frames return (video_pt, seed) @@ -163,6 +294,7 @@ def delete_old_files(): threading.Thread(target=delete_old_files, daemon=True).start() +examples = [["horse.mp4"], ["kitten.mp4"], ["train_running.mp4"]] with gr.Blocks() as demo: gr.Markdown(""" @@ -174,13 +306,24 @@ with gr.Blocks() as demo: 🌐 Github | 📜 arxiv - +