Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@
"llama-3.3-70b": "Llama-3.3-70B-Instruct",
"llama-4-scout": "Llama-4-Scout-17B-16E-Instruct-FP8",
"llama-4-maverick": "Llama-4-Maverick-17B-128E-Instruct-FP8",
# Mistral AI models
"mistral-large-latest": "mistral-large-2411",
"mistral-medium-latest": "mistral-medium-2505",
"mistral-small-latest": "mistral-small-2503",
"codestral-latest": "codestral-2501",
"pixtral-large-latest": "pixtral-large-2411",
"ministral-3b-latest": "ministral-3b-2410",
"ministral-8b-latest": "ministral-8b-2410",
"open-mistral-nemo": "open-mistral-nemo-2407",
}

_MODEL_INFO: Dict[str, ModelInfo] = {
Expand Down Expand Up @@ -441,6 +450,87 @@
"structured_output": True,
"multiple_system_messages": True,
},
# Mistral AI models
"mistral-large-2411": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.MISTRAL,
"structured_output": True,
"multiple_system_messages": True,
},
"mistral-medium-2505": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.MISTRAL,
"structured_output": True,
"multiple_system_messages": True,
},
"mistral-small-2503": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.MISTRAL,
"structured_output": True,
"multiple_system_messages": True,
},
"codestral-2501": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.CODESRAL,
"structured_output": True,
"multiple_system_messages": True,
},
"open-codestral-mamba": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.OPEN_CODESRAL_MAMBA,
"structured_output": False,
"multiple_system_messages": True,
},
"pixtral-large-2411": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.PIXTRAL,
"structured_output": True,
"multiple_system_messages": True,
},
"pixtral-12b-2409": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.PIXTRAL,
"structured_output": False,
"multiple_system_messages": True,
},
"ministral-3b-2410": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.MINISTRAL,
"structured_output": False,
"multiple_system_messages": True,
},
"ministral-8b-2410": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.MINISTRAL,
"structured_output": False,
"multiple_system_messages": True,
},
"open-mistral-nemo-2407": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.MISTRAL,
"structured_output": False,
"multiple_system_messages": True,
},
}

_MODEL_TOKEN_LIMITS: Dict[str, int] = {
Expand Down Expand Up @@ -491,11 +581,38 @@
"Llama-3.3-70B-Instruct": 128000,
"Llama-4-Scout-17B-16E-Instruct-FP8": 128000,
"Llama-4-Maverick-17B-128E-Instruct-FP8": 128000,
# Mistral AI models
"mistral-large-2411": 131072,
"mistral-medium-2505": 131072,
"mistral-small-2503": 131072,
"codestral-2501": 262144,
"open-codestral-mamba": 262144,
"pixtral-large-2411": 131072,
"pixtral-12b-2409": 131072,
"ministral-3b-2410": 131072,
"ministral-8b-2410": 131072,
"open-mistral-nemo-2407": 131072,
}

GEMINI_OPENAI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
ANTHROPIC_OPENAI_BASE_URL = "https://api.anthropic.com/v1/"
LLAMA_API_BASE_URL = "https://api.llama.com/compat/v1/"
MISTRAL_API_BASE_URL = "https://api.mistral.ai/v1/"

# Mistral model name prefixes used for auto-detection.
_MISTRAL_MODEL_PREFIXES = (
"mistral-",
"codestral-",
"open-codestral-",
"pixtral-",
"ministral-",
"open-mistral-",
)


def _is_mistral_model(model: str) -> bool:
"""Check if the model name matches a Mistral AI model prefix."""
return any(model.startswith(prefix) for prefix in _MISTRAL_MODEL_PREFIXES)


def resolve_model(model: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,15 @@ def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]):
copied_args["base_url"] = _model_info.LLAMA_API_BASE_URL
if "api_key" not in copied_args and "LLAMA_API_KEY" in os.environ:
copied_args["api_key"] = os.environ["LLAMA_API_KEY"]
if _model_info._is_mistral_model(copied_args["model"]):
if "base_url" not in copied_args:
copied_args["base_url"] = _model_info.MISTRAL_API_BASE_URL
if "api_key" not in copied_args and "MISTRAL_API_KEY" in os.environ:
copied_args["api_key"] = os.environ["MISTRAL_API_KEY"]
# Mistral API rejects the 'name' field in messages (HTTP 422),
# so disable it by default unless the user explicitly set it.
if "include_name_in_message" not in kwargs:
include_name_in_message = False

