mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 18:52:08 +08:00
model_path load for env
This commit is contained in:
parent
e4e612db05
commit
27837e3c83
@ -21,7 +21,9 @@ import utils
|
|||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to(
|
MODEL_PATH = os.environ.get('MODEL_PATH', "THUDM/CogVideoX-2b")
|
||||||
|
UP_SCALE_MODEL_CKPT = os.environ.get('UP_SCALE_MODEL_CKPT', "/media/gpt4-pdf-chatbot-langchain/ComfyUI/models/upscale_models/RealESRGAN_x4.pth")
|
||||||
|
pipe = CogVideoXPipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.float16).to(
|
||||||
device)
|
device)
|
||||||
pipe.enable_model_cpu_offload()
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
@ -247,8 +249,6 @@ with gr.Blocks() as demo:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
num_inference_steps = gr.Number(label="Inference Steps", value=50)
|
num_inference_steps = gr.Number(label="Inference Steps", value=50)
|
||||||
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
|
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
|
||||||
with gr.Row():
|
|
||||||
up_scale_model_ckpt = gr.Text(label="UP_SCALE_MODEL_CKPT", value="")
|
|
||||||
generate_button = gr.Button("🎬 Generate Video")
|
generate_button = gr.Button("🎬 Generate Video")
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -294,11 +294,11 @@ with gr.Blocks() as demo:
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
def generate(prompt, num_inference_steps, guidance_scale, upscale_ckpt, progress=gr.Progress(track_tqdm=True)):
|
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
||||||
global UP_SCALE_MODEL
|
global UP_SCALE_MODEL
|
||||||
if not UP_SCALE_MODEL:
|
if not UP_SCALE_MODEL:
|
||||||
# Load the upscale model with progress tracking
|
# Load the upscale model with progress tracking
|
||||||
UP_SCALE_MODEL = load_sd_upscale(upscale_ckpt)
|
UP_SCALE_MODEL = load_sd_upscale(UP_SCALE_MODEL_CKPT)
|
||||||
|
|
||||||
latents = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
latents = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
||||||
if UP_SCALE_MODEL:
|
if UP_SCALE_MODEL:
|
||||||
@ -331,7 +331,7 @@ with gr.Blocks() as demo:
|
|||||||
|
|
||||||
generate_button.click(
|
generate_button.click(
|
||||||
generate,
|
generate,
|
||||||
inputs=[prompt, num_inference_steps, guidance_scale, up_scale_model_ckpt],
|
inputs=[prompt, num_inference_steps, guidance_scale],
|
||||||
outputs=[video_output, download_video_button, download_gif_button]
|
outputs=[video_output, download_video_button, download_gif_button]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user