mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 10:32:09 +08:00
Add negative prompt
This commit is contained in:
parent
8e8275d2e8
commit
ba982025f4
@ -14,6 +14,7 @@ import imageio
|
||||
import moviepy.editor as mp
|
||||
from typing import List, Union
|
||||
import PIL
|
||||
from pathlib import Path
|
||||
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@ -91,15 +92,18 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
@spaces.GPU(duration=240)
|
||||
def infer(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
progress=gr.Progress(track_tqdm=True)
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
prompt_embeds, _ = pipe.encode_prompt(
|
||||
# prompt_embeds, _ = pipe.encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=None,
|
||||
# negative_prompt=None,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_videos_per_prompt=1,
|
||||
max_sequence_length=226,
|
||||
@ -111,7 +115,7 @@ def infer(
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=torch.zeros_like(prompt_embeds),
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
).frames[0]
|
||||
|
||||
|
||||
@ -120,6 +124,8 @@ def infer(
|
||||
|
||||
def save_video(tensor):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
Path("output").mkdir(exist_ok=True)
|
||||
|
||||
video_path = f"./output/{timestamp}.mp4"
|
||||
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
||||
export_to_video_imageio(tensor[1:], video_path)
|
||||
@ -167,7 +173,8 @@ with gr.Blocks() as demo:
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=3)
|
||||
negative_prompt = gr.Textbox(label="Negative_prompt (Less than 200 Words)", placeholder="Enter your negative prompt here", lines=3)
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
|
||||
@ -224,8 +231,8 @@ with gr.Blocks() as demo:
|
||||
""")
|
||||
|
||||
|
||||
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
||||
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
||||
def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
||||
tensor = infer(prompt, negative_prompt, num_inference_steps, guidance_scale, progress=progress)
|
||||
video_path = save_video(tensor)
|
||||
video_update = gr.update(visible=True, value=video_path)
|
||||
gif_path = convert_to_gif(video_path)
|
||||
@ -240,7 +247,7 @@ with gr.Blocks() as demo:
|
||||
|
||||
generate_button.click(
|
||||
generate,
|
||||
inputs=[prompt, num_inference_steps, guidance_scale],
|
||||
inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale],
|
||||
outputs=[video_output, download_video_button, download_gif_button]
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user