Skip to content

Commit d6ac48b

Browse files
author
Dylan Huang
authored
Mcp config url (#18)
* save * support remote MCP in evaulation_test * format string better * fix cleanup
1 parent 52b46a7 commit d6ac48b

File tree

6 files changed

+140
-57
lines changed

6 files changed

+140
-57
lines changed

eval_protocol/mcp/mcp_multi_client.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,21 @@
22
import os
33
from contextlib import AsyncExitStack
44
from dataclasses import dataclass
5-
from typing import Any, Dict, List, Optional
5+
from typing import Any, Dict, List, Optional, Union
66

77
from dotenv import load_dotenv
88
from mcp import ClientSession, StdioServerParameters
99
from mcp.client.stdio import stdio_client
10+
from mcp.client.streamable_http import streamablehttp_client
1011
from mcp.types import CallToolResult
1112
from openai.types import FunctionDefinition
1213
from openai.types.chat import ChatCompletionToolParam
1314

14-
from eval_protocol.types.types import MCPMultiClientConfiguration
15+
from eval_protocol.models import (
16+
MCPConfigurationServerStdio,
17+
MCPConfigurationServerUrl,
18+
MCPMultiClientConfiguration,
19+
)
1520

1621
load_dotenv() # load environment variables from .env
1722

@@ -38,10 +43,10 @@ def _load_config(self, config_path: Optional[str] = None) -> MCPMultiClientConfi
3843
"""Load MCP server configuration from file or use default"""
3944
if config_path and os.path.exists(config_path):
4045
with open(config_path, "r") as f:
41-
return json.load(f)
46+
return MCPMultiClientConfiguration(**json.load(f))
4247

4348
# Default configuration - can be overridden by config file
44-
return {"mcpServers": {}}
49+
return MCPMultiClientConfiguration(mcpServers={})
4550

4651
def _validate_environment_variables(self, server_name: str, required_env: List[str]) -> None:
4752
"""Validate that required environment variables are set in os.environ"""
@@ -59,35 +64,54 @@ def _validate_environment_variables(self, server_name: str, required_env: List[s
5964

6065
async def connect_to_servers(self):
6166
"""Connect to all configured MCP servers"""
62-
if not self.config.get("mcpServers"):
67+
if not self.config.mcpServers:
6368
print("No MCP servers configured. Please provide a configuration file.")
6469
return
6570

66-
for server_name, server_config in self.config["mcpServers"].items():
71+
for server_name, server_config in self.config.mcpServers.items():
6772
try:
6873
await self._connect_to_server(server_name, server_config)
6974
except Exception as e:
7075
print(f"Failed to connect to server '{server_name}': {e}")
7176

72-
async def _connect_to_server(self, server_name: str, server_config: Dict[str, Any]):
77+
async def _connect_to_server(
78+
self, server_name: str, server_config: Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]
79+
):
7380
"""Connect to a specific MCP server using its configuration"""
74-
command = server_config.get("command")
75-
args = server_config.get("args", [])
76-
env_config = server_config.get("env", [])
77-
78-
if not command:
79-
raise ValueError(f"Server '{server_name}' must have a 'command' specified")
80-
81-
# Validate that required environment variables are set
82-
if env_config:
83-
self._validate_environment_variables(server_name, env_config)
84-
85-
# Use the current system environment (os.environ) - don't override with config
86-
server_params = StdioServerParameters(command=command, args=args, env=os.environ)
87-
88-
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
89-
stdio, write = stdio_transport
90-
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
81+
session: ClientSession
82+
83+
if isinstance(server_config, MCPConfigurationServerStdio):
84+
# Handle stdio-based MCP server
85+
command = server_config.command
86+
args = server_config.args
87+
env_config = server_config.env
88+
89+
if not command:
90+
raise ValueError(f"Server '{server_name}' must have a 'command' specified")
91+
92+
# Validate that required environment variables are set
93+
if env_config:
94+
self._validate_environment_variables(server_name, env_config)
95+
96+
# Use the current system environment (os.environ) - don't override with config
97+
server_params = StdioServerParameters(command=command, args=args, env=os.environ)
98+
99+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
100+
stdio, write = stdio_transport
101+
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
102+
103+
elif isinstance(server_config, MCPConfigurationServerUrl):
104+
# Handle HTTP-based MCP server
105+
url = server_config.url
106+
if not url:
107+
raise ValueError(f"Server '{server_name}' must have a 'url' specified")
108+
109+
# Connect using streamable HTTP client - manage resources manually
110+
http_transport = await self.exit_stack.enter_async_context(streamablehttp_client(url))
111+
read_stream, write_stream, get_session_id = http_transport
112+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
113+
else:
114+
raise ValueError(f"Unsupported server configuration type: {type(server_config)}")
91115

92116
await session.initialize()
93117
self.sessions[server_name] = session

eval_protocol/models.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Any, Dict, List, Literal, Optional, Union
22

33
from openai.types import CompletionUsage
44
from openai.types.chat.chat_completion_message import (
@@ -8,11 +8,18 @@
88
from pydantic import BaseModel, ConfigDict, Field
99

1010

11+
class ChatCompletionContentPartTextParam(BaseModel):
12+
text: str = Field(..., description="The text content.")
13+
type: Literal["text"] = Field("text", description="The type of the content part.")
14+
15+
1116
class Message(BaseModel):
1217
"""Chat message model with trajectory evaluation support."""
1318

14-
role: str
15-
content: Optional[str] = "" # Content can be None for tool calls in OpenAI API
19+
role: str # assistant, user, system, tool
20+
content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]] = Field(
21+
default="", description="The content of the message."
22+
)
1623
name: Optional[str] = None
1724
tool_call_id: Optional[str] = None
1825
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
@@ -426,3 +433,23 @@ class Config:
426433
# from pydantic import ConfigDict
427434
# model_config = ConfigDict(extra='allow')
428435
# For Pydantic v1, `Config.extra = "allow"` is correct.
436+
437+
438+
class MCPConfigurationServerStdio(BaseModel):
439+
"""Represents a MCP configuration server."""
440+
441+
command: str # command to run the MCP server
442+
args: List[str] = Field(default_factory=list) # to pass to the command
443+
env: List[str] = Field(default_factory=list) # List of environment variables to verify exist in the environment
444+
445+
446+
class MCPConfigurationServerUrl(BaseModel):
447+
"""Represents a Remote MCP configuration server."""
448+
449+
url: str # url to the MCP server
450+
451+
452+
class MCPMultiClientConfiguration(BaseModel):
453+
"""Represents a MCP configuration."""
454+
455+
mcpServers: Dict[str, Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]]

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import os
44
from typing import Any, List, Optional, Union
55

