Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions src/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from typing import Any, Dict, List, Optional, Callable

from src.logger import get_logger
from src.exceptions import (
MissingConfigurationError,
InvalidConfigurationError,
MCPServiceError,
)
from .mcp import MCPStdioServer, MCPHttpServer
from .utils import TokenUsageTracker

Expand Down Expand Up @@ -144,13 +149,20 @@ async def _create_mcp_server(self) -> Any:
return self._create_stdio_server()
if self.mcp_service in self.HTTP_SERVICES:
return self._create_http_server()
raise ValueError(f"Unsupported MCP service: {self.mcp_service}")
raise InvalidConfigurationError(
config_key="mcp_service",
value=self.mcp_service,
reason=f"Unsupported MCP service. STDIO services: {self.STDIO_SERVICES}, HTTP services: {self.HTTP_SERVICES}"
)

def _create_stdio_server(self) -> MCPStdioServer:
if self.mcp_service == "notion":
notion_key = self.service_config.get("notion_key")
if not notion_key:
raise ValueError("Notion API key required")
raise MissingConfigurationError(
config_key="notion_key",
service="notion"
)
return MCPStdioServer(
command="npx",
args=["-y", "@notionhq/notion-mcp-server"],
Expand All @@ -165,7 +177,10 @@ def _create_stdio_server(self) -> MCPStdioServer:
if self.mcp_service == "filesystem":
test_directory = self.service_config.get("test_directory")
if not test_directory:
raise ValueError("Test directory required for filesystem service")
raise MissingConfigurationError(
config_key="test_directory",
service="filesystem"
)
return MCPStdioServer(
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", str(test_directory)],
Expand Down Expand Up @@ -199,7 +214,11 @@ def _create_stdio_server(self) -> MCPStdioServer:
password = self.service_config.get("password")
database = self.service_config.get("current_database") or self.service_config.get("database")
if not all([username, password, database]):
raise ValueError("PostgreSQL requires username, password, and database")
missing = [k for k, v in [("username", username), ("password", password), ("database", database)] if not v]
raise MissingConfigurationError(
config_key=", ".join(missing),
service="postgres"
)
database_url = f"postgresql://{username}:{password}@{host}:{port}/{database}"
return MCPStdioServer(
command="pipx",
Expand All @@ -211,7 +230,11 @@ def _create_stdio_server(self) -> MCPStdioServer:
api_key = self.service_config.get("api_key")
backend_url = self.service_config.get("backend_url")
if not all([api_key, backend_url]):
raise ValueError("Insforge requires api_key and backend_url")
missing = [k for k, v in [("api_key", api_key), ("backend_url", backend_url)] if not v]
raise MissingConfigurationError(
config_key=", ".join(missing),
service="insforge"
)
return MCPStdioServer(
command="npx",
args=["-y", "@insforge/mcp@dev"],
Expand All @@ -221,21 +244,32 @@ def _create_stdio_server(self) -> MCPStdioServer:
},
)

raise ValueError(f"Unsupported stdio service: {self.mcp_service}")
raise InvalidConfigurationError(
config_key="mcp_service",
value=self.mcp_service,
reason=f"Unsupported stdio service. Supported: {self.STDIO_SERVICES}"
)

def _create_http_server(self) -> MCPHttpServer:
if self.mcp_service == "github":
github_token = self.service_config.get("github_token")
if not github_token:
raise ValueError("GitHub token required")
raise MissingConfigurationError(
config_key="github_token",
service="github"
)
return MCPHttpServer(
url="https://api.githubcopilot.com/mcp/",
headers={
"Authorization": f"Bearer {github_token}",
"User-Agent": "MCPMark/1.0",
},
)
raise ValueError(f"Unsupported HTTP service: {self.mcp_service}")
raise InvalidConfigurationError(
config_key="mcp_service",
value=self.mcp_service,
reason=f"Unsupported HTTP service. Supported: {self.HTTP_SERVICES}"
)

# ------------------------------------------------------------------
# Message/Tool formatting helpers
Expand Down
141 changes: 115 additions & 26 deletions src/agents/mcpmark_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
import nest_asyncio

from src.logger import get_logger
from src.exceptions import (
AgentExecutionError,
AgentTimeoutError,
LLMRateLimitError,
LLMQuotaExceededError,
LLMContextWindowExceededError,
MissingConfigurationError,
InvalidConfigurationError,
)
from .base_agent import BaseMCPAgent
from .mcp import MCPStdioServer, MCPHttpServer

Expand Down Expand Up @@ -133,9 +142,22 @@ async def _execute_with_strategy():
if isinstance(e, asyncio.TimeoutError):
error_msg = f"Execution timed out after {self.timeout} seconds"
logger.error(error_msg)
# Convert to AgentTimeoutError but don't raise - return error result
timeout_error = AgentTimeoutError(
agent_name=self.__class__.__name__,
timeout=self.timeout,
cause=e
)
error_msg = str(timeout_error)
else:
error_msg = f"Agent execution failed: {e}"
logger.error(error_msg, exc_info=True)
# Check if it's already an MCPMarkException
if isinstance(e, (AgentExecutionError, AgentTimeoutError,
LLMRateLimitError, LLMQuotaExceededError,
LLMContextWindowExceededError)):
error_msg = str(e)
else:
error_msg = f"Agent execution failed: {e}"
logger.error(error_msg, exc_info=True)

self.usage_tracker.update(
success=False,
Expand Down Expand Up @@ -327,14 +349,17 @@ async def _execute_anthropic_native_tool_loop(
tools=tools,
system=system_text
)
if turn_count == 1:
self.litellm_run_model_name = response['model'].split("/")[-1]

# Check for errors immediately after API call, before accessing response
if error_msg:
break

# Now safe to access response fields
if turn_count == 1 and response and "model" in response:
self.litellm_run_model_name = response['model'].split("/")[-1]

# Update token usage
if "usage" in response:
if response and "usage" in response:
usage = response["usage"]
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
Expand All @@ -348,7 +373,7 @@ async def _execute_anthropic_native_tool_loop(
## TODO: add reasoning tokens for claude

# Extract blocks from response
blocks = response.get("content", [])
blocks = response.get("content", []) if response else []
tool_uses = [b for b in blocks if b.get("type") == "tool_use"]
thinking_blocks = [b for b in blocks if b.get("type") == "thinking"]
text_blocks = [b for b in blocks if b.get("type") == "text"]
Expand Down Expand Up @@ -606,17 +631,47 @@ async def _execute_litellm_tool_loop(
logger.warning(f"| ✗ LLM call timed out on turn {turn_count + 1}")
consecutive_failures += 1
if consecutive_failures >= max_consecutive_failures:
raise Exception(f"Too many consecutive failures ({consecutive_failures})")
raise AgentTimeoutError(
agent_name=self.__class__.__name__,
timeout=self.timeout / 2,
cause=asyncio.TimeoutError(f"Too many consecutive failures ({consecutive_failures})")
)
await asyncio.sleep(8 ** consecutive_failures) # Exponential backoff
continue
except Exception as e:
logger.error(f"| ✗ LLM call failed on turn {turn_count + 1}: {e}")
consecutive_failures += 1

# Handle specific error types that should not be retried
error_str = str(e)
if "ContextWindowExceededError" in error_str:
raise LLMContextWindowExceededError(
model_name=self.litellm_input_model_name,
cause=e
)

# Check if we've exceeded max consecutive failures
if consecutive_failures >= max_consecutive_failures:
raise
if "ContextWindowExceededError" in str(e):
raise
elif "RateLimitError" in str(e):
# Convert to appropriate exception type
if "RateLimitError" in error_str or "rate limit" in error_str.lower():
raise LLMRateLimitError(
model_name=self.litellm_input_model_name,
cause=e
)
elif "quota" in error_str.lower() or "account balance" in error_str.lower():
raise LLMQuotaExceededError(
model_name=self.litellm_input_model_name,
cause=e
)
else:
raise AgentExecutionError(
agent_name=self.__class__.__name__,
reason=str(e),
cause=e
)

# Retry with exponential backoff for recoverable errors
if "RateLimitError" in error_str or "rate limit" in error_str.lower():
await asyncio.sleep(12 ** consecutive_failures)
else:
await asyncio.sleep(2 ** consecutive_failures)
Expand Down Expand Up @@ -645,9 +700,13 @@ async def _execute_litellm_tool_loop(

# Get response message
choices = response.choices
if len(choices):
message = choices[0].message
message_dict = message.model_dump() if hasattr(message, 'model_dump') else dict(message)
if not len(choices):
logger.warning("| No choices in response, ending task")
ended_normally = False
break

message = choices[0].message
message_dict = message.model_dump() if hasattr(message, 'model_dump') else dict(message)

# Log assistant's text content if present
if hasattr(message, 'content') and message.content:
Expand Down Expand Up @@ -712,9 +771,7 @@ async def _execute_litellm_tool_loop(
continue
else:
# Log end reason
if not choices:
logger.info("|\n|\n| Task ended with no messages generated by the model.")
elif choices[0].finish_reason == "stop":
if choices[0].finish_reason == "stop":
logger.info("|\n|\n| Task ended with the finish reason from messages being 'stop'.")

# No tool/function call, add message and we're done
Expand Down Expand Up @@ -789,15 +846,22 @@ async def _create_mcp_server(self) -> Any:
elif self.mcp_service in self.HTTP_SERVICES:
return self._create_http_server()
else:
raise ValueError(f"Unsupported MCP service: {self.mcp_service}")
raise InvalidConfigurationError(
config_key="mcp_service",
value=self.mcp_service,
reason=f"Unsupported MCP service. STDIO services: {self.STDIO_SERVICES}, HTTP services: {self.HTTP_SERVICES}"
)


def _create_stdio_server(self) -> MCPStdioServer:
"""Create stdio-based MCP server."""
if self.mcp_service == "notion":
notion_key = self.service_config.get("notion_key")
if not notion_key:
raise ValueError("Notion API key required")
raise MissingConfigurationError(
config_key="notion_key",
service="notion"
)

return MCPStdioServer(
command="npx",
Expand All @@ -813,7 +877,10 @@ def _create_stdio_server(self) -> MCPStdioServer:
elif self.mcp_service == "filesystem":
test_directory = self.service_config.get("test_directory")
if not test_directory:
raise ValueError("Test directory required for filesystem service")
raise MissingConfigurationError(
config_key="test_directory",
service="filesystem"
)

return MCPStdioServer(
command="npx",
Expand Down Expand Up @@ -846,7 +913,11 @@ def _create_stdio_server(self) -> MCPStdioServer:
database = self.service_config.get("current_database") or self.service_config.get("database")

if not all([username, password, database]):
raise ValueError("PostgreSQL requires username, password, and database")
missing = [k for k, v in [("username", username), ("password", password), ("database", database)] if not v]
raise MissingConfigurationError(
config_key=", ".join(missing),
service="postgres"
)

database_url = f"postgresql://{username}:{password}@{host}:{port}/{database}"

Expand All @@ -860,7 +931,11 @@ def _create_stdio_server(self) -> MCPStdioServer:
api_key = self.service_config.get("api_key")
backend_url = self.service_config.get("backend_url")
if not all([api_key, backend_url]):
raise ValueError("Insforge requires api_key and backend_url")
missing = [k for k, v in [("api_key", api_key), ("backend_url", backend_url)] if not v]
raise MissingConfigurationError(
config_key=", ".join(missing),
service="insforge"
)
return MCPStdioServer(
command="npx",
args=["-y", "@insforge/mcp@dev"],
Expand All @@ -871,15 +946,22 @@ def _create_stdio_server(self) -> MCPStdioServer:
)

else:
raise ValueError(f"Unsupported stdio service: {self.mcp_service}")
raise InvalidConfigurationError(
config_key="mcp_service",
value=self.mcp_service,
reason=f"Unsupported stdio service. Supported: {self.STDIO_SERVICES}"
)


def _create_http_server(self) -> MCPHttpServer:
"""Create HTTP-based MCP server."""
if self.mcp_service == "github":
github_token = self.service_config.get("github_token")
if not github_token:
raise ValueError("GitHub token required")
raise MissingConfigurationError(
config_key="github_token",
service="github"
)

return MCPHttpServer(
url="https://api.githubcopilot.com/mcp/",
Expand All @@ -895,7 +977,10 @@ def _create_http_server(self) -> MCPHttpServer:
api_key = self.service_config.get("api_key", "")

if not api_key:
raise ValueError("Supabase requires api_key (use secret key from 'supabase status')")
raise MissingConfigurationError(
config_key="api_key",
service="supabase"
)

# Supabase CLI exposes MCP at /mcp endpoint
mcp_url = f"{api_url}/mcp"
Expand All @@ -909,5 +994,9 @@ def _create_http_server(self) -> MCPHttpServer:
)

else:
raise ValueError(f"Unsupported HTTP service: {self.mcp_service}")
raise InvalidConfigurationError(
config_key="mcp_service",
value=self.mcp_service,
reason=f"Unsupported HTTP service. Supported: {self.HTTP_SERVICES}"
)

Loading