mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 10:32:09 +08:00
Add negative prompt
This commit is contained in:
parent
8e8275d2e8
commit
ba982025f4
@ -14,6 +14,7 @@ import imageio
|
|||||||
import moviepy.editor as mp
|
import moviepy.editor as mp
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
import PIL
|
import PIL
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@ -91,15 +92,18 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
|||||||
@spaces.GPU(duration=240)
|
@spaces.GPU(duration=240)
|
||||||
def infer(
|
def infer(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
guidance_scale: float,
|
guidance_scale: float,
|
||||||
progress=gr.Progress(track_tqdm=True)
|
progress=gr.Progress(track_tqdm=True)
|
||||||
):
|
):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
prompt_embeds, _ = pipe.encode_prompt(
|
# prompt_embeds, _ = pipe.encode_prompt(
|
||||||
|
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=None,
|
# negative_prompt=None,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
num_videos_per_prompt=1,
|
num_videos_per_prompt=1,
|
||||||
max_sequence_length=226,
|
max_sequence_length=226,
|
||||||
@ -111,7 +115,7 @@ def infer(
|
|||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
negative_prompt_embeds=torch.zeros_like(prompt_embeds),
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
).frames[0]
|
).frames[0]
|
||||||
|
|
||||||
|
|
||||||
@ -120,6 +124,8 @@ def infer(
|
|||||||
|
|
||||||
def save_video(tensor):
|
def save_video(tensor):
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
Path("output").mkdir(exist_ok=True)
|
||||||
|
|
||||||
video_path = f"./output/{timestamp}.mp4"
|
video_path = f"./output/{timestamp}.mp4"
|
||||||
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
||||||
export_to_video_imageio(tensor[1:], video_path)
|
export_to_video_imageio(tensor[1:], video_path)
|
||||||
@ -167,7 +173,8 @@ with gr.Blocks() as demo:
|
|||||||
""")
|
""")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=3)
|
||||||
|
negative_prompt = gr.Textbox(label="Negative_prompt (Less than 200 Words)", placeholder="Enter your negative prompt here", lines=3)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
|
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
|
||||||
@ -224,8 +231,8 @@ with gr.Blocks() as demo:
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
||||||
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
tensor = infer(prompt, negative_prompt, num_inference_steps, guidance_scale, progress=progress)
|
||||||
video_path = save_video(tensor)
|
video_path = save_video(tensor)
|
||||||
video_update = gr.update(visible=True, value=video_path)
|
video_update = gr.update(visible=True, value=video_path)
|
||||||
gif_path = convert_to_gif(video_path)
|
gif_path = convert_to_gif(video_path)
|
||||||
@ -240,7 +247,7 @@ with gr.Blocks() as demo:
|
|||||||
|
|
||||||
generate_button.click(
|
generate_button.click(
|
||||||
generate,
|
generate,
|
||||||
inputs=[prompt, num_inference_steps, guidance_scale],
|
inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale],
|
||||||
outputs=[video_output, download_video_button, download_gif_button]
|
outputs=[video_output, download_video_button, download_gif_button]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user