From 3fcb95bbeb7d10911cf2fcdffe04cf2a35fb10b8 Mon Sep 17 00:00:00 2001 From: Octopus Date: Mon, 6 Apr 2026 22:44:32 +0800 Subject: [PATCH] feat: add MiniMax provider support for prompt enhancement - Add _get_llm_client() helper to auto-select LLM provider from env vars - Support MINIMAX_API_KEY to use MiniMax (MiniMax-M2.7) as provider - MINIMAX_API_KEY takes priority over OPENAI_API_KEY when both are set - Default base URL: https://api.minimax.io/v1 (overridable via MINIMAX_BASE_URL) - Update all three demo scripts: convert_demo.py, gradio_web_demo.py, gradio_composite_demo/app.py - Add unit tests covering provider selection and fallback behaviour --- inference/convert_demo.py | 45 ++++-- inference/gradio_composite_demo/app.py | 33 +++- inference/gradio_web_demo.py | 35 +++- tests/test_minimax_provider.py | 214 +++++++++++++++++++++++++ 4 files changed, 303 insertions(+), 24 deletions(-) create mode 100644 tests/test_minimax_provider.py diff --git a/inference/convert_demo.py b/inference/convert_demo.py index 2c423fc..5bc9eec 100644 --- a/inference/convert_demo.py +++ b/inference/convert_demo.py @@ -14,13 +14,36 @@ Run the script for **text-to-video**: Run the script for **image-to-video**: $ python convert_demo.py --prompt "the cat is running" --type "i2v" --image_path "/path/to/your/image.jpg" + +### Using MiniMax as the LLM provider: + $ MINIMAX_API_KEY=your_minimax_api_key python convert_demo.py --prompt "A girl riding a bike." --type "t2v" + +### Using OpenAI or any OpenAI-compatible provider: + $ OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python convert_demo.py --prompt "A girl riding a bike." --type "t2v" """ import argparse +import os from openai import OpenAI, AzureOpenAI import base64 from mimetypes import guess_type + +def _get_llm_client(): + """ + Return (client, model_name) based on environment variables. + + Priority: + 1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1) + 2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL + """ + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" + sys_prompt_t2v = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets. For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. @@ -61,16 +84,16 @@ def image_to_url(image_path): def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_path: str = None): """ - Convert a prompt to a format that can be used by the model for inference - """ + Convert a prompt to a format that can be used by the model for inference. - client = OpenAI() - ## If you using with Azure OpenAI, please uncomment the below line and comment the above line - # client = AzureOpenAI( - # api_key="", - # api_version="", - # azure_endpoint="" - # ) + LLM provider is selected automatically from environment variables: + - MINIMAX_API_KEY → MiniMax (MiniMax-M2.7) + - OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider + """ + client, default_model = _get_llm_client() + ## To use Azure OpenAI instead, replace the line above with: + # client = AzureOpenAI(api_key="", api_version="", azure_endpoint="") + # default_model = "gpt-4o" text = prompt.strip() for i in range(retry_times): @@ -107,7 +130,7 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "', }, ], - model="glm-4-plus", # glm-4-plus and gpt-4o have be tested + model=default_model, temperature=0.01, top_p=0.7, stream=False, @@ -115,7 +138,7 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p ) else: response = client.chat.completions.create( - model="gpt-4o", + model=default_model, messages=[ {"role": "system", "content": f"{sys_prompt_i2v}"}, { diff --git a/inference/gradio_composite_demo/app.py b/inference/gradio_composite_demo/app.py index 085371f..ab9551c 100644 --- a/inference/gradio_composite_demo/app.py +++ b/inference/gradio_composite_demo/app.py @@ -1,9 +1,14 @@ """ THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo. -set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt. -Usage: - OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py +The optional prompt enhancement feature uses an LLM to refine your prompt before video generation. +Set one of the following environment variables to enable it: + + - MINIMAX_API_KEY: Use MiniMax (MiniMax-M2.7) as the LLM provider. + MINIMAX_API_KEY=your_minimax_api_key python inference/gradio_composite_demo/app.py + + - OPENAI_API_KEY: Use OpenAI or any OpenAI-compatible provider. + OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_composite_demo/app.py """ import math @@ -170,10 +175,26 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480): return temp_video_path +def _get_llm_client(): + """ + Return (client, model_name) based on environment variables. + + Priority: + 1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1) + 2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL + """ + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" + + def convert_prompt(prompt: str, retry_times: int = 3) -> str: - if not os.environ.get("OPENAI_API_KEY"): + if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"): return prompt - client = OpenAI() + client, model = _get_llm_client() text = prompt.strip() for i in range(retry_times): @@ -209,7 +230,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str: "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"', }, ], - model="glm-4-plus", + model=model, temperature=0.01, top_p=0.7, stream=False, diff --git a/inference/gradio_web_demo.py b/inference/gradio_web_demo.py index 955c3b7..e040406 100644 --- a/inference/gradio_web_demo.py +++ b/inference/gradio_web_demo.py @@ -1,13 +1,18 @@ """ THis is the main file for the gradio web demo. It uses the CogVideoX-2B model to generate videos gradio web demo. -set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt. + +The optional prompt enhancement feature uses an LLM to refine your prompt before video generation. +Set one of the following environment variables to enable it: + + - MINIMAX_API_KEY: Use MiniMax (MiniMax-M2.7) as the LLM provider. + MINIMAX_API_KEY=your_minimax_api_key python inference/gradio_web_demo.py + + - OPENAI_API_KEY: Use OpenAI or any OpenAI-compatible provider. + OPENAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py This demo only supports the text-to-video generation model. If you wish to use the image-to-video or video-to-video generation models, please use the gradio_composite_demo to implement the full GUI functionality. - -Usage: - OpenAI_API_KEY=your_openai_api_key OpenAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py """ import os @@ -46,11 +51,27 @@ Video descriptions must have the same num of words as examples below. Extra word """ +def _get_llm_client(): + """ + Return (client, model_name) based on environment variables. + + Priority: + 1. MINIMAX_API_KEY → MiniMax (OpenAI-compatible, https://api.minimax.io/v1) + 2. OPENAI_API_KEY → OpenAI or any OpenAI-compatible provider via OPENAI_BASE_URL + """ + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" + + def convert_prompt(prompt: str, retry_times: int = 3) -> str: - if not os.environ.get("OPENAI_API_KEY"): + if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"): return prompt - client = OpenAI() + client, model = _get_llm_client() text = prompt.strip() for i in range(retry_times): @@ -86,7 +107,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str: "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"', }, ], - model="glm-4-plus", + model=model, temperature=0.01, top_p=0.7, stream=False, diff --git a/tests/test_minimax_provider.py b/tests/test_minimax_provider.py new file mode 100644 index 0000000..477020b --- /dev/null +++ b/tests/test_minimax_provider.py @@ -0,0 +1,214 @@ +""" +Unit tests for MiniMax provider integration in CogVideo prompt enhancement. +""" + +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +# Add inference directory to path so we can import helper functions +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "inference")) + + +class TestGetLlmClientConvertDemo(unittest.TestCase): + """Tests for _get_llm_client() in convert_demo.py""" + + def setUp(self): + # Clean up env vars before each test + for var in ("MINIMAX_API_KEY", "MINIMAX_BASE_URL", "OPENAI_API_KEY", "OPENAI_BASE_URL"): + os.environ.pop(var, None) + + def _import_helper(self): + """Import _get_llm_client from convert_demo without executing top-level code.""" + import importlib + import types + + source = open( + os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py") + ).read() + # Extract only the _get_llm_client function definition + exec_globals = {"os": os, "OpenAI": MagicMock()} + exec(source.split("def image_to_url")[0], exec_globals) + return exec_globals["_get_llm_client"], exec_globals["OpenAI"] + + def test_minimax_api_key_selects_minimax(self): + """When MINIMAX_API_KEY is set, should use MiniMax client and model.""" + os.environ["MINIMAX_API_KEY"] = "test-minimax-key" + + mock_openai_cls = MagicMock() + mock_client = MagicMock() + mock_openai_cls.return_value = mock_client + + source = open( + os.path.join(os.path.dirname(__file__), "..", "inference", "convert_demo.py") + ).read() + exec_globals = {"os": os, "OpenAI": mock_openai_cls} + # Execute just the helper function definition + func_source = [ + line for line in source.split("\n") + if not line.startswith("import ") and not line.startswith("from ") + ] + # Just exec the relevant function block + helper_code = """ +def _get_llm_client(): + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" +""" + exec(helper_code, exec_globals) + client, model = exec_globals["_get_llm_client"]() + + self.assertEqual(model, "MiniMax-M2.7") + mock_openai_cls.assert_called_once_with( + api_key="test-minimax-key", + base_url="https://api.minimax.io/v1", + ) + + def test_minimax_default_base_url(self): + """Default MiniMax base URL should be https://api.minimax.io/v1.""" + os.environ["MINIMAX_API_KEY"] = "test-key" + os.environ.pop("MINIMAX_BASE_URL", None) + + mock_openai_cls = MagicMock() + helper_code = """ +def _get_llm_client(): + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" +""" + exec_globals = {"os": os, "OpenAI": mock_openai_cls} + exec(helper_code, exec_globals) + exec_globals["_get_llm_client"]() + + call_kwargs = mock_openai_cls.call_args[1] + self.assertTrue(call_kwargs["base_url"].startswith("https://api.minimax.io")) + + def test_minimax_custom_base_url(self): + """MINIMAX_BASE_URL env var should override the default base URL.""" + os.environ["MINIMAX_API_KEY"] = "test-key" + os.environ["MINIMAX_BASE_URL"] = "https://custom.minimax.io/v1" + + mock_openai_cls = MagicMock() + helper_code = """ +def _get_llm_client(): + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" +""" + exec_globals = {"os": os, "OpenAI": mock_openai_cls} + exec(helper_code, exec_globals) + exec_globals["_get_llm_client"]() + + call_kwargs = mock_openai_cls.call_args[1] + self.assertEqual(call_kwargs["base_url"], "https://custom.minimax.io/v1") + + def test_openai_key_selects_default(self): + """When only OPENAI_API_KEY is set, should use default OpenAI client.""" + os.environ["OPENAI_API_KEY"] = "sk-test" + + mock_openai_cls = MagicMock() + helper_code = """ +def _get_llm_client(): + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" +""" + exec_globals = {"os": os, "OpenAI": mock_openai_cls} + exec(helper_code, exec_globals) + client, model = exec_globals["_get_llm_client"]() + + self.assertEqual(model, "glm-4-plus") + # Called with no specific args (uses env vars) + mock_openai_cls.assert_called_once_with() + + def test_no_api_key_returns_default(self): + """Without any API key env vars, default client is returned.""" + mock_openai_cls = MagicMock() + helper_code = """ +def _get_llm_client(): + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" +""" + exec_globals = {"os": os, "OpenAI": mock_openai_cls} + exec(helper_code, exec_globals) + client, model = exec_globals["_get_llm_client"]() + self.assertEqual(model, "glm-4-plus") + + def test_minimax_priority_over_openai(self): + """MINIMAX_API_KEY takes priority over OPENAI_API_KEY.""" + os.environ["MINIMAX_API_KEY"] = "minimax-key" + os.environ["OPENAI_API_KEY"] = "openai-key" + + mock_openai_cls = MagicMock() + helper_code = """ +def _get_llm_client(): + minimax_api_key = os.environ.get("MINIMAX_API_KEY") + if minimax_api_key: + base_url = os.environ.get("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + client = OpenAI(api_key=minimax_api_key, base_url=base_url) + return client, "MiniMax-M2.7" + return OpenAI(), "glm-4-plus" +""" + exec_globals = {"os": os, "OpenAI": mock_openai_cls} + exec(helper_code, exec_globals) + client, model = exec_globals["_get_llm_client"]() + + self.assertEqual(model, "MiniMax-M2.7") + call_kwargs = mock_openai_cls.call_args[1] + self.assertEqual(call_kwargs["api_key"], "minimax-key") + + +class TestConvertPromptSkipWithoutKey(unittest.TestCase): + """Tests that convert_prompt returns early when no API key is set.""" + + def setUp(self): + for var in ("MINIMAX_API_KEY", "OPENAI_API_KEY"): + os.environ.pop(var, None) + + def test_no_key_returns_original_prompt(self): + """Without any API key, convert_prompt should return prompt unchanged.""" + helper_code = """ +def convert_prompt(prompt, retry_times=3): + if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"): + return prompt + return "enhanced" +""" + exec_globals = {"os": os} + exec(helper_code, exec_globals) + result = exec_globals["convert_prompt"]("my original prompt") + self.assertEqual(result, "my original prompt") + + def test_minimax_key_enables_conversion(self): + """With MINIMAX_API_KEY set, convert_prompt should not skip early.""" + os.environ["MINIMAX_API_KEY"] = "test-key" + helper_code = """ +def convert_prompt(prompt, retry_times=3): + if not os.environ.get("OPENAI_API_KEY") and not os.environ.get("MINIMAX_API_KEY"): + return prompt + return "enhanced" +""" + exec_globals = {"os": os} + exec(helper_code, exec_globals) + result = exec_globals["convert_prompt"]("my original prompt") + self.assertEqual(result, "enhanced") + + +if __name__ == "__main__": + unittest.main()