Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ai/agents/ai_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def call_with_provider(
elif provider_lower == PROVIDER_ANTHROPIC:
result = self._call_anthropic(model, system_message, prompt, temperature)
elif provider_lower == PROVIDER_OLLAMA:
result = self._call_ollama(system_message, prompt, temperature)
result = self._call_ollama(system_message, prompt, temperature, model)
elif provider_lower == PROVIDER_GEMINI:
result = self._call_gemini(model, system_message, prompt, temperature)
elif provider_lower == PROVIDER_GROQ:
Expand Down
36 changes: 21 additions & 15 deletions src/ai/chat_context_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ def _construct_prompt(self, user_message: str, context_data: 'ChatContextData')
])

# Add conversation history (last few exchanges)
if self.conversation_history:
with self._history_lock:
history_snapshot = list(self.conversation_history[-6:])
if history_snapshot:
prompt_parts.append("Recent Conversation:")
for item in self.conversation_history[-6:]: # Last 3 exchanges
for item in history_snapshot: # Last 3 exchanges
role = item["role"].title()
message = item["message"][:200] + "..." if len(item["message"]) > 200 else item["message"]
prompt_parts.append(f"{role}: {message}")
Expand All @@ -132,31 +134,35 @@ def _construct_prompt(self, user_message: str, context_data: 'ChatContextData')

def _add_to_history(self, role: str, message: str):
"""Add a message to conversation history."""
self.conversation_history.append({
"role": role,
"message": message,
"timestamp": datetime.now().isoformat()
})
with self._history_lock:
self.conversation_history.append({
"role": role,
"message": message,
"timestamp": datetime.now().isoformat()
})

# Keep only recent history
if len(self.conversation_history) > self.max_history_items:
self.conversation_history = self.conversation_history[-self.max_history_items:]
# Keep only recent history
if len(self.conversation_history) > self.max_history_items:
self.conversation_history = self.conversation_history[-self.max_history_items:]

def clear_history(self):
"""Clear conversation history."""
self.conversation_history = []
with self._history_lock:
self.conversation_history = []
logger.info("Chat conversation history cleared")

def get_history(self) -> list:
"""Get conversation history."""
return self.conversation_history.copy()
with self._history_lock:
return self.conversation_history.copy()

def get_context_from_history(self, max_entries: int = 5) -> str:
"""Get context from recent conversation history."""
if not self.conversation_history:
return ""
with self._history_lock:
if not self.conversation_history:
return ""
recent_history = list(self.conversation_history[-max_entries:])

recent_history = self.conversation_history[-max_entries:]
context_parts = []

for entry in recent_history:
Expand Down
1 change: 1 addition & 0 deletions src/ai/chat_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self, app):
self.app = app
self.is_processing = False
self.conversation_history = []
self._history_lock = threading.Lock()

# Typing indicator state
self._typing_indicator_mark = None
Expand Down
164 changes: 94 additions & 70 deletions src/ai/mcp/mcp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,14 @@ def __init__(self, process: subprocess.Popen, server: MCPServer = None):
self.process = process
self.server = server # Reference to store errors
self.request_id = 0
self.response_queue = queue.Queue()
self._pending_requests: Dict[int, queue.Queue] = {}
self._pending_lock = threading.Lock()
self.reader_thread = threading.Thread(target=self._read_responses, daemon=True)
self.reader_thread.start()
# Start stderr reader to capture errors
self.stderr_thread = threading.Thread(target=self._read_stderr, daemon=True)
self.stderr_thread.start()

def _read_responses(self):
"""Read responses from the MCP server stdout"""
try:
Expand All @@ -100,14 +101,23 @@ def _read_responses(self):
try:
response = json.loads(line)
logger.debug(f"MCP response: {response}")
self.response_queue.put(response)
resp_id = response.get("id")
with self._pending_lock:
req_queue = self._pending_requests.get(resp_id)
if req_queue is not None:
req_queue.put(response)
else:
logger.warning(f"Received response for unknown request id {resp_id}, discarding")
except json.JSONDecodeError:
logger.warning(f"Invalid JSON from MCP server: {line}")
except Exception as e:
logger.error(f"Error reading MCP responses: {e}")
finally:
# Notify any waiters that the reader has exited so they don't hang forever
self.response_queue.put({"error": {"code": -1, "message": "MCP reader thread exited"}})
# Notify all pending waiters that the reader has exited so they don't hang forever
exit_msg = {"error": {"code": -1, "message": "MCP reader thread exited"}}
with self._pending_lock:
for req_queue in self._pending_requests.values():
req_queue.put(exit_msg)

