22import os
33from contextlib import AsyncExitStack
44from dataclasses import dataclass
5- from typing import Any , Dict , List , Optional
5+ from typing import Any , Dict , List , Optional , Union
66
77from dotenv import load_dotenv
88from mcp import ClientSession , StdioServerParameters
99from mcp .client .stdio import stdio_client
10+ from mcp .client .streamable_http import streamablehttp_client
1011from mcp .types import CallToolResult
1112from openai .types import FunctionDefinition
1213from openai .types .chat import ChatCompletionToolParam
1314
14- from eval_protocol .types .types import MCPMultiClientConfiguration
15+ from eval_protocol .models import (
16+ MCPConfigurationServerStdio ,
17+ MCPConfigurationServerUrl ,
18+ MCPMultiClientConfiguration ,
19+ )
1520
1621load_dotenv () # load environment variables from .env
1722
@@ -38,10 +43,10 @@ def _load_config(self, config_path: Optional[str] = None) -> MCPMultiClientConfi
3843 """Load MCP server configuration from file or use default"""
3944 if config_path and os .path .exists (config_path ):
4045 with open (config_path , "r" ) as f :
41- return json .load (f )
46+ return MCPMultiClientConfiguration ( ** json .load (f ) )
4247
4348 # Default configuration - can be overridden by config file
44- return { " mcpServers" : {}}
49+ return MCPMultiClientConfiguration ( mcpServers = {})
4550
4651 def _validate_environment_variables (self , server_name : str , required_env : List [str ]) -> None :
4752 """Validate that required environment variables are set in os.environ"""
@@ -59,35 +64,54 @@ def _validate_environment_variables(self, server_name: str, required_env: List[s
5964
6065 async def connect_to_servers (self ):
6166 """Connect to all configured MCP servers"""
62- if not self .config .get ( " mcpServers" ) :
67+ if not self .config .mcpServers :
6368 print ("No MCP servers configured. Please provide a configuration file." )
6469 return
6570
66- for server_name , server_config in self .config [ " mcpServers" ] .items ():
71+ for server_name , server_config in self .config . mcpServers .items ():
6772 try :
6873 await self ._connect_to_server (server_name , server_config )
6974 except Exception as e :
7075 print (f"Failed to connect to server '{ server_name } ': { e } " )
7176
72- async def _connect_to_server (self , server_name : str , server_config : Dict [str , Any ]):
77+ async def _connect_to_server (
78+ self , server_name : str , server_config : Union [MCPConfigurationServerStdio , MCPConfigurationServerUrl ]
79+ ):
7380 """Connect to a specific MCP server using its configuration"""
74- command = server_config .get ("command" )
75- args = server_config .get ("args" , [])
76- env_config = server_config .get ("env" , [])
77-
78- if not command :
79- raise ValueError (f"Server '{ server_name } ' must have a 'command' specified" )
80-
81- # Validate that required environment variables are set
82- if env_config :
83- self ._validate_environment_variables (server_name , env_config )
84-
85- # Use the current system environment (os.environ) - don't override with config
86- server_params = StdioServerParameters (command = command , args = args , env = os .environ )
87-
88- stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
89- stdio , write = stdio_transport
90- session = await self .exit_stack .enter_async_context (ClientSession (stdio , write ))
81+ session : ClientSession
82+
83+ if isinstance (server_config , MCPConfigurationServerStdio ):
84+ # Handle stdio-based MCP server
85+ command = server_config .command
86+ args = server_config .args
87+ env_config = server_config .env
88+
89+ if not command :
90+ raise ValueError (f"Server '{ server_name } ' must have a 'command' specified" )
91+
92+ # Validate that required environment variables are set
93+ if env_config :
94+ self ._validate_environment_variables (server_name , env_config )
95+
96+ # Use the current system environment (os.environ) - don't override with config
97+ server_params = StdioServerParameters (command = command , args = args , env = os .environ )
98+
99+ stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
100+ stdio , write = stdio_transport
101+ session = await self .exit_stack .enter_async_context (ClientSession (stdio , write ))
102+
103+ elif isinstance (server_config , MCPConfigurationServerUrl ):
104+ # Handle HTTP-based MCP server
105+ url = server_config .url
106+ if not url :
107+ raise ValueError (f"Server '{ server_name } ' must have a 'url' specified" )
108+
109+ # Connect using streamable HTTP client - manage resources manually
110+ http_transport = await self .exit_stack .enter_async_context (streamablehttp_client (url ))
111+ read_stream , write_stream , get_session_id = http_transport
112+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
113+ else :
114+ raise ValueError (f"Unsupported server configuration type: { type (server_config )} " )
91115
92116 await session .initialize ()
93117 self .sessions [server_name ] = session
0 commit comments