6-
from mcp.types import CallToolResult
6+
from mcp.types import CallToolResult, TextContent
77
from openai import NOT_GIVEN, NotGiven
8-
from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
8+
from openai.types.chat import ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionToolParam
99
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
1010

1111
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
@@ -57,16 +57,16 @@ async def call_agent(self) -> str:
5757
tool_tasks.append(task)
5858

5959
# Execute all tool calls in parallel
60-
tool_results = await asyncio.gather(*tool_tasks)
60+
tool_results: List[List[TextContent]] = await asyncio.gather(*tool_tasks)
6161

6262
# Add all tool results to messages (they will be in the same order as tool_calls)
6363
for tool_call, (tool_call_id, content) in zip(message["tool_calls"], tool_results):
6464
self.messages.append(
65-
{
66-
"role": "tool",
67-
"content": content,
68-
"tool_call_id": tool_call_id,
69-
}
65+
Message(
66+
role="tool",
67+
content=content,
68+
tool_call_id=tool_call_id,
69+
)
7070
)
7171
return await self.call_agent()
7272
return message["content"]
@@ -88,15 +88,12 @@ async def _execute_tool_call(self, tool_call_id: str, tool_name: str, tool_args_
8888
content = self._get_content_from_tool_result(tool_result)
8989
return tool_call_id, content
9090

91-
def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str:
91+
def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[TextContent]:
9292
if tool_result.structuredContent:
9393
return json.dumps(tool_result.structuredContent)
94-
if len(tool_result.content) > 1:
95-
raise NotImplementedError("Multiple content is not supported yet")
96-
first_content = tool_result.content[0]
97-
if first_content.type != "text":
94+
if not all(isinstance(content, TextContent) for content in tool_result.content):
9895
raise NotImplementedError("Non-text content is not supported yet")
99-
return first_content.text
96+
return tool_result.content[0].text
10097

10198

10299
async def default_agent_rollout_processor(
@@ -108,4 +105,6 @@ async def default_agent_rollout_processor(
108105
await agent.setup()
109106
await agent.call_agent()
110107
dataset.append(EvaluationRow(messages=agent.messages, ground_truth=row.ground_truth))
108+
if agent.mcp_client:
109+
await agent.mcp_client.cleanup()
111110
return dataset

eval_protocol/types/types.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,3 @@ class Trajectory:
7171
termination_reason: str
7272
conversation_history: List[Dict[str, Any]]
7373
usage: Dict[str, int] = field(default_factory=dict)
74-
75-
76-
@dataclass
77-
class MCPConfigurationServer:
78-
"""Represents a MCP configuration server."""
79-
80-
command: str # command to run the MCP server
81-
args: List[str] # to pass to the command
82-
env: List[str] # List of environment variables to verify exist in the environment
83-
84-
85-
@dataclass
86-
class MCPMultiClientConfiguration:
87-
"""Represents a MCP configuration."""
88-
89-
mcp_servers: Dict[str, MCPConfigurationServer]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"mcpServers": {
3+
"docs.fireworks.ai": {
4+
"url": "https://docs.fireworks.ai/mcp"
5+
}
6+
}
7+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from eval_protocol.models import EvaluateResult, Message, EvaluationRow
2+
from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test
3+
4+
5+
@evaluation_test(
6+
input_messages=[
7+
[
8+
Message(
9+
role="system",
10+
content=(
11+
"You are a helpful assistant that can answer questions about Fireworks.\n"
12+
"ALWAYS provide code or commands to execute to answer the question."
13+
),
14+
),
15+
Message(
16+
role="user",
17+
content=("Can you teach me about how to manage deployments on Fireworks"),
18+
),
19+
]
20+
],
21+
rollout_processor=default_agent_rollout_processor,
22+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
23+
mode="pointwise",
24+
mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config.json",
25+
)
26+
def test_pytest_mcp_url(row: EvaluationRow) -> EvaluationRow:
27+
"""Run math evaluation on sample dataset using pytest interface."""
28+
# filter for all tool calls
29+
tool_calls = [msg for msg in row.messages if msg.role == "tool"]
30+
31+
if len(tool_calls) == 0:
32+
row.evaluation_result = EvaluateResult(
33+
score=0,
34+
feedback="No tool calls made",
35+
)
36+
return row
37+
38+
row.evaluation_result = EvaluateResult(
39+
score=1,
40+
feedback="At least one tool call was made",
41+
)
42+
return row

0 commit comments

Comments
 (0)