def _read_stderr(self):
"""Read stderr output from the MCP server"""
Expand Down Expand Up @@ -135,43 +145,49 @@ def send_request(self, method: str, params: Dict[str, Any] = None, timeout: floa
# Check if process is still running
if self.process.poll() is not None:
raise Exception(f"MCP server process has terminated with code: {self.process.returncode}")

self.request_id += 1
request = {
"jsonrpc": "2.0",
"method": method,
"id": self.request_id
}
if params:
request["params"] = params

# Send request
request_str = json.dumps(request) + "\n"
logger.debug(f"Sending MCP request: {request_str.strip()}")

# Allocate a per-request queue and atomically assign request_id
req_queue = queue.Queue()
with self._pending_lock:
self.request_id += 1
current_id = self.request_id
self._pending_requests[current_id] = req_queue

try:
self.process.stdin.write(request_str)
self.process.stdin.flush()
except Exception as e:
raise Exception(f"Failed to send request to MCP server: {e}")

# Wait for response
start_time = time.time()
while time.time() - start_time < timeout:
# Check if process terminated while waiting
if self.process.poll() is not None:
raise Exception(f"MCP server process terminated while waiting for response (code: {self.process.returncode})")

request = {
"jsonrpc": "2.0",
"method": method,
"id": current_id
}
if params:
request["params"] = params

# Send request
request_str = json.dumps(request) + "\n"
logger.debug(f"Sending MCP request: {request_str.strip()}")
try:
response = self.response_queue.get(timeout=0.1)
if response.get("id") == self.request_id:
self.process.stdin.write(request_str)
self.process.stdin.flush()
except Exception as e:
raise Exception(f"Failed to send request to MCP server: {e}")

# Wait for response on our dedicated queue
start_time = time.time()
while time.time() - start_time < timeout:
# Check if process terminated while waiting
if self.process.poll() is not None:
raise Exception(f"MCP server process terminated while waiting for response (code: {self.process.returncode})")

try:
response = req_queue.get(timeout=0.1)
if "error" in response:
error = response['error']
# Extract error details
if isinstance(error, dict):
code = error.get('code', 'Unknown')
message = error.get('message', 'Unknown error')
data = error.get('data', {})

# Check for rate limit error
if code == 429 or 'rate' in str(message).lower():
retry_after = data.get('retry_after', 60)
Expand All @@ -181,10 +197,14 @@ def send_request(self, method: str, params: Dict[str, Any] = None, timeout: floa
else:
raise Exception(f"MCP error: {error}")
return response.get("result")
except queue.Empty:
continue

raise Exception(f"Timeout waiting for MCP response to {method}")
except queue.Empty:
continue

raise Exception(f"Timeout waiting for MCP response to {method}")
finally:
# Always clean up the per-request queue
with self._pending_lock:
self._pending_requests.pop(current_id, None)


class MCPManager:
Expand Down Expand Up @@ -330,27 +350,30 @@ def _discover_tools(self, protocol: MCPProtocol) -> List[Dict[str, Any]]:

def execute_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""Execute a tool on an MCP server"""
# Hold lock only for server/protocol lookup
with self._lock:
server = self.servers.get(server_name)
if not server or not server.process:
raise Exception(f"MCP server {server_name} is not running")

if not server.protocol:
raise Exception(f"MCP server {server_name} has no active protocol")

try:
# Log the request for debugging
request_params = {
"name": tool_name,
"arguments": arguments
}
logger.debug(f"Sending tool call request: {request_params}")

result = server.protocol.send_request("tools/call", request_params)
return result
except Exception as e:
logger.error(f"Failed to execute tool {tool_name} on {server_name}: {e}")
raise

protocol = server.protocol

# Perform the (potentially long) I/O outside the lock
try:
request_params = {
"name": tool_name,
"arguments": arguments
}
logger.debug(f"Sending tool call request: {request_params}")

result = protocol.send_request("tools/call", request_params)
return result
except Exception as e:
logger.error(f"Failed to execute tool {tool_name} on {server_name}: {e}")
raise

def get_all_tools(self) -> List[Tuple[str, Dict[str, Any]]]:
"""Get all available tools from all running servers"""
Expand Down Expand Up @@ -475,6 +498,9 @@ def _monitor_loop(self) -> None:

def _check_servers(self) -> None:
"""Check all servers and restart any that have crashed."""
crashed_servers = []

# Hold lock only briefly to detect crashed servers and clean up stale refs
with self.mcp_manager._lock:
for name, server in list(self.mcp_manager.servers.items()):
if not server.enabled:
Expand All @@ -485,17 +511,25 @@ def _check_servers(self) -> None:
continue

if server.process.poll() is not None:
# Process has terminated
# Process has terminated — collect name and clean up stale refs
exit_code = server.process.returncode
logger.warning(f"MCP server '{name}' crashed with exit code {exit_code}")
self._handle_crash(name, server)
server.process = None
server.protocol = None
server.tools = None
crashed_servers.append(name)

def _handle_crash(self, name: str, server: MCPServer) -> None:
# Handle restarts outside the lock
for name in crashed_servers:
self._handle_crash(name)

def _handle_crash(self, name: str) -> None:
"""Handle a crashed server with restart logic.

Does NOT assume any lock is held. start_server acquires its own lock.

Args:
name: Server name
server: Server instance
"""
attempts = self.restart_attempts.get(name, 0)

Expand All @@ -507,26 +541,16 @@ def _handle_crash(self, name: str, server: MCPServer) -> None:
backoff = 2 ** attempts
logger.info(f"Attempting to restart MCP server '{name}' in {backoff}s (attempt {attempts + 1}/{self.max_restarts})")

# Clean up old process
server.process = None
server.protocol = None
server.tools = None

# Wait with backoff
# Wait with backoff (no lock held, so this is safe)
time.sleep(backoff)

if not self._running:
return

try:
# Release lock temporarily for restart
self.mcp_manager._lock.release()
try:
self.mcp_manager.start_server(name)
self.restart_attempts[name] = 0 # Reset on success
logger.info(f"Successfully restarted MCP server '{name}'")
finally:
self.mcp_manager._lock.acquire()
self.mcp_manager.start_server(name)
self.restart_attempts[name] = 0 # Reset on success
logger.info(f"Successfully restarted MCP server '{name}'")
except Exception as e:
self.restart_attempts[name] = attempts + 1
logger.error(f"Failed to restart MCP server '{name}': {e}")
Expand Down
45 changes: 16 additions & 29 deletions src/ai/mcp/mcp_tool_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,36 +233,23 @@ def execute(self, **kwargs) -> ToolResult:

# Check rate limits for this server
if self.server_name in RATE_LIMITS:
with rate_limit_lock:
rate_config = RATE_LIMITS[self.server_name]
min_interval = rate_config.get("minimum_interval", 1.0 / rate_config["requests_per_second"])

current_time = time.time()

# Check global rate limit for this server
global_last = rate_config.get("global_last_request", 0)
global_time_since_last = current_time - global_last

if global_time_since_last < min_interval:
wait_time = min_interval - global_time_since_last
logger.info(f"Global rate limiting for {self.server_name}: waiting {wait_time:.2f}s")
time.sleep(wait_time)
tool_key = f"{self.server_name}:{self.original_name}"
while True:
with rate_limit_lock:
rate_config = RATE_LIMITS[self.server_name]
min_interval = rate_config.get("minimum_interval", 1.0 / rate_config["requests_per_second"])
current_time = time.time()

# Also check per-tool rate limit
tool_key = f"{self.server_name}:{self.original_name}"
last_time = rate_config["last_request_time"].get(tool_key, 0)
time_since_last = current_time - last_time

# If not enough time has passed for this specific tool, wait
if time_since_last < min_interval:
wait_time = min_interval - time_since_last
logger.info(f"Tool rate limiting: waiting {wait_time:.2f}s before calling {self.original_name}")
time.sleep(wait_time)

# Update last request times
rate_config["global_last_request"] = time.time()
rate_config["last_request_time"][tool_key] = time.time()
global_last = rate_config.get("global_last_request", 0)
global_wait = min_interval - (current_time - global_last)
last_time = rate_config["last_request_time"].get(tool_key, 0)
tool_wait = min_interval - (current_time - last_time)
wait_time = max(global_wait, tool_wait, 0.0)
if wait_time <= 0:
rate_config["global_last_request"] = current_time
rate_config["last_request_time"][tool_key] = current_time
break
logger.info(f"Rate limiting for {self.server_name}: waiting {wait_time:.2f}s")
time.sleep(wait_time)

# Execute the tool via MCP with retry logic
logger.info(f"Executing MCP tool {self.original_name} on server {self.server_name}")
Expand Down
Loading
Loading