From ba982025f4e22ed510513e74f39ba0f35feec8ce Mon Sep 17 00:00:00 2001 From: cly2625 Date: Tue, 6 Aug 2024 17:27:38 +0800 Subject: [PATCH] Add negative prompt --- gradio_demo.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/gradio_demo.py b/gradio_demo.py index 65eeb48..723e079 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -14,6 +14,7 @@ import imageio import moviepy.editor as mp from typing import List, Union import PIL +from pathlib import Path dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" @@ -91,15 +92,18 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str: @spaces.GPU(duration=240) def infer( prompt: str, + negative_prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True) ): torch.cuda.empty_cache() - prompt_embeds, _ = pipe.encode_prompt( + # prompt_embeds, _ = pipe.encode_prompt( + prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( prompt=prompt, - negative_prompt=None, + # negative_prompt=None, + negative_prompt=negative_prompt, do_classifier_free_guidance=True, num_videos_per_prompt=1, max_sequence_length=226, @@ -111,7 +115,7 @@ def infer( num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, - negative_prompt_embeds=torch.zeros_like(prompt_embeds), + negative_prompt_embeds=negative_prompt_embeds, ).frames[0] @@ -120,6 +124,8 @@ def infer( def save_video(tensor): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + Path("output").mkdir(exist_ok=True) + video_path = f"./output/{timestamp}.mp4" os.makedirs(os.path.dirname(video_path), exist_ok=True) export_to_video_imageio(tensor[1:], video_path) @@ -167,7 +173,8 @@ with gr.Blocks() as demo: """) with gr.Row(): with gr.Column(): - prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) + prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=3) + negative_prompt = gr.Textbox(label="Negative_prompt (Less than 200 Words)", placeholder="Enter your negative prompt here", lines=3) with gr.Row(): gr.Markdown( "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.") @@ -224,8 +231,8 @@ with gr.Blocks() as demo: """) - def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)): - tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress) + def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)): + tensor = infer(prompt, negative_prompt, num_inference_steps, guidance_scale, progress=progress) video_path = save_video(tensor) video_update = gr.update(visible=True, value=video_path) gif_path = convert_to_gif(video_path) @@ -240,7 +247,7 @@ with gr.Blocks() as demo: generate_button.click( generate, - inputs=[prompt, num_inference_steps, guidance_scale], + inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale], outputs=[video_output, download_video_button, download_gif_button] )