diff --git a/inference/gradio_web_demo.py b/inference/gradio_web_demo.py index ae1e6d7..fbefaf4 100644 --- a/inference/gradio_web_demo.py +++ b/inference/gradio_web_demo.py @@ -21,7 +21,9 @@ import utils device = "cuda" if torch.cuda.is_available() else "cpu" -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to( +MODEL_PATH = os.environ.get('MODEL_PATH', "THUDM/CogVideoX-2b") +UP_SCALE_MODEL_CKPT = os.environ.get('UP_SCALE_MODEL_CKPT', "/media/gpt4-pdf-chatbot-langchain/ComfyUI/models/upscale_models/RealESRGAN_x4.pth") +pipe = CogVideoXPipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.float16).to( device) pipe.enable_model_cpu_offload() @@ -247,8 +249,6 @@ with gr.Blocks() as demo: with gr.Row(): num_inference_steps = gr.Number(label="Inference Steps", value=50) guidance_scale = gr.Number(label="Guidance Scale", value=6.0) - with gr.Row(): - up_scale_model_ckpt = gr.Text(label="UP_SCALE_MODEL_CKPT", value="") generate_button = gr.Button("🎬 Generate Video") with gr.Column(): @@ -294,11 +294,11 @@ with gr.Blocks() as demo: """) - def generate(prompt, num_inference_steps, guidance_scale, upscale_ckpt, progress=gr.Progress(track_tqdm=True)): + def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)): global UP_SCALE_MODEL if not UP_SCALE_MODEL: # Load the upscale model with progress tracking - UP_SCALE_MODEL = load_sd_upscale(upscale_ckpt) + UP_SCALE_MODEL = load_sd_upscale(UP_SCALE_MODEL_CKPT) latents = infer(prompt, num_inference_steps, guidance_scale, progress=progress) if UP_SCALE_MODEL: @@ -331,7 +331,7 @@ with gr.Blocks() as demo: generate_button.click( generate, - inputs=[prompt, num_inference_steps, guidance_scale, up_scale_model_ckpt], + inputs=[prompt, num_inference_steps, guidance_scale], outputs=[video_output, download_video_button, download_gif_button] )