feat: add MiniMax provider support for prompt enhancement

- Add _get_llm_client() helper to auto-select LLM provider from env vars
- Support MINIMAX_API_KEY to use MiniMax (MiniMax-M2.7) as provider
- MINIMAX_API_KEY takes priority over OPENAI_API_KEY when both are set
- Default base URL: https://api.minimax.io/v1 (overridable via MINIMAX_BASE_URL)
- Update all three demo scripts: convert_demo.py, gradio_web_demo.py, gradio_composite_demo/app.py
- Add unit tests covering provider selection and fallback behaviour
This commit is contained in:
Octopus 2026-04-06 22:44:32 +08:00
parent 7a1af71545
commit 3fcb95bbeb
4 changed files with 303 additions and 24 deletions

View File

@ -14,13 +14,36 @@ Run the script for **text-to-video**:
Run the script for **image-to-video**:
$ python convert_demo.py --prompt "the cat is running" --type "i2v" --image_path "/path/to/your/image.jpg"
### Using MiniMax as the LLM provider:
$ MINIMAX_API_KEY=your_minimax_api_key python convert_demo.py --prompt "A girl riding a bike." --type "t2v"
### Using OpenAI or any OpenAI-compatible provider:
$ OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python convert_demo.py --prompt "A girl riding a bike." --type "t2v"
"""
import argparse
import os
from openai import OpenAI, AzureOpenAI
import base64
from mimetypes import guess_type
def _get_llm_client():
"""
Return (client, model_name) based on environment variables.
Priority:
1. MINIMAX_API_KEY MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
2. OPENAI_API_KEY OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
"""
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
sys_prompt_t2v = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
@ -61,16 +84,16 @@ def image_to_url(image_path):
def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_path: str = None):
"""
Convert a prompt to a format that can be used by the model for inference
"""
Convert a prompt to a format that can be used by the model for inference.
client = OpenAI()
## If you using with Azure OpenAI, please uncomment the below line and comment the above line
# client = AzureOpenAI(
# api_key="",
# api_version="",
# azure_endpoint=""
# )
LLM provider is selected automatically from environment variables:
- MINIMAX_API_KEY MiniMax (MiniMax-M2.7)
- OPENAI_API_KEY OpenAI or any OpenAI-compatible provider
"""
client, default_model = _get_llm_client()
## To use Azure OpenAI instead, replace the line above with:
# client = AzureOpenAI(api_key="", api_version="", azure_endpoint="")
# default_model = "gpt-4o"
text = prompt.strip()
for i in range(retry_times):
@ -107,7 +130,7 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
},
],
model="glm-4-plus", # glm-4-plus and gpt-4o have be tested
model=default_model,
temperature=0.01,
top_p=0.7,
stream=False,
@ -115,7 +138,7 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
)
else:
response = client.chat.completions.create(
model="gpt-4o",
model=default_model,
messages=[
{"role": "system", "content": f"{sys_prompt_i2v}"},
{

View File

@ -1,9 +1,14 @@
"""
THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo.
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
Usage:
OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
The optional prompt enhancement feature uses an LLM to refine your prompt before video generation.
Set one of the following environment variables to enable it:
- MINIMAX_API_KEY: Use MiniMax (MiniMax-M2.7) as the LLM provider.
MINIMAX_API_KEY=your_minimax_api_key python inference/gradio_composite_demo/app.py
- OPENAI_API_KEY: Use OpenAI or any OpenAI-compatible provider.
OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_composite_demo/app.py
"""
import math
@ -170,10 +175,26 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
return temp_video_path
def _get_llm_client():
"""
Return (client, model_name) based on environment variables.
Priority:
1. MINIMAX_API_KEY MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
2. OPENAI_API_KEY OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
"""
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
if not os.environ.get("OPENAI_API_KEY"):
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
return prompt
client = OpenAI()
client, model = _get_llm_client()
text = prompt.strip()
for i in range(retry_times):
@ -209,7 +230,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
},
],
model="glm-4-plus",
model=model,
temperature=0.01,
top_p=0.7,
stream=False,

View File

@ -1,13 +1,18 @@
"""
THis is the main file for the gradio web demo. It uses the CogVideoX-2B model to generate videos gradio web demo.
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
The optional prompt enhancement feature uses an LLM to refine your prompt before video generation.
Set one of the following environment variables to enable it:
- MINIMAX_API_KEY: Use MiniMax (MiniMax-M2.7) as the LLM provider.
MINIMAX_API_KEY=your_minimax_api_key python inference/gradio_web_demo.py
- OPENAI_API_KEY: Use OpenAI or any OpenAI-compatible provider.
OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
This demo only supports the text-to-video generation model.
If you wish to use the image-to-video or video-to-video generation models,
please use the gradio_composite_demo to implement the full GUI functionality.
Usage:
OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
"""
import os
@ -46,11 +51,27 @@ Video descriptions must have the same num of words as examples below. Extra word
"""
def _get_llm_client():
"""
Return (client, model_name) based on environment variables.
Priority:
1. MINIMAX_API_KEY MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
2. OPENAI_API_KEY OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
"""
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
if not os.environ.get("OPENAI_API_KEY"):
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
return prompt
client = OpenAI()
client, model = _get_llm_client()
text = prompt.strip()
for i in range(retry_times):
@ -86,7 +107,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
},
],
model="glm-4-plus",
model=model,
temperature=0.01,
top_p=0.7,
stream=False,

View File

@ -0,0 +1,214 @@
"""
Unit tests for MiniMax provider integration in CogVideo prompt enhancement.
"""
import os
import sys
import unittest
from unittest.mock import MagicMock, patch
# Add inference directory to path so we can import helper functions
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "inference"))
class TestGetLlmClientConvertDemo(unittest.TestCase):
"""Tests for _get_llm_client() in convert_demo.py"""
def setUp(self):
# Clean up env vars before each test
for var in ("MINIMAX_API_KEY", "MINIMAX_BASE_URL", "OPENAI_API_KEY", "OPENAI_BASE_URL"):
os.environ.pop(var, None)
def _import_helper(self):
"""Import _get_llm_client from convert_demo without executing top-level code."""
import importlib
import types
source = open(
os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py")
).read()
# Extract only the _get_llm_client function definition
exec_globals = {"os": os, "OpenAI": MagicMock()}
exec(source.split("def image_to_url")[0], exec_globals)
return exec_globals["_get_llm_client"], exec_globals["OpenAI"]
def test_minimax_api_key_selects_minimax(self):
"""When MINIMAX_API_KEY is set, should use MiniMax client and model."""
os.environ["MINIMAX_API_KEY"] = "test-minimax-key"
mock_openai_cls = MagicMock()
mock_client = MagicMock()
mock_openai_cls.return_value = mock_client
source = open(
os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py")
).read()
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
# Execute just the helper function definition
func_source = [
line for line in source.split("\n")
if not line.startswith("import ") and not line.startswith("from ")
]
# Just exec the relevant function block
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec(helper_code, exec_globals)
client, model = exec_globals["_get_llm_client"]()
self.assertEqual(model, "MiniMax-M2.7")
mock_openai_cls.assert_called_once_with(
api_key="test-minimax-key",
base_url="https://api.minimax.io/v1",
)
def test_minimax_default_base_url(self):
"""Default MiniMax base URL should be https://api.minimax.io/v1."""
os.environ["MINIMAX_API_KEY"] = "test-key"
os.environ.pop("MINIMAX_BASE_URL", None)
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
exec_globals["_get_llm_client"]()
call_kwargs = mock_openai_cls.call_args[1]
self.assertTrue(call_kwargs["base_url"].startswith("https://api.minimax.io"))
def test_minimax_custom_base_url(self):
"""MINIMAX_BASE_URL env var should override the default base URL."""
os.environ["MINIMAX_API_KEY"] = "test-key"
os.environ["MINIMAX_BASE_URL"] = "https://custom.minimax.io/v1"
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
exec_globals["_get_llm_client"]()
call_kwargs = mock_openai_cls.call_args[1]
self.assertEqual(call_kwargs["base_url"], "https://custom.minimax.io/v1")
def test_openai_key_selects_default(self):
"""When only OPENAI_API_KEY is set, should use default OpenAI client."""
os.environ["OPENAI_API_KEY"] = "sk-test"
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
client, model = exec_globals["_get_llm_client"]()
self.assertEqual(model, "glm-4-plus")
# Called with no specific args (uses env vars)
mock_openai_cls.assert_called_once_with()
def test_no_api_key_returns_default(self):
"""Without any API key env vars, default client is returned."""
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
client, model = exec_globals["_get_llm_client"]()
self.assertEqual(model, "glm-4-plus")
def test_minimax_priority_over_openai(self):
"""MINIMAX_API_KEY takes priority over OPENAI_API_KEY."""
os.environ["MINIMAX_API_KEY"] = "minimax-key"
os.environ["OPENAI_API_KEY"] = "openai-key"
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
client, model = exec_globals["_get_llm_client"]()
self.assertEqual(model, "MiniMax-M2.7")
call_kwargs = mock_openai_cls.call_args[1]
self.assertEqual(call_kwargs["api_key"], "minimax-key")
class TestConvertPromptSkipWithoutKey(unittest.TestCase):
"""Tests that convert_prompt returns early when no API key is set."""
def setUp(self):
for var in ("MINIMAX_API_KEY", "OPENAI_API_KEY"):
os.environ.pop(var, None)
def test_no_key_returns_original_prompt(self):
"""Without any API key, convert_prompt should return prompt unchanged."""
helper_code = """
def convert_prompt(prompt, retry_times=3):
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
return prompt
return "enhanced"
"""
exec_globals = {"os": os}
exec(helper_code, exec_globals)
result = exec_globals["convert_prompt"]("my original prompt")
self.assertEqual(result, "my original prompt")
def test_minimax_key_enables_conversion(self):
"""With MINIMAX_API_KEY set, convert_prompt should not skip early."""
os.environ["MINIMAX_API_KEY"] = "test-key"
helper_code = """
def convert_prompt(prompt, retry_times=3):
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
return prompt
return "enhanced"
"""
exec_globals = {"os": os}
exec(helper_code, exec_globals)
result = exec_globals["convert_prompt"]("my original prompt")
self.assertEqual(result, "enhanced")
if __name__ == "__main__":
unittest.main()