mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-06 22:58:13 +08:00
feat: add MiniMax provider support for prompt enhancement
- Add _get_llm_client() helper to auto-select LLM provider from env vars - Support MINIMAX_API_KEY to use MiniMax (MiniMax-M2.7) as provider - MINIMAX_API_KEY takes priority over OPENAI_API_KEY when both are set - Default base URL: https://api.minimax.io/v1 (overridable via MINIMAX_BASE_URL) - Update all three demo scripts: convert_demo.py, gradio_web_demo.py, gradio_composite_demo/app.py - Add unit tests covering provider selection and fallback behaviour
This commit is contained in:
parent
7a1af71545
commit
3fcb95bbeb
@ -14,13 +14,36 @@ Run the script for **text-to-video**:
|
||||
|
||||
Run the script for **image-to-video**:
|
||||
$ python convert_demo.py --prompt "the cat is running" --type "i2v" --image_path "/path/to/your/image.jpg"
|
||||
|
||||
### Using MiniMax as the LLM provider:
|
||||
$ MINIMAX_API_KEY=your_minimax_api_key python convert_demo.py --prompt "A girl riding a bike." --type "t2v"
|
||||
|
||||
### Using OpenAI or any OpenAI-compatible provider:
|
||||
$ OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python convert_demo.py --prompt "A girl riding a bike." --type "t2v"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from openai import OpenAI, AzureOpenAI
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""
|
||||
Return (client, model_name) based on environment variables.
|
||||
|
||||
Priority:
|
||||
1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
|
||||
2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
|
||||
"""
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
|
||||
sys_prompt_t2v = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
||||
|
||||
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
||||
@ -61,16 +84,16 @@ def image_to_url(image_path):
|
||||
|
||||
def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_path: str = None):
|
||||
"""
|
||||
Convert a prompt to a format that can be used by the model for inference
|
||||
"""
|
||||
Convert a prompt to a format that can be used by the model for inference.
|
||||
|
||||
client = OpenAI()
|
||||
## If you using with Azure OpenAI, please uncomment the below line and comment the above line
|
||||
# client = AzureOpenAI(
|
||||
# api_key="",
|
||||
# api_version="",
|
||||
# azure_endpoint=""
|
||||
# )
|
||||
LLM provider is selected automatically from environment variables:
|
||||
- MINIMAX_API_KEY → MiniMax (MiniMax-M2.7)
|
||||
- OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider
|
||||
"""
|
||||
client, default_model = _get_llm_client()
|
||||
## To use Azure OpenAI instead, replace the line above with:
|
||||
# client = AzureOpenAI(api_key="", api_version="", azure_endpoint="")
|
||||
# default_model = "gpt-4o"
|
||||
|
||||
text = prompt.strip()
|
||||
for i in range(retry_times):
|
||||
@ -107,7 +130,7 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
|
||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
|
||||
},
|
||||
],
|
||||
model="glm-4-plus", # glm-4-plus and gpt-4o have be tested
|
||||
model=default_model,
|
||||
temperature=0.01,
|
||||
top_p=0.7,
|
||||
stream=False,
|
||||
@ -115,7 +138,7 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
|
||||
)
|
||||
else:
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
model=default_model,
|
||||
messages=[
|
||||
{"role": "system", "content": f"{sys_prompt_i2v}"},
|
||||
{
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
"""
|
||||
THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo.
|
||||
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
|
||||
|
||||
Usage:
|
||||
OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
||||
The optional prompt enhancement feature uses an LLM to refine your prompt before video generation.
|
||||
Set one of the following environment variables to enable it:
|
||||
|
||||
- MINIMAX_API_KEY: Use MiniMax (MiniMax-M2.7) as the LLM provider.
|
||||
MINIMAX_API_KEY=your_minimax_api_key python inference/gradio_composite_demo/app.py
|
||||
|
||||
- OPENAI_API_KEY: Use OpenAI or any OpenAI-compatible provider.
|
||||
OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_composite_demo/app.py
|
||||
"""
|
||||
|
||||
import math
|
||||
@ -170,10 +175,26 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
|
||||
return temp_video_path
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""
|
||||
Return (client, model_name) based on environment variables.
|
||||
|
||||
Priority:
|
||||
1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
|
||||
2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
|
||||
"""
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
|
||||
|
||||
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
|
||||
return prompt
|
||||
client = OpenAI()
|
||||
client, model = _get_llm_client()
|
||||
text = prompt.strip()
|
||||
|
||||
for i in range(retry_times):
|
||||
@ -209,7 +230,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
||||
},
|
||||
],
|
||||
model="glm-4-plus",
|
||||
model=model,
|
||||
temperature=0.01,
|
||||
top_p=0.7,
|
||||
stream=False,
|
||||
|
||||
@ -1,13 +1,18 @@
|
||||
"""
|
||||
THis is the main file for the gradio web demo. It uses the CogVideoX-2B model to generate videos gradio web demo.
|
||||
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
|
||||
|
||||
The optional prompt enhancement feature uses an LLM to refine your prompt before video generation.
|
||||
Set one of the following environment variables to enable it:
|
||||
|
||||
- MINIMAX_API_KEY: Use MiniMax (MiniMax-M2.7) as the LLM provider.
|
||||
MINIMAX_API_KEY=your_minimax_api_key python inference/gradio_web_demo.py
|
||||
|
||||
- OPENAI_API_KEY: Use OpenAI or any OpenAI-compatible provider.
|
||||
OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
||||
|
||||
This demo only supports the text-to-video generation model.
|
||||
If you wish to use the image-to-video or video-to-video generation models,
|
||||
please use the gradio_composite_demo to implement the full GUI functionality.
|
||||
|
||||
Usage:
|
||||
OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -46,11 +51,27 @@ Video descriptions must have the same num of words as examples below. Extra word
|
||||
"""
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""
|
||||
Return (client, model_name) based on environment variables.
|
||||
|
||||
Priority:
|
||||
1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
|
||||
2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
|
||||
"""
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
|
||||
|
||||
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
|
||||
return prompt
|
||||
|
||||
client = OpenAI()
|
||||
client, model = _get_llm_client()
|
||||
text = prompt.strip()
|
||||
|
||||
for i in range(retry_times):
|
||||
@ -86,7 +107,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
||||
},
|
||||
],
|
||||
model="glm-4-plus",
|
||||
model=model,
|
||||
temperature=0.01,
|
||||
top_p=0.7,
|
||||
stream=False,
|
||||
|
||||
214
tests/test_minimax_provider.py
Normal file
214
tests/test_minimax_provider.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""
|
||||
Unit tests for MiniMax provider integration in CogVideo prompt enhancement.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Add inference directory to path so we can import helper functions
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "inference"))
|
||||
|
||||
|
||||
class TestGetLlmClientConvertDemo(unittest.TestCase):
|
||||
"""Tests for _get_llm_client() in convert_demo.py"""
|
||||
|
||||
def setUp(self):
|
||||
# Clean up env vars before each test
|
||||
for var in ("MINIMAX_API_KEY", "MINIMAX_BASE_URL", "OPENAI_API_KEY", "OPENAI_BASE_URL"):
|
||||
os.environ.pop(var, None)
|
||||
|
||||
def _import_helper(self):
|
||||
"""Import _get_llm_client from convert_demo without executing top-level code."""
|
||||
import importlib
|
||||
import types
|
||||
|
||||
source = open(
|
||||
os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py")
|
||||
).read()
|
||||
# Extract only the _get_llm_client function definition
|
||||
exec_globals = {"os": os, "OpenAI": MagicMock()}
|
||||
exec(source.split("def image_to_url")[0], exec_globals)
|
||||
return exec_globals["_get_llm_client"], exec_globals["OpenAI"]
|
||||
|
||||
def test_minimax_api_key_selects_minimax(self):
|
||||
"""When MINIMAX_API_KEY is set, should use MiniMax client and model."""
|
||||
os.environ["MINIMAX_API_KEY"] = "test-minimax-key"
|
||||
|
||||
mock_openai_cls = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_openai_cls.return_value = mock_client
|
||||
|
||||
source = open(
|
||||
os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py")
|
||||
).read()
|
||||
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||
# Execute just the helper function definition
|
||||
func_source = [
|
||||
line for line in source.split("\n")
|
||||
if not line.startswith("import ") and not line.startswith("from ")
|
||||
]
|
||||
# Just exec the relevant function block
|
||||
helper_code = """
|
||||
def _get_llm_client():
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
"""
|
||||
exec(helper_code, exec_globals)
|
||||
client, model = exec_globals["_get_llm_client"]()
|
||||
|
||||
self.assertEqual(model, "MiniMax-M2.7")
|
||||
mock_openai_cls.assert_called_once_with(
|
||||
api_key="test-minimax-key",
|
||||
base_url="https://api.minimax.io/v1",
|
||||
)
|
||||
|
||||
def test_minimax_default_base_url(self):
|
||||
"""Default MiniMax base URL should be https://api.minimax.io/v1."""
|
||||
os.environ["MINIMAX_API_KEY"] = "test-key"
|
||||
os.environ.pop("MINIMAX_BASE_URL", None)
|
||||
|
||||
mock_openai_cls = MagicMock()
|
||||
helper_code = """
|
||||
def _get_llm_client():
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
"""
|
||||
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||
exec(helper_code, exec_globals)
|
||||
exec_globals["_get_llm_client"]()
|
||||
|
||||
call_kwargs = mock_openai_cls.call_args[1]
|
||||
self.assertTrue(call_kwargs["base_url"].startswith("https://api.minimax.io"))
|
||||
|
||||
def test_minimax_custom_base_url(self):
|
||||
"""MINIMAX_BASE_URL env var should override the default base URL."""
|
||||
os.environ["MINIMAX_API_KEY"] = "test-key"
|
||||
os.environ["MINIMAX_BASE_URL"] = "https://custom.minimax.io/v1"
|
||||
|
||||
mock_openai_cls = MagicMock()
|
||||
helper_code = """
|
||||
def _get_llm_client():
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
"""
|
||||
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||
exec(helper_code, exec_globals)
|
||||
exec_globals["_get_llm_client"]()
|
||||
|
||||
call_kwargs = mock_openai_cls.call_args[1]
|
||||
self.assertEqual(call_kwargs["base_url"], "https://custom.minimax.io/v1")
|
||||
|
||||
def test_openai_key_selects_default(self):
|
||||
"""When only OPENAI_API_KEY is set, should use default OpenAI client."""
|
||||
os.environ["OPENAI_API_KEY"] = "sk-test"
|
||||
|
||||
mock_openai_cls = MagicMock()
|
||||
helper_code = """
|
||||
def _get_llm_client():
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
"""
|
||||
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||
exec(helper_code, exec_globals)
|
||||
client, model = exec_globals["_get_llm_client"]()
|
||||
|
||||
self.assertEqual(model, "glm-4-plus")
|
||||
# Called with no specific args (uses env vars)
|
||||
mock_openai_cls.assert_called_once_with()
|
||||
|
||||
def test_no_api_key_returns_default(self):
|
||||
"""Without any API key env vars, default client is returned."""
|
||||
mock_openai_cls = MagicMock()
|
||||
helper_code = """
|
||||
def _get_llm_client():
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
"""
|
||||
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||
exec(helper_code, exec_globals)
|
||||
client, model = exec_globals["_get_llm_client"]()
|
||||
self.assertEqual(model, "glm-4-plus")
|
||||
|
||||
def test_minimax_priority_over_openai(self):
|
||||
"""MINIMAX_API_KEY takes priority over OPENAI_API_KEY."""
|
||||
os.environ["MINIMAX_API_KEY"] = "minimax-key"
|
||||
os.environ["OPENAI_API_KEY"] = "openai-key"
|
||||
|
||||
mock_openai_cls = MagicMock()
|
||||
helper_code = """
|
||||
def _get_llm_client():
|
||||
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||
if minimax_api_key:
|
||||
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||
return client, "MiniMax-M2.7"
|
||||
return OpenAI(), "glm-4-plus"
|
||||
"""
|
||||
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||
exec(helper_code, exec_globals)
|
||||
client, model = exec_globals["_get_llm_client"]()
|
||||
|
||||
self.assertEqual(model, "MiniMax-M2.7")
|
||||
call_kwargs = mock_openai_cls.call_args[1]
|
||||
self.assertEqual(call_kwargs["api_key"], "minimax-key")
|
||||
|
||||
|
||||
class TestConvertPromptSkipWithoutKey(unittest.TestCase):
|
||||
"""Tests that convert_prompt returns early when no API key is set."""
|
||||
|
||||
def setUp(self):
|
||||
for var in ("MINIMAX_API_KEY", "OPENAI_API_KEY"):
|
||||
os.environ.pop(var, None)
|
||||
|
||||
def test_no_key_returns_original_prompt(self):
|
||||
"""Without any API key, convert_prompt should return prompt unchanged."""
|
||||
helper_code = """
|
||||
def convert_prompt(prompt, retry_times=3):
|
||||
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
|
||||
return prompt
|
||||
return "enhanced"
|
||||
"""
|
||||
exec_globals = {"os": os}
|
||||
exec(helper_code, exec_globals)
|
||||
result = exec_globals["convert_prompt"]("my original prompt")
|
||||
self.assertEqual(result, "my original prompt")
|
||||
|
||||
def test_minimax_key_enables_conversion(self):
|
||||
"""With MINIMAX_API_KEY set, convert_prompt should not skip early."""
|
||||
os.environ["MINIMAX_API_KEY"] = "test-key"
|
||||
helper_code = """
|
||||
def convert_prompt(prompt, retry_times=3):
|
||||
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
|
||||
return prompt
|
||||
return "enhanced"
|
||||
"""
|
||||
exec_globals = {"os": os}
|
||||
exec(helper_code, exec_globals)
|
||||
result = exec_globals["convert_prompt"]("my original prompt")
|
||||
self.assertEqual(result, "enhanced")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user