diff --git a/src/skillspector/mcp_server.py b/src/skillspector/mcp_server.py index 444b75f..83b2f9a 100644 --- a/src/skillspector/mcp_server.py +++ b/src/skillspector/mcp_server.py @@ -28,6 +28,7 @@ from __future__ import annotations import shutil +from pathlib import Path from typing import TYPE_CHECKING, Any from skillspector import __version__ @@ -46,11 +47,30 @@ RISK_THRESHOLD = 50 +def _is_local_target(target: str) -> bool: + """Return True when ``target`` names local filesystem content.""" + stripped = target.strip() + if stripped.startswith("file://"): + return True + if stripped.startswith(("http://", "https://", "git@", "ssh://", "git+ssh://")): + return False + if stripped.startswith(("\\\\", "//")): + return True + + candidate = Path(stripped).expanduser() + if candidate.is_absolute() or candidate.drive: + return True + if "://" in stripped: + return False + return candidate.exists() + + async def run_scan( target: str, *, use_llm: bool = True, output_format: str = "json", + allow_local_targets: bool = True, yara_rules_dir: str | None = None, ) -> dict[str, Any]: """Invoke the SkillSpector graph and return a structured verdict. @@ -62,6 +82,9 @@ async def run_scan( the returned payload reports what actually happened. output_format: Format of the embedded ``report`` string. One of :data:`VALID_FORMATS`. + allow_local_targets: Whether local filesystem targets are allowed. + HTTP MCP calls set this to ``False`` so routable servers do not + accept caller-controlled local paths. yara_rules_dir: Optional directory of additional YARA rules. Returns: @@ -73,6 +96,8 @@ async def run_scan( """ if output_format not in VALID_FORMATS: raise ValueError(f"output_format must be one of {VALID_FORMATS}, got {output_format!r}") + if not allow_local_targets and _is_local_target(target): + raise ValueError("local targets are disabled for this MCP transport") llm_available = resolve_provider_credentials() is not None llm_used = use_llm and llm_available @@ -131,7 +156,7 @@ async def run_scan( shutil.rmtree(temp_dir, ignore_errors=True) -def build_server(name: str = "skillspector") -> FastMCP: +def build_server(name: str = "skillspector", *, allow_local_targets: bool = True) -> FastMCP: """Construct the FastMCP server exposing the ``scan_skill`` tool. Requires the optional ``mcp`` dependency (``pip install 'skillspector[mcp]'``). @@ -164,14 +189,19 @@ async def scan_skill( actually ran, so a low score from a static-only scan is not mistaken for a clean full scan. """ - return await run_scan(target, use_llm=use_llm, output_format=output_format) + return await run_scan( + target, + use_llm=use_llm, + output_format=output_format, + allow_local_targets=allow_local_targets, + ) return server def run(transport: str = "stdio", host: str = "127.0.0.1", port: int = 8000) -> None: """Run the MCP server over ``stdio`` (local agents) or ``http`` (remote/A2A).""" - server = build_server() + server = build_server(allow_local_targets=transport != "http") if transport == "stdio": server.run(transport="stdio") elif transport == "http": diff --git a/tests/unit/test_mcp_server.py b/tests/unit/test_mcp_server.py index 10c5596..594fc14 100644 --- a/tests/unit/test_mcp_server.py +++ b/tests/unit/test_mcp_server.py @@ -16,6 +16,8 @@ """Tests for the MCP server wrapper (run_scan core + scan_skill tool).""" from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock import pytest @@ -83,6 +85,141 @@ async def test_run_scan_rejects_invalid_format(tmp_path: Path) -> None: await run_scan(str(tmp_path), output_format="xml") +async def test_run_scan_rejects_local_target_when_disallowed( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """HTTP-style scans reject local targets before the graph is invoked.""" + graph_ainvoke = AsyncMock() + monkeypatch.setattr(mcp_server.graph, "ainvoke", graph_ainvoke) + monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None) + + with pytest.raises(ValueError, match="local targets are disabled"): + await run_scan(str(tmp_path), allow_local_targets=False) + + assert graph_ainvoke.await_count == 0 + + +async def test_run_scan_rejects_file_url_when_local_targets_disallowed( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """The same HTTP guard rejects file:// targets before any scan runs.""" + graph_ainvoke = AsyncMock() + monkeypatch.setattr(mcp_server.graph, "ainvoke", graph_ainvoke) + monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None) + + with pytest.raises(ValueError, match="local targets are disabled"): + await run_scan(tmp_path.as_uri(), allow_local_targets=False) + + assert graph_ainvoke.await_count == 0 + + +@pytest.mark.parametrize( + ("target", "expected"), + [ + (r"\\server\share\skill", True), + ("//server/share/skill", True), + ("git@github.com:NVIDIA/SkillSpector.git", False), + ("ssh://git@github.com/NVIDIA/SkillSpector.git", False), + ("git+ssh://git@github.com/NVIDIA/SkillSpector.git", False), + ("custom://example/skill", False), + ], +) +def test_is_local_target_classifies_protocol_edges(target: str, expected: bool) -> None: + """Classifier treats UNC-style paths as local and known remote schemes as remote.""" + assert mcp_server._is_local_target(target) is expected + + +def test_is_local_target_checks_relative_paths_from_cwd( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Existing relative paths are local; missing relative paths stay unresolved.""" + (tmp_path / "skill").mkdir() + monkeypatch.chdir(tmp_path) + + assert mcp_server._is_local_target("skill") is True + assert mcp_server._is_local_target("missing-skill") is False + + +async def test_run_scan_allows_remote_target_when_local_targets_disallowed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Remote HTTP targets still reach the resolver path when local targets are blocked.""" + graph_ainvoke = AsyncMock( + return_value={ + "risk_score": 0, + "risk_severity": "low", + "risk_recommendation": "safe", + "filtered_findings": [], + "report_body": "ok", + } + ) + monkeypatch.setattr(mcp_server.graph, "ainvoke", graph_ainvoke) + monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None) + + target = "https://example.com/skills/safe.git" + result = await run_scan(target, allow_local_targets=False) + + assert result["target"] == target + assert graph_ainvoke.await_count == 1 + assert graph_ainvoke.await_args.args[0]["input_path"] == target + + +async def test_run_scan_keeps_default_local_target_compatibility( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """The default run_scan path still accepts local targets.""" + graph_ainvoke = AsyncMock( + return_value={ + "risk_score": 0, + "risk_severity": "low", + "risk_recommendation": "safe", + "filtered_findings": [], + "report_body": "ok", + } + ) + monkeypatch.setattr(mcp_server.graph, "ainvoke", graph_ainvoke) + monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None) + + result = await run_scan(str(tmp_path)) + + assert result["target"] == str(tmp_path) + assert graph_ainvoke.await_count == 1 + assert graph_ainvoke.await_args.args[0]["input_path"] == str(tmp_path) + + +@pytest.mark.parametrize( + ("transport", "expected_allow_local_targets"), + [("stdio", True), ("http", False)], +) +def test_run_passes_transport_local_target_policy( + transport: str, + expected_allow_local_targets: bool, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """run() keeps stdio local scans available and disables them for HTTP.""" + captured: dict[str, bool] = {} + server = SimpleNamespace( + settings=SimpleNamespace(host=None, port=None), + run=MagicMock(), + ) + + def fake_build_server(*, allow_local_targets: bool = True): + captured["allow_local_targets"] = allow_local_targets + return server + + monkeypatch.setattr(mcp_server, "build_server", fake_build_server) + + mcp_server.run(transport=transport, host="0.0.0.0", port=9000) + + assert captured["allow_local_targets"] is expected_allow_local_targets + if transport == "http": + assert server.settings.host == "0.0.0.0" + assert server.settings.port == 9000 + server.run.assert_called_once_with(transport="streamable-http") + else: + server.run.assert_called_once_with(transport="stdio") + + async def test_build_server_registers_scan_skill() -> None: """build_server wires up the scan_skill tool (requires the mcp extra).""" pytest.importorskip("mcp")