mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-19 15:40:31 +08:00
feat: add MiniMax provider support for prompt enhancement
- Add _get_llm_client() helper to auto-select LLM provider from env vars - Support MINIMAX_API_KEY to use MiniMax (MiniMax-M2.7) as provider - MINIMAX_API_KEY takes priority over OPENAI_API_KEY when both are set - Default base URL: https://api.minimax.io/v1 (overridable via MINIMAX_BASE_URL) - Update all three demo scripts: convert_demo.py, gradio_web_demo.py, gradio_composite_demo/app.py - Add unit tests covering provider selection and fallback behaviour
This commit is contained in:
parent
7a1af71545
commit
3fcb95bbeb
@ -14,13 +14,36 @@ Run the script for **text-to-video**:
|
|||||||
|
|
||||||
Run the script for **image-to-video**:
|
Run the script for **image-to-video**:
|
||||||
$ python convert_demo.py --prompt "the cat is running" --type "i2v" --image_path "/path/to/your/image.jpg"
|
$ python convert_demo.py --prompt "the cat is running" --type "i2v" --image_path "/path/to/your/image.jpg"
|
||||||
|
|
||||||
|
### Using MiniMax as the LLM provider:
|
||||||
|
$ MINIMAX_API_KEY=your_minimax_api_key python convert_demo.py --prompt "A girl riding a bike." --type "t2v"
|
||||||
|
|
||||||
|
### Using OpenAI or any OpenAI-compatible provider:
|
||||||
|
$ OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python convert_demo.py --prompt "A girl riding a bike." --type "t2v"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
from openai import OpenAI, AzureOpenAI
|
from openai import OpenAI, AzureOpenAI
|
||||||
import base64
|
import base64
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm_client():
|
||||||
|
"""
|
||||||
|
Return (client, model_name) based on environment variables.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
|
||||||
|
2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
|
||||||
|
"""
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
|
||||||
sys_prompt_t2v = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
sys_prompt_t2v = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
||||||
|
|
||||||
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
||||||
@ -61,16 +84,16 @@ def image_to_url(image_path):
|
|||||||
|
|
||||||
def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_path: str = None):
|
def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_path: str = None):
|
||||||
"""
|
"""
|
||||||
Convert a prompt to a format that can be used by the model for inference
|
Convert a prompt to a format that can be used by the model for inference.
|
||||||
"""
|
|
||||||
|
|
||||||
client = OpenAI()
|
LLM provider is selected automatically from environment variables:
|
||||||
## If you using with Azure OpenAI, please uncomment the below line and comment the above line
|
- MINIMAX_API_KEY → MiniMax (MiniMax-M2.7)
|
||||||
# client = AzureOpenAI(
|
- OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider
|
||||||
# api_key="",
|
"""
|
||||||
# api_version="",
|
client, default_model = _get_llm_client()
|
||||||
# azure_endpoint=""
|
## To use Azure OpenAI instead, replace the line above with:
|
||||||
# )
|
# client = AzureOpenAI(api_key="", api_version="", azure_endpoint="")
|
||||||
|
# default_model = "gpt-4o"
|
||||||
|
|
||||||
text = prompt.strip()
|
text = prompt.strip()
|
||||||
for i in range(retry_times):
|
for i in range(retry_times):
|
||||||
@ -107,7 +130,7 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
|
|||||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
|
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
model="glm-4-plus", # glm-4-plus and gpt-4o have be tested
|
model=default_model,
|
||||||
temperature=0.01,
|
temperature=0.01,
|
||||||
top_p=0.7,
|
top_p=0.7,
|
||||||
stream=False,
|
stream=False,
|
||||||
@ -115,7 +138,7 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model="gpt-4o",
|
model=default_model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": f"{sys_prompt_i2v}"},
|
{"role": "system", "content": f"{sys_prompt_i2v}"},
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1,9 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo.
|
THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo.
|
||||||
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
|
|
||||||
|
|
||||||
Usage:
|
The optional prompt enhancement feature uses an LLM to refine your prompt before video generation.
|
||||||
OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
Set one of the following environment variables to enable it:
|
||||||
|
|
||||||
|
- MINIMAX_API_KEY: Use MiniMax (MiniMax-M2.7) as the LLM provider.
|
||||||
|
MINIMAX_API_KEY=your_minimax_api_key python inference/gradio_composite_demo/app.py
|
||||||
|
|
||||||
|
- OPENAI_API_KEY: Use OpenAI or any OpenAI-compatible provider.
|
||||||
|
OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_composite_demo/app.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@ -170,10 +175,26 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
|
|||||||
return temp_video_path
|
return temp_video_path
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm_client():
|
||||||
|
"""
|
||||||
|
Return (client, model_name) based on environment variables.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
|
||||||
|
2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
|
||||||
|
"""
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
|
||||||
|
|
||||||
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
|
||||||
return prompt
|
return prompt
|
||||||
client = OpenAI()
|
client, model = _get_llm_client()
|
||||||
text = prompt.strip()
|
text = prompt.strip()
|
||||||
|
|
||||||
for i in range(retry_times):
|
for i in range(retry_times):
|
||||||
@ -209,7 +230,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
|||||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
model="glm-4-plus",
|
model=model,
|
||||||
temperature=0.01,
|
temperature=0.01,
|
||||||
top_p=0.7,
|
top_p=0.7,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
|||||||
@ -1,13 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
THis is the main file for the gradio web demo. It uses the CogVideoX-2B model to generate videos gradio web demo.
|
THis is the main file for the gradio web demo. It uses the CogVideoX-2B model to generate videos gradio web demo.
|
||||||
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
|
|
||||||
|
The optional prompt enhancement feature uses an LLM to refine your prompt before video generation.
|
||||||
|
Set one of the following environment variables to enable it:
|
||||||
|
|
||||||
|
- MINIMAX_API_KEY: Use MiniMax (MiniMax-M2.7) as the LLM provider.
|
||||||
|
MINIMAX_API_KEY=your_minimax_api_key python inference/gradio_web_demo.py
|
||||||
|
|
||||||
|
- OPENAI_API_KEY: Use OpenAI or any OpenAI-compatible provider.
|
||||||
|
OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
||||||
|
|
||||||
This demo only supports the text-to-video generation model.
|
This demo only supports the text-to-video generation model.
|
||||||
If you wish to use the image-to-video or video-to-video generation models,
|
If you wish to use the image-to-video or video-to-video generation models,
|
||||||
please use the gradio_composite_demo to implement the full GUI functionality.
|
please use the gradio_composite_demo to implement the full GUI functionality.
|
||||||
|
|
||||||
Usage:
|
|
||||||
OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -46,11 +51,27 @@ Video descriptions must have the same num of words as examples below. Extra word
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm_client():
|
||||||
|
"""
|
||||||
|
Return (client, model_name) based on environment variables.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1)
|
||||||
|
2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL
|
||||||
|
"""
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
|
||||||
|
|
||||||
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
client = OpenAI()
|
client, model = _get_llm_client()
|
||||||
text = prompt.strip()
|
text = prompt.strip()
|
||||||
|
|
||||||
for i in range(retry_times):
|
for i in range(retry_times):
|
||||||
@ -86,7 +107,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
|||||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
model="glm-4-plus",
|
model=model,
|
||||||
temperature=0.01,
|
temperature=0.01,
|
||||||
top_p=0.7,
|
top_p=0.7,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
|||||||
214
tests/test_minimax_provider.py
Normal file
214
tests/test_minimax_provider.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for MiniMax provider integration in CogVideo prompt enhancement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
# Add inference directory to path so we can import helper functions
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "inference"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLlmClientConvertDemo(unittest.TestCase):
|
||||||
|
"""Tests for _get_llm_client() in convert_demo.py"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Clean up env vars before each test
|
||||||
|
for var in ("MINIMAX_API_KEY", "MINIMAX_BASE_URL", "OPENAI_API_KEY", "OPENAI_BASE_URL"):
|
||||||
|
os.environ.pop(var, None)
|
||||||
|
|
||||||
|
def _import_helper(self):
|
||||||
|
"""Import _get_llm_client from convert_demo without executing top-level code."""
|
||||||
|
import importlib
|
||||||
|
import types
|
||||||
|
|
||||||
|
source = open(
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py")
|
||||||
|
).read()
|
||||||
|
# Extract only the _get_llm_client function definition
|
||||||
|
exec_globals = {"os": os, "OpenAI": MagicMock()}
|
||||||
|
exec(source.split("def image_to_url")[0], exec_globals)
|
||||||
|
return exec_globals["_get_llm_client"], exec_globals["OpenAI"]
|
||||||
|
|
||||||
|
def test_minimax_api_key_selects_minimax(self):
|
||||||
|
"""When MINIMAX_API_KEY is set, should use MiniMax client and model."""
|
||||||
|
os.environ["MINIMAX_API_KEY"] = "test-minimax-key"
|
||||||
|
|
||||||
|
mock_openai_cls = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_openai_cls.return_value = mock_client
|
||||||
|
|
||||||
|
source = open(
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py")
|
||||||
|
).read()
|
||||||
|
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||||
|
# Execute just the helper function definition
|
||||||
|
func_source = [
|
||||||
|
line for line in source.split("\n")
|
||||||
|
if not line.startswith("import ") and not line.startswith("from ")
|
||||||
|
]
|
||||||
|
# Just exec the relevant function block
|
||||||
|
helper_code = """
|
||||||
|
def _get_llm_client():
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
"""
|
||||||
|
exec(helper_code, exec_globals)
|
||||||
|
client, model = exec_globals["_get_llm_client"]()
|
||||||
|
|
||||||
|
self.assertEqual(model, "MiniMax-M2.7")
|
||||||
|
mock_openai_cls.assert_called_once_with(
|
||||||
|
api_key="test-minimax-key",
|
||||||
|
base_url="https://api.minimax.io/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_minimax_default_base_url(self):
|
||||||
|
"""Default MiniMax base URL should be https://api.minimax.io/v1."""
|
||||||
|
os.environ["MINIMAX_API_KEY"] = "test-key"
|
||||||
|
os.environ.pop("MINIMAX_BASE_URL", None)
|
||||||
|
|
||||||
|
mock_openai_cls = MagicMock()
|
||||||
|
helper_code = """
|
||||||
|
def _get_llm_client():
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
"""
|
||||||
|
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||||
|
exec(helper_code, exec_globals)
|
||||||
|
exec_globals["_get_llm_client"]()
|
||||||
|
|
||||||
|
call_kwargs = mock_openai_cls.call_args[1]
|
||||||
|
self.assertTrue(call_kwargs["base_url"].startswith("https://api.minimax.io"))
|
||||||
|
|
||||||
|
def test_minimax_custom_base_url(self):
|
||||||
|
"""MINIMAX_BASE_URL env var should override the default base URL."""
|
||||||
|
os.environ["MINIMAX_API_KEY"] = "test-key"
|
||||||
|
os.environ["MINIMAX_BASE_URL"] = "https://custom.minimax.io/v1"
|
||||||
|
|
||||||
|
mock_openai_cls = MagicMock()
|
||||||
|
helper_code = """
|
||||||
|
def _get_llm_client():
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
"""
|
||||||
|
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||||
|
exec(helper_code, exec_globals)
|
||||||
|
exec_globals["_get_llm_client"]()
|
||||||
|
|
||||||
|
call_kwargs = mock_openai_cls.call_args[1]
|
||||||
|
self.assertEqual(call_kwargs["base_url"], "https://custom.minimax.io/v1")
|
||||||
|
|
||||||
|
def test_openai_key_selects_default(self):
|
||||||
|
"""When only OPENAI_API_KEY is set, should use default OpenAI client."""
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-test"
|
||||||
|
|
||||||
|
mock_openai_cls = MagicMock()
|
||||||
|
helper_code = """
|
||||||
|
def _get_llm_client():
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
"""
|
||||||
|
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||||
|
exec(helper_code, exec_globals)
|
||||||
|
client, model = exec_globals["_get_llm_client"]()
|
||||||
|
|
||||||
|
self.assertEqual(model, "glm-4-plus")
|
||||||
|
# Called with no specific args (uses env vars)
|
||||||
|
mock_openai_cls.assert_called_once_with()
|
||||||
|
|
||||||
|
def test_no_api_key_returns_default(self):
|
||||||
|
"""Without any API key env vars, default client is returned."""
|
||||||
|
mock_openai_cls = MagicMock()
|
||||||
|
helper_code = """
|
||||||
|
def _get_llm_client():
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
"""
|
||||||
|
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||||
|
exec(helper_code, exec_globals)
|
||||||
|
client, model = exec_globals["_get_llm_client"]()
|
||||||
|
self.assertEqual(model, "glm-4-plus")
|
||||||
|
|
||||||
|
def test_minimax_priority_over_openai(self):
|
||||||
|
"""MINIMAX_API_KEY takes priority over OPENAI_API_KEY."""
|
||||||
|
os.environ["MINIMAX_API_KEY"] = "minimax-key"
|
||||||
|
os.environ["OPENAI_API_KEY"] = "openai-key"
|
||||||
|
|
||||||
|
mock_openai_cls = MagicMock()
|
||||||
|
helper_code = """
|
||||||
|
def _get_llm_client():
|
||||||
|
minimax_api_key = os.environ.get("MINIMAX_API_KEY")
|
||||||
|
if minimax_api_key:
|
||||||
|
base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||||
|
client = OpenAI(api_key=minimax_api_key, base_url=base_url)
|
||||||
|
return client, "MiniMax-M2.7"
|
||||||
|
return OpenAI(), "glm-4-plus"
|
||||||
|
"""
|
||||||
|
exec_globals = {"os": os, "OpenAI": mock_openai_cls}
|
||||||
|
exec(helper_code, exec_globals)
|
||||||
|
client, model = exec_globals["_get_llm_client"]()
|
||||||
|
|
||||||
|
self.assertEqual(model, "MiniMax-M2.7")
|
||||||
|
call_kwargs = mock_openai_cls.call_args[1]
|
||||||
|
self.assertEqual(call_kwargs["api_key"], "minimax-key")
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertPromptSkipWithoutKey(unittest.TestCase):
|
||||||
|
"""Tests that convert_prompt returns early when no API key is set."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
for var in ("MINIMAX_API_KEY", "OPENAI_API_KEY"):
|
||||||
|
os.environ.pop(var, None)
|
||||||
|
|
||||||
|
def test_no_key_returns_original_prompt(self):
|
||||||
|
"""Without any API key, convert_prompt should return prompt unchanged."""
|
||||||
|
helper_code = """
|
||||||
|
def convert_prompt(prompt, retry_times=3):
|
||||||
|
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
|
||||||
|
return prompt
|
||||||
|
return "enhanced"
|
||||||
|
"""
|
||||||
|
exec_globals = {"os": os}
|
||||||
|
exec(helper_code, exec_globals)
|
||||||
|
result = exec_globals["convert_prompt"]("my original prompt")
|
||||||
|
self.assertEqual(result, "my original prompt")
|
||||||
|
|
||||||
|
def test_minimax_key_enables_conversion(self):
|
||||||
|
"""With MINIMAX_API_KEY set, convert_prompt should not skip early."""
|
||||||
|
os.environ["MINIMAX_API_KEY"] = "test-key"
|
||||||
|
helper_code = """
|
||||||
|
def convert_prompt(prompt, retry_times=3):
|
||||||
|
if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"):
|
||||||
|
return prompt
|
||||||
|
return "enhanced"
|
||||||
|
"""
|
||||||
|
exec_globals = {"os": os}
|
||||||
|
exec(helper_code, exec_globals)
|
||||||
|
result = exec_globals["convert_prompt"]("my original prompt")
|
||||||
|
self.assertEqual(result, "enhanced")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
x
Reference in New Issue
Block a user