client = _openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
Expand Down
152 changes: 151 additions & 1 deletion python/packages/autogen-ext/tests/models/test_openai_model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import resolve_model
from autogen_ext.models.openai._model_info import (
MISTRAL_API_BASE_URL,
_is_mistral_model,
get_info,
get_token_limit,
resolve_model,
)
from autogen_ext.models.openai._openai_client import (
BaseOpenAIChatCompletionClient,
calculate_vision_tokens,
Expand Down Expand Up @@ -191,6 +197,150 @@ async def test_openai_chat_completion_client_with_gemini_model() -> None:
assert client


@pytest.mark.asyncio
async def test_openai_chat_completion_client_with_mistral_model() -> None:
"""Test that Mistral models can be instantiated via OpenAIChatCompletionClient."""
client = OpenAIChatCompletionClient(model="mistral-small-latest", api_key="api_key")
assert client


@pytest.mark.asyncio
async def test_openai_chat_completion_client_with_mistral_model_variants() -> None:
"""Test that various Mistral model variants are recognized."""
mistral_models = [
"mistral-large-latest",
"mistral-medium-latest",
"mistral-small-latest",
"codestral-latest",
"pixtral-large-latest",
"ministral-3b-latest",
"ministral-8b-latest",
"open-mistral-nemo",
]
for model_name in mistral_models:
client = OpenAIChatCompletionClient(model=model_name, api_key="api_key")
assert client, f"Failed to create client for model: {model_name}"


def test_is_mistral_model() -> None:
"""Test the _is_mistral_model helper function."""
# Positive cases
assert _is_mistral_model("mistral-large-latest")
assert _is_mistral_model("mistral-small-2503")
assert _is_mistral_model("codestral-latest")
assert _is_mistral_model("codestral-2501")
assert _is_mistral_model("open-codestral-mamba")
assert _is_mistral_model("pixtral-large-2411")
assert _is_mistral_model("pixtral-12b-2409")
assert _is_mistral_model("ministral-3b-2410")
assert _is_mistral_model("ministral-8b-latest")
assert _is_mistral_model("open-mistral-nemo")
assert _is_mistral_model("open-mistral-nemo-2407")
# Negative cases
assert not _is_mistral_model("gpt-4o")
assert not _is_mistral_model("gemini-1.5-flash")
assert not _is_mistral_model("claude-3-5-sonnet-20241022")
assert not _is_mistral_model("Llama-4-Scout-17B-16E-Instruct-FP8")


def test_mistral_model_info() -> None:
"""Test that Mistral model info is correctly registered."""
# mistral-large
info = get_info("mistral-large-latest")
assert info["family"] == ModelFamily.MISTRAL
assert info["function_calling"] is True
assert info["json_output"] is True
assert info["vision"] is False
assert info["structured_output"] is True

# pixtral (vision model)
info = get_info("pixtral-large-latest")
assert info["family"] == ModelFamily.PIXTRAL
assert info["vision"] is True
assert info["function_calling"] is True

# codestral
info = get_info("codestral-latest")
assert info["family"] == ModelFamily.CODESRAL
assert info["function_calling"] is True

# open-codestral-mamba (no function calling)
info = get_info("open-codestral-mamba")
assert info["family"] == ModelFamily.OPEN_CODESRAL_MAMBA
assert info["function_calling"] is False

# ministral
info = get_info("ministral-8b-latest")
assert info["family"] == ModelFamily.MINISTRAL
assert info["function_calling"] is True

# open-mistral-nemo
info = get_info("open-mistral-nemo")
assert info["family"] == ModelFamily.MISTRAL
assert info["function_calling"] is True


def test_mistral_model_token_limits() -> None:
"""Test that Mistral model token limits are correctly registered."""
assert get_token_limit("mistral-large-latest") == 131072
assert get_token_limit("mistral-small-latest") == 131072
assert get_token_limit("codestral-latest") == 262144
assert get_token_limit("open-codestral-mamba") == 262144
assert get_token_limit("pixtral-large-latest") == 131072
assert get_token_limit("ministral-3b-latest") == 131072
assert get_token_limit("open-mistral-nemo") == 131072


def test_mistral_resolve_model() -> None:
"""Test that Mistral model pointers resolve correctly."""
assert resolve_model("mistral-large-latest") == "mistral-large-2411"
assert resolve_model("mistral-medium-latest") == "mistral-medium-2505"
assert resolve_model("mistral-small-latest") == "mistral-small-2503"
assert resolve_model("codestral-latest") == "codestral-2501"
assert resolve_model("pixtral-large-latest") == "pixtral-large-2411"
assert resolve_model("ministral-3b-latest") == "ministral-3b-2410"
assert resolve_model("ministral-8b-latest") == "ministral-8b-2410"
assert resolve_model("open-mistral-nemo") == "open-mistral-nemo-2407"
# Versioned models should resolve to themselves
assert resolve_model("mistral-large-2411") == "mistral-large-2411"
assert resolve_model("codestral-2501") == "codestral-2501"


@pytest.mark.asyncio
async def test_mistral_auto_base_url() -> None:
"""Test that Mistral models automatically get the correct base_url."""
client = OpenAIChatCompletionClient(model="mistral-small-latest", api_key="test_key")
assert client._raw_config.get("base_url") == MISTRAL_API_BASE_URL


@pytest.mark.asyncio
async def test_mistral_auto_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that MISTRAL_API_KEY env var is auto-detected for Mistral models."""
monkeypatch.setenv("MISTRAL_API_KEY", "test_mistral_key")
client = OpenAIChatCompletionClient(model="mistral-small-latest")
assert client._raw_config.get("api_key") == "test_mistral_key"


@pytest.mark.asyncio
async def test_mistral_include_name_disabled_by_default() -> None:
"""Test that include_name_in_message defaults to False for Mistral models.

This addresses the HTTP 422 error from Mistral API when the 'name' field
is included in messages (see issue #6147).
"""
client = OpenAIChatCompletionClient(model="mistral-small-latest", api_key="test_key")
assert client._include_name_in_message is False


@pytest.mark.asyncio
async def test_mistral_include_name_can_be_overridden() -> None:
"""Test that include_name_in_message can be explicitly set for Mistral models."""
client = OpenAIChatCompletionClient(
model="mistral-small-latest", api_key="test_key", include_name_in_message=True
)
assert client._include_name_in_message is True


@pytest.mark.asyncio
async def test_openai_chat_completion_client_serialization() -> None:
client = OpenAIChatCompletionClient(model="gpt-4.1-nano", api_key="sk-password")
Expand Down