Add negative prompt

This commit is contained in:
cly2625 2024-08-06 17:27:38 +08:00
parent 8e8275d2e8
commit ba982025f4

View File

@ -14,6 +14,7 @@ import imageio
import moviepy.editor as mp import moviepy.editor as mp
from typing import List, Union from typing import List, Union
import PIL import PIL
from pathlib import Path
dtype = torch.bfloat16 dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@ -91,15 +92,18 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
@spaces.GPU(duration=240) @spaces.GPU(duration=240)
def infer( def infer(
prompt: str, prompt: str,
negative_prompt: str,
num_inference_steps: int, num_inference_steps: int,
guidance_scale: float, guidance_scale: float,
progress=gr.Progress(track_tqdm=True) progress=gr.Progress(track_tqdm=True)
): ):
torch.cuda.empty_cache() torch.cuda.empty_cache()
prompt_embeds, _ = pipe.encode_prompt( # prompt_embeds, _ = pipe.encode_prompt(
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt, prompt=prompt,
negative_prompt=None, # negative_prompt=None,
negative_prompt=negative_prompt,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
num_videos_per_prompt=1, num_videos_per_prompt=1,
max_sequence_length=226, max_sequence_length=226,
@ -111,7 +115,7 @@ def infer(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=torch.zeros_like(prompt_embeds), negative_prompt_embeds=negative_prompt_embeds,
).frames[0] ).frames[0]
@ -120,6 +124,8 @@ def infer(
def save_video(tensor): def save_video(tensor):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
Path("output").mkdir(exist_ok=True)
video_path = f"./output/{timestamp}.mp4" video_path = f"./output/{timestamp}.mp4"
os.makedirs(os.path.dirname(video_path), exist_ok=True) os.makedirs(os.path.dirname(video_path), exist_ok=True)
export_to_video_imageio(tensor[1:], video_path) export_to_video_imageio(tensor[1:], video_path)
@ -167,7 +173,8 @@ with gr.Blocks() as demo:
""") """)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=3)
negative_prompt = gr.Textbox(label="Negative_prompt (Less than 200 Words)", placeholder="Enter your negative prompt here", lines=3)
with gr.Row(): with gr.Row():
gr.Markdown( gr.Markdown(
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.") "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
@ -224,8 +231,8 @@ with gr.Blocks() as demo:
""") """)
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)): def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress) tensor = infer(prompt, negative_prompt, num_inference_steps, guidance_scale, progress=progress)
video_path = save_video(tensor) video_path = save_video(tensor)
video_update = gr.update(visible=True, value=video_path) video_update = gr.update(visible=True, value=video_path)
gif_path = convert_to_gif(video_path) gif_path = convert_to_gif(video_path)
@ -240,7 +247,7 @@ with gr.Blocks() as demo:
generate_button.click( generate_button.click(
generate, generate,
inputs=[prompt, num_inference_steps, guidance_scale], inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale],
outputs=[video_output, download_video_button, download_gif_button] outputs=[video_output, download_video_button, download_gif_button]
) )