Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions langkit/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
OpenAIGPT4,
OpenAIDefault,
OpenAILegacy,
MiniMaxDefault,
MiniMaxHighSpeed,
)

__ALL__ = [
Expand All @@ -20,4 +22,6 @@
OpenAIDefault,
OpenAIGPT4,
OpenAILegacy,
MiniMaxDefault,
MiniMaxHighSpeed,
]
67 changes: 67 additions & 0 deletions langkit/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@


openai.api_key = os.getenv("OPENAI_API_KEY")

_minimax_default_model = "MiniMax-M2.7"
_minimax_highspeed_model = "MiniMax-M2.7-highspeed"
_minimax_default_base_url = "https://api.minimax.io/v1"


def create_minimax_chat_completion(messages, api_key, base_url, **params):
openai_version = get_openai_version()
if openai_version.startswith("0."):
openai.api_base = base_url
openai.api_key = api_key
return openai.ChatCompletion.create(messages=messages, **params)
else:
client = openai.OpenAI(api_key=api_key, base_url=base_url)
return client.chat.completions.create(messages=messages, **params)
_openai_llm_model = os.getenv("LANGKIT_OPENAI_LLM_MODEL_NAME") or "gpt-3.5-turbo"
_llm_model_temperature = 0.9
_llm_model_max_tokens = 1024
Expand Down Expand Up @@ -373,6 +388,58 @@ def send_prompt(self, prompt: str) -> ChatLog:
)


@dataclass
class MiniMaxDefault(LLMInvocationParams):
"""MiniMax chat model provider using the OpenAI-compatible API."""

model: str = field(default_factory=lambda: _minimax_default_model)
temperature: float = field(default_factory=lambda: 1.0)
max_tokens: int = field(default_factory=lambda: _llm_model_max_tokens)
frequency_penalty: float = field(default_factory=lambda: _llm_model_frequency_penalty)
presence_penalty: float = field(default_factory=lambda: _llm_model_presence_penalty)

def completion(self, messages: List[Dict[str, str]], **kwargs):
params = asdict(self)
api_key = os.getenv("MINIMAX_API_KEY")
base_url = os.getenv("MINIMAX_BASE_URL") or _minimax_default_base_url
return create_minimax_chat_completion(messages=messages, api_key=api_key, base_url=base_url, **params)

def copy(self) -> LLMInvocationParams:
return MiniMaxDefault(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
)


@dataclass
class MiniMaxHighSpeed(LLMInvocationParams):
"""MiniMax high-speed chat model provider using the OpenAI-compatible API."""

model: str = field(default_factory=lambda: _minimax_highspeed_model)
temperature: float = field(default_factory=lambda: 1.0)
max_tokens: int = field(default_factory=lambda: _llm_model_max_tokens)
frequency_penalty: float = field(default_factory=lambda: _llm_model_frequency_penalty)
presence_penalty: float = field(default_factory=lambda: _llm_model_presence_penalty)

def completion(self, messages: List[Dict[str, str]], **kwargs):
params = asdict(self)
api_key = os.getenv("MINIMAX_API_KEY")
base_url = os.getenv("MINIMAX_BASE_URL") or _minimax_default_base_url
return create_minimax_chat_completion(messages=messages, api_key=api_key, base_url=base_url, **params)

def copy(self) -> LLMInvocationParams:
return MiniMaxHighSpeed(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
)


# this is just for demonstration purposes
def send_prompt(prompt: str) -> ChatLog:
try:
Expand Down
108 changes: 108 additions & 0 deletions langkit/tests/test_minimax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Unit tests for MiniMax LLM provider classes."""
from dataclasses import asdict
from unittest.mock import MagicMock, patch
from langkit.openai.openai import MiniMaxDefault, MiniMaxHighSpeed, _minimax_default_base_url


def test_minimax_default_model():
provider = MiniMaxDefault()
assert provider.model == "MiniMax-M2.7"


def test_minimax_highspeed_model():
provider = MiniMaxHighSpeed()
assert provider.model == "MiniMax-M2.7-highspeed"


def test_minimax_default_temperature():
provider = MiniMaxDefault()
assert provider.temperature == 1.0


def test_minimax_highspeed_temperature():
provider = MiniMaxHighSpeed()
assert provider.temperature == 1.0


def test_minimax_default_base_url():
assert _minimax_default_base_url == "https://api.minimax.io/v1"


def test_minimax_default_copy():
provider = MiniMaxDefault()
copy = provider.copy()
assert isinstance(copy, MiniMaxDefault)
assert copy.model == provider.model
assert copy.temperature == provider.temperature


def test_minimax_highspeed_copy():
provider = MiniMaxHighSpeed()
copy = provider.copy()
assert isinstance(copy, MiniMaxHighSpeed)
assert copy.model == provider.model
assert copy.temperature == provider.temperature


def test_minimax_default_completion(monkeypatch):
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")

mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello!"
mock_response.usage.total_tokens = 10

with patch("langkit.openai.openai.create_minimax_chat_completion", return_value=mock_response) as mock_fn:
provider = MiniMaxDefault()
messages = [{"role": "user", "content": "Hi"}]
result = provider.completion(messages)

mock_fn.assert_called_once()
call_kwargs = mock_fn.call_args[1]
assert call_kwargs["api_key"] == "test-key"
assert call_kwargs["base_url"] == "https://api.minimax.io/v1"
assert call_kwargs["model"] == "MiniMax-M2.7"
assert call_kwargs["temperature"] == 1.0
assert result == mock_response


def test_minimax_highspeed_completion(monkeypatch):
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")

mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello fast!"
mock_response.usage.total_tokens = 5

with patch("langkit.openai.openai.create_minimax_chat_completion", return_value=mock_response) as mock_fn:
provider = MiniMaxHighSpeed()
messages = [{"role": "user", "content": "Hi"}]
result = provider.completion(messages)

mock_fn.assert_called_once()
call_kwargs = mock_fn.call_args[1]
assert call_kwargs["api_key"] == "test-key"
assert call_kwargs["base_url"] == "https://api.minimax.io/v1"
assert call_kwargs["model"] == "MiniMax-M2.7-highspeed"
assert call_kwargs["temperature"] == 1.0
assert result == mock_response


def test_minimax_custom_base_url(monkeypatch):
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")
monkeypatch.setenv("MINIMAX_BASE_URL", "https://custom.minimax.io/v1")

mock_response = MagicMock()
with patch("langkit.openai.openai.create_minimax_chat_completion", return_value=mock_response) as mock_fn:
provider = MiniMaxDefault()
messages = [{"role": "user", "content": "Hi"}]
provider.completion(messages)

call_kwargs = mock_fn.call_args[1]
assert call_kwargs["base_url"] == "https://custom.minimax.io/v1"


def test_minimax_exported_from_openai_module():
from langkit.openai import MiniMaxDefault as MD, MiniMaxHighSpeed as MH
assert MD is MiniMaxDefault
assert MH is MiniMaxHighSpeed