diff --git a/inference/gradio_composite_demo/app.py b/inference/gradio_composite_demo/app.py index 03ec1a1..07ff37f 100644 --- a/inference/gradio_composite_demo/app.py +++ b/inference/gradio_composite_demo/app.py @@ -3,7 +3,7 @@ THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt. Usage: - OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py + OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py """ import math @@ -12,9 +12,20 @@ import random import threading import time +import cv2 +import tempfile +import imageio_ffmpeg import gradio as gr import torch -from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from PIL import Image +from diffusers import ( + CogVideoXPipeline, + CogVideoXDPMScheduler, + CogVideoXVideoToVideoPipeline, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.utils import load_video, load_image from datetime import datetime, timedelta from diffusers.image_processor import VaeImageProcessor @@ -31,18 +42,33 @@ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device) pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") +pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained( + "THUDM/CogVideoX-5b", + transformer=pipe.transformer, + vae=pipe.vae, + scheduler=pipe.scheduler, + tokenizer=pipe.tokenizer, + text_encoder=pipe.text_encoder, + torch_dtype=torch.bfloat16, +).to(device) -# Unnecessary +pipe_image = CogVideoXImageToVideoPipeline.from_pretrained( + "THUDM/CogVideoX-5b", + transformer=CogVideoXTransformer3DModel.from_pretrained( + "THUDM/CogVideoX-5b-I2V", subfolder="transformers", torch_dtype=torch.bfloat16 + ), + vae=pipe.vae, + scheduler=pipe.scheduler, + tokenizer=pipe.tokenizer, + text_encoder=pipe.text_encoder, + torch_dtype=torch.bfloat16, +).to(device) -pipe.enable_model_cpu_offload() -pipe.enable_sequential_cpu_offload() -pipe.vae.enable_slicing() -pipe.vae.enable_tiling() -# Compile - -# pipe.transformer.to(memory_format=torch.channels_last) -# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) +pipe.transformer.to(memory_format=torch.channels_last) +pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) +pipe_image.transformer.to(memory_format=torch.channels_last) +pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True) os.makedirs("./output", exist_ok=True) os.makedirs("./gradio_tmp", exist_ok=True) @@ -64,6 +90,80 @@ Video descriptions must have the same num of words as examples below. Extra word """ +def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)): + width, height = get_video_dimensions(input_video) + + if width == 720 and height == 480: + processed_video = input_video + else: + processed_video = center_crop_resize(input_video) + return processed_video + + +def get_video_dimensions(input_video_path): + reader = imageio_ffmpeg.read_frames(input_video_path) + metadata = next(reader) + return metadata["size"] + + +def center_crop_resize(input_video_path, target_width=720, target_height=480): + cap = cv2.VideoCapture(input_video_path) + + orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + orig_fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + width_factor = target_width / orig_width + height_factor = target_height / orig_height + resize_factor = max(width_factor, height_factor) + + inter_width = int(orig_width * resize_factor) + inter_height = int(orig_height * resize_factor) + + target_fps = 8 + ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1) + skip = min(5, ideal_skip) # Cap at 5 + + while (total_frames / (skip + 1)) < 49 and skip > 0: + skip -= 1 + + processed_frames = [] + frame_count = 0 + total_read = 0 + + while frame_count < 49 and total_read < total_frames: + ret, frame = cap.read() + if not ret: + break + + if total_read % (skip + 1) == 0: + resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA) + + start_x = (inter_width - target_width) // 2 + start_y = (inter_height - target_height) // 2 + cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width] + + processed_frames.append(cropped) + frame_count += 1 + + total_read += 1 + + cap.release() + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: + temp_video_path = temp_file.name + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height)) + + for frame in processed_frames: + out.write(frame) + + out.release() + + return temp_video_path + + def convert_prompt(prompt: str, retry_times: int = 3) -> str: if not os.environ.get("OPENAI_API_KEY"): return prompt @@ -103,7 +203,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str: "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"', }, ], - model="glm-4-0520", + model="glm-4-plus", temperature=0.01, top_p=0.7, stream=False, @@ -116,6 +216,9 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str: def infer( prompt: str, + image_input: str, + video_input: str, + video_strenght: float, num_inference_steps: int, guidance_scale: float, seed: int = -1, @@ -123,16 +226,44 @@ def infer( ): if seed == -1: seed = random.randint(0, 2**8 - 1) - video_pt = pipe( - prompt=prompt, - num_videos_per_prompt=1, - num_inference_steps=num_inference_steps, - num_frames=49, - use_dynamic_cfg=True, - output_type="pt", - guidance_scale=guidance_scale, - generator=torch.Generator(device="cpu").manual_seed(seed), - ).frames + + if video_input is not None: + video = load_video(video_input)[:49] # Limit to 49 frames + video_pt = pipe_video( + video=video, + prompt=prompt, + num_inference_steps=num_inference_steps, + num_videos_per_prompt=1, + strength=video_strenght, + use_dynamic_cfg=True, + output_type="pt", + guidance_scale=guidance_scale, + generator=torch.Generator(device="cpu").manual_seed(seed), + ).frames + elif image_input is not None: + image_input = Image.fromarray(image_input) # Change to PIL + image = load_image(image_input) + video_pt = pipe_image( + image=image, + prompt=prompt, + num_inference_steps=num_inference_steps, + num_videos_per_prompt=1, + use_dynamic_cfg=True, + output_type="pt", + guidance_scale=guidance_scale, + generator=torch.Generator(device="cpu").manual_seed(seed), + ).frames + else: + video_pt = pipe( + prompt=prompt, + num_videos_per_prompt=1, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + output_type="pt", + guidance_scale=guidance_scale, + generator=torch.Generator(device="cpu").manual_seed(seed), + ).frames return (video_pt, seed) @@ -163,6 +294,7 @@ def delete_old_files(): threading.Thread(target=delete_old_files, daemon=True).start() +examples = [["horse.mp4"], ["kitten.mp4"], ["train_running.mp4"]] with gr.Blocks() as demo: gr.Markdown(""" @@ -174,13 +306,24 @@ with gr.Blocks() as demo: 🌐 Github | 📜 arxiv - +
+ If the Space is too busy, duplicate it to use privately + +
⚠️ This demo is for academic research and experiential use only.
""") with gr.Row(): with gr.Column(): + with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False): + image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)") + with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False): + video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)") + strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength") + examples_component = gr.Examples(examples, inputs=[video_input], cache_examples=False) prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) with gr.Row(): @@ -188,7 +331,6 @@ with gr.Blocks() as demo: "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one." ) enhance_button = gr.Button("✨ Enhance Prompt(Optional)") - with gr.Group(): with gr.Column(): with gr.Row(): @@ -275,13 +417,25 @@ with gr.Blocks() as demo: """) - def generate(prompt, seed_value, scale_status, rife_status, progress=gr.Progress(track_tqdm=True)): + def generate( + prompt, + image_input, + video_input, + video_strength, + seed_value, + scale_status, + rife_status, + progress=gr.Progress(track_tqdm=True) + ): latents, seed = infer( prompt, + image_input, + video_input, + video_strength, num_inference_steps=50, # NOT Changed guidance_scale=7.0, # NOT Changed seed=seed_value, - # progress=progress, + progress=progress, ) if scale_status: latents = utils.upscale_batch_and_concatenate(upscale_model, latents, device) @@ -311,11 +465,13 @@ with gr.Blocks() as demo: generate_button.click( generate, - inputs=[prompt, seed_param, enable_scale, enable_rife], + inputs=[prompt, image_input, video_input, strength, seed_param, enable_scale, enable_rife], outputs=[video_output, download_video_button, download_gif_button, seed_text], ) enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt]) + video_input.upload(resize_if_unfit, inputs=[video_input], outputs=[video_input]) if __name__ == "__main__": + demo.queue(max_size=15) demo.launch() diff --git a/inference/gradio_composite_demo/requirements.txt b/inference/gradio_composite_demo/requirements.txt index 0464d81..8c64659 100644 --- a/inference/gradio_composite_demo/requirements.txt +++ b/inference/gradio_composite_demo/requirements.txt @@ -1,21 +1,19 @@ -spaces==0.29.3 -safetensors>=0.4.4 -spandrel>=0.3.4 +spaces>=0.29.3 +safetensors>=0.4.5 +spandrel>=0.4.0 tqdm>=4.66.5 -opencv-python>=4.10.0.84 scikit-video>=1.1.11 diffusers>=0.30.1 transformers>=4.44.0 -accelerate>=0.33.0 +accelerate>=0.34.2 +opencv-python>=4.10.0.84 sentencepiece>=0.2.0 -SwissArmyTransformer>=0.4.12 numpy==1.26.0 torch>=2.4.0 torchvision>=0.19.0 -gradio>=4.42.0 -streamlit>=1.37.1 -imageio==2.34.2 -imageio-ffmpeg==0.5.1 -openai>=1.42.0 -moviepy==1.0.3 +gradio>=4.44.0 +imageio>=2.34.2 +imageio-ffmpeg>=0.5.1 +openai>=1.45.0 +moviepy>=1.0.3 pillow==9.5.0 \ No newline at end of file diff --git a/inference/gradio_composite_demo/rife_model.py b/inference/gradio_composite_demo/rife_model.py index 0a69ca6..dbb7d00 100644 --- a/inference/gradio_composite_demo/rife_model.py +++ b/inference/gradio_composite_demo/rife_model.py @@ -10,7 +10,6 @@ import skvideo.io from rife.RIFE_HDv3 import Model logger = logging.getLogger(__name__) - device = "cuda" if torch.cuda.is_available() else "cpu" @@ -37,8 +36,7 @@ def make_inference(model, I0, I1, upscale_amount, n): @torch.inference_mode() def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"): - print(f"samples dtype:{samples.dtype}") - print(f"samples shape:{samples.shape}") + output = [] # [f, c, h, w] for b in range(samples.shape[0]): @@ -119,13 +117,11 @@ def rife_inference_with_path(model, video_path): def rife_inference_with_latents(model, latents): - pbar = utils.ProgressBar(latents.shape[1], desc="RIFE inference") rife_results = [] latents = latents.to(device) for i in range(latents.size(0)): # [f, c, w, h] latent = latents[i] - frames = ssim_interpolation_rife(model, latent) pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h]) rife_results.append(pt_image)