model_path load for env

This commit is contained in:
glide-the 2024-08-20 15:13:34 +08:00
parent e4e612db05
commit 27837e3c83

View File

@ -21,7 +21,9 @@ import utils
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to( MODEL_PATH = os.environ.get('MODEL_PATH', "THUDM/CogVideoX-2b")
UP_SCALE_MODEL_CKPT = os.environ.get('UP_SCALE_MODEL_CKPT', "/media/gpt4-pdf-chatbot-langchain/ComfyUI/models/upscale_models/RealESRGAN_x4.pth")
pipe = CogVideoXPipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.float16).to(
device) device)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
@ -247,8 +249,6 @@ with gr.Blocks() as demo:
with gr.Row(): with gr.Row():
num_inference_steps = gr.Number(label="Inference Steps", value=50) num_inference_steps = gr.Number(label="Inference Steps", value=50)
guidance_scale = gr.Number(label="Guidance Scale", value=6.0) guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
with gr.Row():
up_scale_model_ckpt = gr.Text(label="UP_SCALE_MODEL_CKPT", value="")
generate_button = gr.Button("🎬 Generate Video") generate_button = gr.Button("🎬 Generate Video")
with gr.Column(): with gr.Column():
@ -294,11 +294,11 @@ with gr.Blocks() as demo:
""") """)
def generate(prompt, num_inference_steps, guidance_scale, upscale_ckpt, progress=gr.Progress(track_tqdm=True)): def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
global UP_SCALE_MODEL global UP_SCALE_MODEL
if not UP_SCALE_MODEL: if not UP_SCALE_MODEL:
# Load the upscale model with progress tracking # Load the upscale model with progress tracking
UP_SCALE_MODEL = load_sd_upscale(upscale_ckpt) UP_SCALE_MODEL = load_sd_upscale(UP_SCALE_MODEL_CKPT)
latents = infer(prompt, num_inference_steps, guidance_scale, progress=progress) latents = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
if UP_SCALE_MODEL: if UP_SCALE_MODEL:
@ -331,7 +331,7 @@ with gr.Blocks() as demo:
generate_button.click( generate_button.click(
generate, generate,
inputs=[prompt, num_inference_steps, guidance_scale, up_scale_model_ckpt], inputs=[prompt, num_inference_steps, guidance_scale],
outputs=[video_output, download_video_button, download_gif_button] outputs=[video_output, download_video_button, download_gif_button]
) )