diff --git a/src/mcp_server_appwrite/server.py b/src/mcp_server_appwrite/server.py index bec3537..05b7439 100644 --- a/src/mcp_server_appwrite/server.py +++ b/src/mcp_server_appwrite/server.py @@ -5,10 +5,13 @@ import base64 import importlib import inspect +import ipaddress import json +import mimetypes import os import pkgutil import re +import socket import sys from collections.abc import Mapping from dataclasses import dataclass @@ -18,7 +21,9 @@ from pathlib import Path from types import UnionType from typing import Any, Union, get_args, get_origin +from urllib.parse import unquote, urlsplit +import httpx import mcp.server.stdio import mcp.types as types from appwrite.client import Client @@ -331,42 +336,200 @@ def _coerce_enum(enum_type: type[Enum], value: Any, param_name: str) -> Any: ) from exc +# Upload behavior is configured once per server process at build time, since a given +# process serves exactly one transport. +# stdio: local filesystem paths are read directly; URL fetch also allowed. +# http : the server runs remotely with no access to the client's filesystem, so local +# paths are rejected with guidance; uploads come via URL fetch or inline bytes. +_UPLOAD_TRANSPORT: str = "stdio" + +_MAX_FETCH_BYTES = 25 * 1024 * 1024 # 25 MB cap on server-fetched files +_MAX_INLINE_BYTES = 256 * 1024 # 256 KB cap on decoded inline content +_FETCH_TIMEOUT_SECONDS = 30.0 +_FETCH_MAX_REDIRECTS = 5 + +_HOSTED_PATH_GUIDANCE = ( + "The hosted Appwrite MCP server cannot read local file paths. For '{param}', pass a " + 'public URL as {{"url": "https://..."}} (preferred), or a small file inline as ' + '{{"filename": "...", "content": "", "encoding": "base64"}}.' +) + + +def _configure_uploads(transport: str) -> None: + """Set the upload mode for this server process. Called once from build_mcp_server.""" + global _UPLOAD_TRANSPORT + _UPLOAD_TRANSPORT = transport + + +def _validate_fetch_url(url: str) -> None: + """Reject non-http(s) schemes and hosts that resolve to non-public addresses. + + This is the SSRF guard for server-side URL fetches: it stops the model from making + the hosted server reach internal services, loopback, or the cloud metadata endpoint + (169.254.169.254). Note the resolve-then-reconnect DNS-rebinding gap is accepted. + """ + parts = urlsplit(url) + if parts.scheme not in ("http", "https"): + raise ValueError( + f"Unsupported URL scheme '{parts.scheme}' — only http and https are allowed." + ) + host = parts.hostname + if not host: + raise ValueError("URL is missing a host.") + + port = parts.port or (443 if parts.scheme == "https" else 80) + try: + infos = socket.getaddrinfo(host, port) + except socket.gaierror as exc: + raise ValueError(f"Could not resolve host '{host}'.") from exc + + for info in infos: + ip = ipaddress.ip_address(info[4][0]) + if ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_reserved + or ip.is_multicast + or ip.is_unspecified + ): + raise ValueError( + "Refusing to fetch a URL that resolves to a private, loopback, or " + "link-local address." + ) + + +def _derive_filename(resp: httpx.Response, url: str) -> str: + """Best-effort filename from Content-Disposition, then the URL path, then a fallback.""" + disposition = resp.headers.get("content-disposition", "") + match = re.search( + r"filename\*?=(?:[^']*'[^']*')?\"?([^\";]+)\"?", disposition, re.IGNORECASE + ) + candidate = "" + if match: + candidate = unquote(match.group(1)).strip().strip('"') + if not candidate: + segment = urlsplit(url).path.rstrip("/").rsplit("/", 1)[-1] + candidate = unquote(segment).strip() + # Sanitize to a bare filename — strip any directory components or traversal. + candidate = candidate.replace("\\", "/").rsplit("/", 1)[-1] + if candidate in ("", ".", ".."): + mime_type = (resp.headers.get("content-type") or "").split(";")[0].strip() + extension = mimetypes.guess_extension(mime_type) if mime_type else None + candidate = f"upload{extension}" if extension else "upload" + return candidate + + +def _fetch_input_file(url: str, param_name: str) -> InputFile: + """Download a public URL (SSRF-guarded, size-capped) into an in-memory InputFile.""" + _validate_fetch_url(url) + try: + with httpx.Client( + timeout=_FETCH_TIMEOUT_SECONDS, + follow_redirects=True, + max_redirects=_FETCH_MAX_REDIRECTS, + limits=httpx.Limits(max_connections=1), + ) as client: + with client.stream("GET", url) as resp: + resp.raise_for_status() + # The final URL after redirects must also be public. + _validate_fetch_url(str(resp.url)) + + declared = resp.headers.get("content-length") + if declared is not None and declared.isdigit(): + if int(declared) > _MAX_FETCH_BYTES: + raise ValueError( + f"File at URL for '{param_name}' is too large " + f"({declared} bytes); max is {_MAX_FETCH_BYTES} bytes." + ) + + chunks: list[bytes] = [] + total = 0 + for chunk in resp.iter_bytes(): + total += len(chunk) + if total > _MAX_FETCH_BYTES: + raise ValueError( + f"File at URL for '{param_name}' exceeds the max of " + f"{_MAX_FETCH_BYTES} bytes." + ) + chunks.append(chunk) + + data = b"".join(chunks) + mime_type = ( + (resp.headers.get("content-type") or "").split(";")[0].strip() + ) + filename = _derive_filename(resp, url) + except httpx.HTTPError as exc: + raise ValueError( + f"Failed to fetch file from URL for '{param_name}': {exc}" + ) from exc + + return InputFile.from_bytes(data, filename, mime_type or None) + + +def _coerce_inline_content(value: Mapping, param_name: str) -> InputFile: + filename = value.get("filename") + content = value.get("content") + encoding = str(value.get("encoding", "utf-8")).lower() + if encoding == "base64": + try: + data = base64.b64decode(content) + except Exception as exc: + raise ValueError(f"Invalid base64 content for '{param_name}'.") from exc + elif encoding == "utf-8": + data = str(content).encode("utf-8") + else: + raise ValueError( + f"Invalid encoding for '{param_name}'. Expected 'utf-8' or 'base64'." + ) + + if len(data) > _MAX_INLINE_BYTES: + raise ValueError( + f"Inline content for '{param_name}' is too large " + f"({len(data)} bytes, max {_MAX_INLINE_BYTES}). For larger files pass " + '{"url": "https://..."} so the server can download it directly.' + ) + + return InputFile.from_bytes(data, str(filename), value.get("mime_type")) + + +def _coerce_path(path: str, param_name: str) -> InputFile: + if _UPLOAD_TRANSPORT != "stdio": + raise ValueError(_HOSTED_PATH_GUIDANCE.format(param=param_name)) + return InputFile.from_path(path) + + def _coerce_input_file(value: Any, param_name: str) -> InputFile: if isinstance(value, InputFile): return value if isinstance(value, str): - return InputFile.from_path(value) + if urlsplit(value).scheme in ("http", "https"): + return _fetch_input_file(value, param_name) + return _coerce_path(value, param_name) if not isinstance(value, Mapping): raise ValueError( - f"Invalid value for '{param_name}'. Provide a file path string or an object with `path` or `filename` and `content`." + f"Invalid value for '{param_name}'. Provide a public URL string, a `url`, or " + "an object with `filename` and `content`." ) + url = value.get("url") + if url: + return _fetch_input_file(str(url), param_name) + path = value.get("path") if path: - return InputFile.from_path(str(path)) + return _coerce_path(str(path), param_name) filename = value.get("filename") content = value.get("content") if filename and content is not None: - encoding = str(value.get("encoding", "utf-8")).lower() - if encoding == "base64": - try: - data = base64.b64decode(content) - except Exception as exc: - raise ValueError(f"Invalid base64 content for '{param_name}'.") from exc - elif encoding == "utf-8": - data = str(content).encode("utf-8") - else: - raise ValueError( - f"Invalid encoding for '{param_name}'. Expected 'utf-8' or 'base64'." - ) - - return InputFile.from_bytes(data, str(filename), value.get("mime_type")) + return _coerce_inline_content(value, param_name) raise ValueError( - f"Invalid value for '{param_name}'. Provide `path`, or both `filename` and `content`." + f"Invalid value for '{param_name}'. Provide `url`, or both `filename` and " + "`content`." ) @@ -680,11 +843,16 @@ def build_instructions(transport: str = "http") -> str: "target project: use appwrite_get_context first, then pass the selected " "project id as project_id to appwrite_call_tool. " "Organization-scoped console tools (e.g. creating a project) need organization_id. " + "File/image uploads: pass a public URL as the file argument (e.g. " + '{"url": "https://..."}) so the server downloads it directly; for very small ' + 'files you may pass inline base64 ({"filename": ..., "content": ..., ' + '"encoding": "base64"}). ' f"{common}" ) def build_mcp_server(operator: Operator, *, transport: str = "http") -> Server: + _configure_uploads(transport) instructions = build_instructions(transport) server = Server("Appwrite MCP Server", instructions=instructions) diff --git a/src/mcp_server_appwrite/service.py b/src/mcp_server_appwrite/service.py index 7f2b84c..d1de30d 100644 --- a/src/mcp_server_appwrite/service.py +++ b/src/mcp_server_appwrite/service.py @@ -45,14 +45,18 @@ def _input_file_schema(self) -> dict: "oneOf": [ { "type": "string", - "description": "Path to a local file on the machine running the MCP server.", + "description": "A public http(s) URL the server downloads the file from, or a local file path (paths only work when the server runs locally).", }, { "type": "object", "properties": { + "url": { + "type": "string", + "description": "Public http(s) URL the server downloads the file from. Preferred for images and any non-trivial file on the hosted server.", + }, "path": { "type": "string", - "description": "Path to a local file on the machine running the MCP server.", + "description": "Path to a local file. Only works when the server runs locally; on the hosted server use `url` instead.", }, "filename": { "type": "string", diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 72e3c17..bb7a129 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -11,8 +11,10 @@ from appwrite.enums.browser import Browser from appwrite.input_file import InputFile +from mcp_server_appwrite import server as server_module from mcp_server_appwrite.server import ( _coerce_argument, + _configure_uploads, _format_tool_result, _prepare_arguments, _validate_service, @@ -26,6 +28,45 @@ from mcp_server_appwrite.tool_manager import ToolManager +class _FakeResponse: + def __init__(self, *, data=b"", headers=None, url="https://example.com/pic.png"): + self._data = data + self.headers = headers or {} + self.url = url + + def raise_for_status(self): + return None + + def iter_bytes(self): + for index in range(0, len(self._data), 64): + yield self._data[index : index + 64] + + +class _FakeStream: + def __init__(self, response): + self._response = response + + def __enter__(self): + return self._response + + def __exit__(self, *args): + return False + + +class _FakeClient: + def __init__(self, response): + self._response = response + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def stream(self, method, url): + return _FakeStream(self._response) + + class ServerHelperTests(unittest.TestCase): def test_parse_args_defaults_to_stdio(self): with patch.dict(os.environ, {}, clear=True): @@ -457,5 +498,122 @@ def test_parse_args_rejects_removed_flags(self): parse_args() +_PUBLIC_ADDRINFO = [(None, None, None, None, ("93.184.216.34", 80))] + + +class UploadInputFileTests(unittest.TestCase): + """File-upload coercion: URL fetch, SSRF guard, size caps, transport gating.""" + + def setUp(self): + _configure_uploads("http") + + def tearDown(self): + _configure_uploads("stdio") + + def _patch_fetch(self, response, addrinfo=_PUBLIC_ADDRINFO): + return ( + patch( + "mcp_server_appwrite.server.socket.getaddrinfo", return_value=addrinfo + ), + patch( + "mcp_server_appwrite.server.httpx.Client", + return_value=_FakeClient(response), + ), + ) + + def test_url_object_uses_content_disposition_filename(self): + response = _FakeResponse( + data=b"\x89PNG\r\n", + headers={ + "content-type": "image/png", + "content-disposition": 'attachment; filename="pic.png"', + }, + ) + addr, client = self._patch_fetch(response) + with addr, client: + coerced = _coerce_argument( + "file", {"url": "https://example.com/x"}, InputFile + ) + + self.assertEqual(coerced.source_type, "bytes") + self.assertEqual(coerced.data, b"\x89PNG\r\n") + self.assertEqual(coerced.filename, "pic.png") + self.assertEqual(coerced.mime_type, "image/png") + + def test_bare_url_string_derives_filename_from_path(self): + response = _FakeResponse(data=b"abc", headers={"content-type": "image/png"}) + addr, client = self._patch_fetch(response) + with addr, client: + coerced = _coerce_argument( + "file", "https://example.com/dir/a.png", InputFile + ) + + self.assertEqual(coerced.source_type, "bytes") + self.assertEqual(coerced.filename, "a.png") + + def test_url_fetch_rejects_private_ip(self): + response = _FakeResponse(data=b"secret") + for ip in ("127.0.0.1", "169.254.169.254", "10.0.0.1"): + with self.subTest(ip=ip): + addr, client = self._patch_fetch( + response, addrinfo=[(None, None, None, None, (ip, 80))] + ) + with addr, client as client_mock: + with self.assertRaises(ValueError) as ctx: + _coerce_argument( + "file", {"url": "https://evil.example/x"}, InputFile + ) + self.assertIn("private", str(ctx.exception).lower()) + client_mock.assert_not_called() + + def test_url_fetch_rejects_non_http_scheme(self): + with self.assertRaises(ValueError) as ctx: + _coerce_argument("file", {"url": "file:///etc/passwd"}, InputFile) + self.assertIn("scheme", str(ctx.exception).lower()) + + def test_url_fetch_size_cap_via_stream(self): + response = _FakeResponse(data=b"0123456789") # 10 bytes, no content-length + addr, client = self._patch_fetch(response) + with addr, client, patch.object(server_module, "_MAX_FETCH_BYTES", 4): + with self.assertRaises(ValueError) as ctx: + _coerce_argument("file", {"url": "https://example.com/x"}, InputFile) + self.assertIn("max", str(ctx.exception).lower()) + + def test_inline_content_size_cap(self): + with patch.object(server_module, "_MAX_INLINE_BYTES", 4): + with self.assertRaises(ValueError) as ctx: + _coerce_argument( + "file", + { + "filename": "big.bin", + "content": base64.b64encode(b"hello").decode("ascii"), + "encoding": "base64", + }, + InputFile, + ) + self.assertIn("url", str(ctx.exception).lower()) + + def test_path_string_rejected_on_http(self): + with self.assertRaises(ValueError) as ctx: + _coerce_argument("file", "/home/me/pic.png", InputFile) + message = str(ctx.exception) + self.assertIn("url", message.lower()) + self.assertNotIn("stdio", message.lower()) + self.assertNotIn("self-host", message.lower()) + + def test_path_string_allowed_on_stdio(self): + _configure_uploads("stdio") + with tempfile.NamedTemporaryFile(suffix=".txt") as handle: + coerced = _coerce_argument("file", handle.name, InputFile) + self.assertEqual(coerced.source_type, "path") + + def test_http_instructions_mention_url_upload(self): + http = build_instructions("http") + stdio = build_instructions("stdio") + self.assertIn("url", http.lower()) + self.assertIn("upload", http.lower()) + self.assertNotIn("upload", stdio.lower()) + + if __name__ == "__main__": unittest.main()