22
33from dataclasses import dataclass
44import time
5- from typing import Any
5+ from typing import Any , Callable
66from uuid import uuid4
77
88import httpx
@@ -24,6 +24,10 @@ class AxmeClientConfig:
2424 max_retries : int = 2
2525 retry_backoff_seconds : float = 0.2
2626 auto_trace_id : bool = True
27+ default_owner_agent : str | None = None
28+ mcp_endpoint_path : str = "/mcp"
29+ mcp_protocol_version : str = "2024-11-05"
30+ mcp_observer : Callable [[dict [str , Any ]], None ] | None = None
2731
2832
2933class AxmeClient :
@@ -38,6 +42,7 @@ def __init__(self, config: AxmeClientConfig, *, http_client: httpx.Client | None
3842 "Content-Type" : "application/json" ,
3943 },
4044 )
45+ self ._mcp_tool_schemas : dict [str , dict [str , Any ]] = {}
4146
4247 def close (self ) -> None :
4348 if self ._owns_http_client :
@@ -485,6 +490,70 @@ def replay_webhook_event(
485490 )
486491 return response
487492
493+ def mcp_initialize (self , * , protocol_version : str | None = None , trace_id : str | None = None ) -> dict [str , Any ]:
494+ payload = {
495+ "jsonrpc" : "2.0" ,
496+ "id" : str (uuid4 ()),
497+ "method" : "initialize" ,
498+ "params" : {"protocolVersion" : protocol_version or self ._config .mcp_protocol_version },
499+ }
500+ return self ._mcp_request (payload = payload , trace_id = trace_id , retryable = True )
501+
502+ def mcp_list_tools (self , * , trace_id : str | None = None ) -> dict [str , Any ]:
503+ payload = {
504+ "jsonrpc" : "2.0" ,
505+ "id" : str (uuid4 ()),
506+ "method" : "tools/list" ,
507+ "params" : {},
508+ }
509+ result = self ._mcp_request (payload = payload , trace_id = trace_id , retryable = True )
510+ tools = result .get ("tools" )
511+ if isinstance (tools , list ):
512+ self ._mcp_tool_schemas = {}
513+ for tool in tools :
514+ if not isinstance (tool , dict ):
515+ continue
516+ name = tool .get ("name" )
517+ input_schema = tool .get ("inputSchema" )
518+ if isinstance (name , str ) and isinstance (input_schema , dict ):
519+ self ._mcp_tool_schemas [name ] = input_schema
520+ return result
521+
522+ def mcp_call_tool (
523+ self ,
524+ name : str ,
525+ * ,
526+ arguments : dict [str , Any ] | None = None ,
527+ owner_agent : str | None = None ,
528+ idempotency_key : str | None = None ,
529+ trace_id : str | None = None ,
530+ validate_input_schema : bool = True ,
531+ retryable : bool | None = None ,
532+ ) -> dict [str , Any ]:
533+ if not isinstance (name , str ) or not name .strip ():
534+ raise ValueError ("tool name must be non-empty string" )
535+ args = dict (arguments or {})
536+ resolved_owner = owner_agent or self ._config .default_owner_agent
537+ if resolved_owner and "owner_agent" not in args :
538+ args ["owner_agent" ] = resolved_owner
539+ if idempotency_key and "idempotency_key" not in args :
540+ args ["idempotency_key" ] = idempotency_key
541+
542+ if validate_input_schema :
543+ self ._validate_mcp_tool_arguments (name = name .strip (), arguments = args )
544+
545+ params : dict [str , Any ] = {"name" : name .strip (), "arguments" : args }
546+ if resolved_owner :
547+ params ["owner_agent" ] = resolved_owner
548+ payload = {
549+ "jsonrpc" : "2.0" ,
550+ "id" : str (uuid4 ()),
551+ "method" : "tools/call" ,
552+ "params" : params ,
553+ }
554+ should_retry = retryable if retryable is not None else bool (idempotency_key )
555+ return self ._mcp_request (payload = payload , trace_id = trace_id , retryable = should_retry )
556+
488557 def _request_json (
489558 self ,
490559 method : str ,
@@ -529,6 +598,99 @@ def _request_json(
529598
530599 raise RuntimeError ("unreachable retry loop state" )
531600
601+ def _mcp_request (
602+ self ,
603+ * ,
604+ payload : dict [str , Any ],
605+ trace_id : str | None ,
606+ retryable : bool ,
607+ ) -> dict [str , Any ]:
608+ self ._notify_mcp_observer (
609+ {
610+ "phase" : "request" ,
611+ "method" : payload .get ("method" ),
612+ "rpc_id" : payload .get ("id" ),
613+ "retryable" : retryable ,
614+ }
615+ )
616+ response = self ._request_json (
617+ "POST" ,
618+ self ._config .mcp_endpoint_path ,
619+ json_body = payload ,
620+ trace_id = trace_id ,
621+ retryable = retryable ,
622+ )
623+ if isinstance (response .get ("error" ), dict ):
624+ self ._raise_mcp_rpc_error (response )
625+ result = response .get ("result" )
626+ if not isinstance (result , dict ):
627+ raise AxmeHttpError (502 , "invalid MCP response: missing result object" , body = response )
628+ self ._notify_mcp_observer (
629+ {
630+ "phase" : "response" ,
631+ "method" : payload .get ("method" ),
632+ "rpc_id" : payload .get ("id" ),
633+ "result_keys" : sorted (result .keys ()),
634+ }
635+ )
636+ return result
637+
638+ def _raise_mcp_rpc_error (self , response_payload : dict [str , Any ]) -> None :
639+ error = response_payload .get ("error" )
640+ if not isinstance (error , dict ):
641+ raise AxmeHttpError (502 , "invalid MCP response: error is not object" , body = response_payload )
642+ code = error .get ("code" )
643+ message = error .get ("message" )
644+ if not isinstance (code , int ):
645+ code = - 32000
646+ if not isinstance (message , str ) or not message :
647+ message = "MCP RPC error"
648+ data = error .get ("data" )
649+ kwargs = {"body" : {"code" : code , "message" : message , "data" : data }}
650+ if code in {- 32001 , - 32003 }:
651+ raise AxmeAuthError (403 , message , ** kwargs )
652+ if code == - 32004 :
653+ raise AxmeRateLimitError (429 , message , ** kwargs )
654+ if code == - 32602 :
655+ raise AxmeValidationError (422 , message , ** kwargs )
656+ if code <= - 32000 :
657+ raise AxmeServerError (502 , message , ** kwargs )
658+ raise AxmeHttpError (400 , message , ** kwargs )
659+
660+ def _validate_mcp_tool_arguments (self , * , name : str , arguments : dict [str , Any ]) -> None :
661+ schema = self ._mcp_tool_schemas .get (name )
662+ if not isinstance (schema , dict ):
663+ return
664+ required = schema .get ("required" )
665+ if isinstance (required , list ):
666+ missing = [item for item in required if isinstance (item , str ) and item not in arguments ]
667+ if missing :
668+ raise ValueError (f"missing required MCP tool arguments for { name } : { ', ' .join (sorted (missing ))} " )
669+ properties = schema .get ("properties" )
670+ if not isinstance (properties , dict ):
671+ return
672+ for key , value in arguments .items ():
673+ if key not in properties :
674+ continue
675+ prop = properties [key ]
676+ if not isinstance (prop , dict ):
677+ continue
678+ declared_type = prop .get ("type" )
679+ if isinstance (declared_type , list ):
680+ accepted_types = [item for item in declared_type if isinstance (item , str )]
681+ elif isinstance (declared_type , str ):
682+ accepted_types = [declared_type ]
683+ else :
684+ accepted_types = []
685+ if accepted_types and not _matches_json_type (value = value , accepted_types = accepted_types ):
686+ raise ValueError (f"invalid MCP argument type for { name } .{ key } : expected { accepted_types } " )
687+
688+ def _notify_mcp_observer (self , event : dict [str , Any ]) -> None :
689+ observer = self ._config .mcp_observer
690+ if observer is None :
691+ return
692+ observer (event )
693+
532694 def _sleep_before_retry (self , attempt_idx : int , * , retry_after : int | None ) -> None :
533695 if retry_after is not None :
534696 time .sleep (max (0 , retry_after ))
@@ -599,3 +761,22 @@ def _parse_retry_after(value: str | None) -> int | None:
599761
600762def _is_retryable_status (status_code : int ) -> bool :
601763 return status_code == 429 or status_code >= 500
764+
765+
766+ def _matches_json_type (* , value : Any , accepted_types : list [str ]) -> bool :
767+ for type_name in accepted_types :
768+ if type_name == "null" and value is None :
769+ return True
770+ if type_name == "string" and isinstance (value , str ):
771+ return True
772+ if type_name == "boolean" and isinstance (value , bool ):
773+ return True
774+ if type_name == "integer" and isinstance (value , int ) and not isinstance (value , bool ):
775+ return True
776+ if type_name == "number" and isinstance (value , (int , float )) and not isinstance (value , bool ):
777+ return True
778+ if type_name == "object" and isinstance (value , dict ):
779+ return True
780+ if type_name == "array" and isinstance (value , list ):
781+ return True
782+ return False
0 commit comments