66and session management.
77"""
88
9+ import contextlib
910import logging
1011from collections .abc import AsyncGenerator , Awaitable , Callable
1112from contextlib import asynccontextmanager
1920from httpx_sse import EventSource , ServerSentEvent , aconnect_sse
2021from typing_extensions import deprecated
2122
22- from mcp .shared ._httpx_utils import McpHttpClientFactory , create_mcp_http_client
23+ from mcp .shared ._httpx_utils import (
24+ MCP_DEFAULT_SSE_READ_TIMEOUT ,
25+ MCP_DEFAULT_TIMEOUT ,
26+ McpHttpClientFactory ,
27+ create_mcp_http_client ,
28+ )
2329from mcp .shared .message import ClientMessageMetadata , SessionMessage
2430from mcp .types import (
2531 ErrorData ,
@@ -102,9 +108,9 @@ def __init__(
102108 self .session_id = None
103109 self .protocol_version = None
104110 self .request_headers = {
111+ ** self .headers ,
105112 ACCEPT : f"{ JSON } , { SSE } " ,
106113 CONTENT_TYPE : JSON ,
107- ** self .headers ,
108114 }
109115
110116 def _prepare_request_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
@@ -450,12 +456,9 @@ def get_session_id(self) -> str | None:
450456@asynccontextmanager
451457async def streamable_http_client (
452458 url : str ,
453- headers : dict [str , str ] | None = None ,
454- timeout : float | timedelta = 30 ,
455- sse_read_timeout : float | timedelta = 60 * 5 ,
459+ * ,
460+ httpx_client : httpx .AsyncClient | None = None ,
456461 terminate_on_close : bool = True ,
457- httpx_client_factory : McpHttpClientFactory = create_mcp_http_client ,
458- auth : httpx .Auth | None = None ,
459462) -> AsyncGenerator [
460463 tuple [
461464 MemoryObjectReceiveStream [SessionMessage | Exception ],
@@ -467,30 +470,57 @@ async def streamable_http_client(
467470 """
468471 Client transport for StreamableHTTP.
469472
470- `sse_read_timeout` determines how long (in seconds) the client will wait for a new
471- event before disconnecting. All other HTTP operations are controlled by `timeout`.
473+ Args:
474+ url: The MCP server endpoint URL.
475+ httpx_client: Optional pre-configured httpx.AsyncClient. If None, a default
476+ client with recommended MCP timeouts will be created. To configure headers,
477+ authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
478+ terminate_on_close: If True, send a DELETE request to terminate the session
479+ when the context exits.
472480
473481 Yields:
474482 Tuple containing:
475483 - read_stream: Stream for reading messages from the server
476484 - write_stream: Stream for sending messages to the server
477485 - get_session_id_callback: Function to retrieve the current session ID
478- """
479- transport = StreamableHTTPTransport (url , headers , timeout , sse_read_timeout , auth )
480486
487+ Example:
488+ See examples/snippets/clients/ for usage patterns.
489+ """
481490 read_stream_writer , read_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
482491 write_stream , write_stream_reader = anyio .create_memory_object_stream [SessionMessage ](0 )
483492
493+ # Determine if we need to create and manage the client
494+ client_provided = httpx_client is not None
495+ client = httpx_client
496+
497+ if client is None :
498+ # Create default client with recommended MCP timeouts
499+ client = create_mcp_http_client ()
500+
501+ # Extract configuration from the client to pass to transport
502+ headers_dict = dict (client .headers ) if client .headers else None
503+ timeout = client .timeout .connect if (client .timeout and client .timeout .connect is not None ) else MCP_DEFAULT_TIMEOUT
504+ sse_read_timeout = (
505+ client .timeout .read if (client .timeout and client .timeout .read is not None ) else MCP_DEFAULT_SSE_READ_TIMEOUT
506+ )
507+ auth = client .auth
508+
509+ # Create transport with extracted configuration
510+ transport = StreamableHTTPTransport (url , headers_dict , timeout , sse_read_timeout , auth )
511+
512+ # Sync client headers with transport's merged headers (includes MCP protocol requirements)
513+ client .headers .update (transport .request_headers )
514+
484515 async with anyio .create_task_group () as tg :
485516 try :
486517 logger .debug (f"Connecting to StreamableHTTP endpoint: { url } " )
487518
488- async with httpx_client_factory (
489- headers = transport .request_headers ,
490- timeout = httpx .Timeout (transport .timeout , read = transport .sse_read_timeout ),
491- auth = transport .auth ,
492- ) as client :
493- # Define callbacks that need access to tg
519+ async with contextlib .AsyncExitStack () as stack :
520+ # Only manage client lifecycle if we created it
521+ if not client_provided :
522+ await stack .enter_async_context (client )
523+
494524 def start_get_stream () -> None :
495525 tg .start_soon (transport .handle_get_stream , client , read_stream_writer )
496526
@@ -537,7 +567,24 @@ async def streamablehttp_client(
537567 ],
538568 None ,
539569]:
540- async with streamable_http_client (
541- url , headers , timeout , sse_read_timeout , terminate_on_close , httpx_client_factory , auth
542- ) as streams :
543- yield streams
570+ # Convert timeout parameters
571+ timeout_seconds = timeout .total_seconds () if isinstance (timeout , timedelta ) else timeout
572+ sse_read_timeout_seconds = (
573+ sse_read_timeout .total_seconds () if isinstance (sse_read_timeout , timedelta ) else sse_read_timeout
574+ )
575+
576+ # Create httpx client using the factory with old-style parameters
577+ client = httpx_client_factory (
578+ headers = headers ,
579+ timeout = httpx .Timeout (timeout_seconds , read = sse_read_timeout_seconds ),
580+ auth = auth ,
581+ )
582+
583+ # Manage client lifecycle since we created it
584+ async with client :
585+ async with streamable_http_client (
586+ url ,
587+ httpx_client = client ,
588+ terminate_on_close = terminate_on_close ,
589+ ) as streams :
590+ yield streams
0 commit comments