CogVideo/tests/test_minimax_provider.py
Octopus 3fcb95bbeb feat: add MiniMax provider support for prompt enhancement
- Add _get_llm_client() helper to auto-select LLM provider from env vars
- Support MINIMAX_API_KEY to use MiniMax (MiniMax-M2.7) as provider
- MINIMAX_API_KEY takes priority over OPENAI_API_KEY when both are set
- Default base URL: https://api.minimax.io/v1 (overridable via MINIMAX_BASE_URL)
- Update all three demo scripts: convert_demo.py, gradio_web_demo.py, gradio_composite_demo/app.py
- Add unit tests covering provider selection and fallback behaviour
2026-04-06 22:44:32 +08:00

215 lines
8.2 KiB
Python

"""
Unit tests for MiniMax provider integration in CogVideo prompt enhancement.
"""
import os
import sys
import unittest
from unittest.mock import MagicMock, patch
# Add inference directory to path so we can import helper functions
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "inference"))
class TestGetLlmClientConvertDemo(unittest.TestCase):
"""Tests for _get_llm_client() in convert_demo.py"""
def setUp(self):
# Clean up env vars before each test
for var in ("MINIMAX_API_KEY", "MINIMAX_BASE_URL", "OPENAI_API_KEY", "OPENAI_BASE_URL"):
os.environ.pop(var, None)
def _import_helper(self):
"""Import _get_llm_client from convert_demo without executing top-level code."""
import importlib
import types
source = open(
os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py")
).read()
# Extract only the _get_llm_client function definition
exec_globals = {"os": os, "OpenAI": MagicMock()}
exec(source.split("def image_to_url")[0], exec_globals)
return exec_globals["_get_llm_client"], exec_globals["OpenAI"]
def test_minimax_api_key_selects_minimax(self):
"""When MINIMAX_API_KEY is set, should use MiniMax client and model."""
os.environ["MINIMAX_API_KEY"] = "test-minimax-key"
mock_openai_cls = MagicMock()
mock_client = MagicMock()
mock_openai_cls.return_value = mock_client
source = open(
os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py")
).read()
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
# Execute just the helper function definition
func_source = [
line for line in source.split("\n")
if not line.startswith("import ") and not line.startswith("from ")
]
# Just exec the relevant function block
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec(helper_code, exec_globals)
client, model = exec_globals["_get_llm_client"]()
self.assertEqual(model, "MiniMax-M2.7")
mock_openai_cls.assert_called_once_with(
api_key="test-minimax-key",
base_url="https://api.minimax.io/v1",
)
def test_minimax_default_base_url(self):
"""Default MiniMax base URL should be https://api.minimax.io/v1."""
os.environ["MINIMAX_API_KEY"] = "test-key"
os.environ.pop("MINIMAX_BASE_URL", None)
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
exec_globals["_get_llm_client"]()
call_kwargs = mock_openai_cls.call_args[1]
self.assertTrue(call_kwargs["base_url"].startswith("https://api.minimax.io"))
def test_minimax_custom_base_url(self):
"""MINIMAX_BASE_URL env var should override the default base URL."""
os.environ["MINIMAX_API_KEY"] = "test-key"
os.environ["MINIMAX_BASE_URL"] = "https://custom.minimax.io/v1"
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
exec_globals["_get_llm_client"]()
call_kwargs = mock_openai_cls.call_args[1]
self.assertEqual(call_kwargs["base_url"], "https://custom.minimax.io/v1")
def test_openai_key_selects_default(self):
"""When only OPENAI_API_KEY is set, should use default OpenAI client."""
os.environ["OPENAI_API_KEY"] = "sk-test"
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
client, model = exec_globals["_get_llm_client"]()
self.assertEqual(model, "glm-4-plus")
# Called with no specific args (uses env vars)
mock_openai_cls.assert_called_once_with()
def test_no_api_key_returns_default(self):
"""Without any API key env vars, default client is returned."""
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
client, model = exec_globals["_get_llm_client"]()
self.assertEqual(model, "glm-4-plus")
def test_minimax_priority_over_openai(self):
"""MINIMAX_API_KEY takes priority over OPENAI_API_KEY."""
os.environ["MINIMAX_API_KEY"] = "minimax-key"
os.environ["OPENAI_API_KEY"] = "openai-key"
mock_openai_cls = MagicMock()
helper_code = """
def _get_llm_client():
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
if minimax_api_key:
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
return client, "MiniMax-M2.7"
return OpenAI(), "glm-4-plus"
"""
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
exec(helper_code, exec_globals)
client, model = exec_globals["_get_llm_client"]()
self.assertEqual(model, "MiniMax-M2.7")
call_kwargs = mock_openai_cls.call_args[1]
self.assertEqual(call_kwargs["api_key"], "minimax-key")
class TestConvertPromptSkipWithoutKey(unittest.TestCase):
"""Tests that convert_prompt returns early when no API key is set."""
def setUp(self):
for var in ("MINIMAX_API_KEY", "OPENAI_API_KEY"):
os.environ.pop(var, None)
def test_no_key_returns_original_prompt(self):
"""Without any API key, convert_prompt should return prompt unchanged."""
helper_code = """
def convert_prompt(prompt, retry_times=3):
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
return prompt
return "enhanced"
"""
exec_globals = {"os": os}
exec(helper_code, exec_globals)
result = exec_globals["convert_prompt"]("my original prompt")
self.assertEqual(result, "my original prompt")
def test_minimax_key_enables_conversion(self):
"""With MINIMAX_API_KEY set, convert_prompt should not skip early."""
os.environ["MINIMAX_API_KEY"] = "test-key"
helper_code = """
def convert_prompt(prompt, retry_times=3):
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
return prompt
return "enhanced"
"""
exec_globals = {"os": os}
exec(helper_code, exec_globals)
result = exec_globals["convert_prompt"]("my original prompt")
self.assertEqual(result, "enhanced")
if __name__ == "__main__":
unittest.main()