Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions src/skillspector/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from __future__ import annotations

import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Any

from skillspector import __version__
Expand All @@ -46,11 +47,30 @@
RISK_THRESHOLD = 50


def _is_local_target(target: str) -> bool:
"""Return True when ``target`` names local filesystem content."""
stripped = target.strip()
if stripped.startswith("file://"):
return True
if stripped.startswith(("http://", "https://", "git@", "ssh://", "git+ssh://")):
return False
if stripped.startswith(("\\\\", "//")):
return True

candidate = Path(stripped).expanduser()
if candidate.is_absolute() or candidate.drive:
return True
if "://" in stripped:
return False
return candidate.exists()


async def run_scan(
target: str,
*,
use_llm: bool = True,
output_format: str = "json",
allow_local_targets: bool = True,
yara_rules_dir: str | None = None,
) -> dict[str, Any]:
"""Invoke the SkillSpector graph and return a structured verdict.
Expand All @@ -62,6 +82,9 @@ async def run_scan(
the returned payload reports what actually happened.
output_format: Format of the embedded ``report`` string. One of
:data:`VALID_FORMATS`.
allow_local_targets: Whether local filesystem targets are allowed.
HTTP MCP calls set this to ``False`` so routable servers do not
accept caller-controlled local paths.
yara_rules_dir: Optional directory of additional YARA rules.

Returns:
Expand All @@ -73,6 +96,8 @@ async def run_scan(
"""
if output_format not in VALID_FORMATS:
raise ValueError(f"output_format must be one of {VALID_FORMATS}, got {output_format!r}")
if not allow_local_targets and _is_local_target(target):
raise ValueError("local targets are disabled for this MCP transport")

llm_available = resolve_provider_credentials() is not None
llm_used = use_llm and llm_available
Expand Down Expand Up @@ -131,7 +156,7 @@ async def run_scan(
shutil.rmtree(temp_dir, ignore_errors=True)


def build_server(name: str = "skillspector") -> FastMCP:
def build_server(name: str = "skillspector", *, allow_local_targets: bool = True) -> FastMCP:
"""Construct the FastMCP server exposing the ``scan_skill`` tool.

Requires the optional ``mcp`` dependency (``pip install 'skillspector[mcp]'``).
Expand Down Expand Up @@ -164,14 +189,19 @@ async def scan_skill(
actually ran, so a low score from a static-only scan is not mistaken for
a clean full scan.
"""
return await run_scan(target, use_llm=use_llm, output_format=output_format)
return await run_scan(
target,
use_llm=use_llm,
output_format=output_format,
allow_local_targets=allow_local_targets,
)

return server


def run(transport: str = "stdio", host: str = "127.0.0.1", port: int = 8000) -> None:
"""Run the MCP server over ``stdio`` (local agents) or ``http`` (remote/A2A)."""
server = build_server()
server = build_server(allow_local_targets=transport != "http")
if transport == "stdio":
server.run(transport="stdio")
elif transport == "http":
Expand Down
137 changes: 137 additions & 0 deletions tests/unit/test_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""Tests for the MCP server wrapper (run_scan core + scan_skill tool)."""

from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest

Expand Down Expand Up @@ -83,6 +85,141 @@ async def test_run_scan_rejects_invalid_format(tmp_path: Path) -> None:
await run_scan(str(tmp_path), output_format="xml")


async def test_run_scan_rejects_local_target_when_disallowed(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""HTTP-style scans reject local targets before the graph is invoked."""
graph_ainvoke = AsyncMock()
monkeypatch.setattr(mcp_server.graph, "ainvoke", graph_ainvoke)
monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None)

with pytest.raises(ValueError, match="local targets are disabled"):
await run_scan(str(tmp_path), allow_local_targets=False)

assert graph_ainvoke.await_count == 0


async def test_run_scan_rejects_file_url_when_local_targets_disallowed(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The same HTTP guard rejects file:// targets before any scan runs."""
graph_ainvoke = AsyncMock()
monkeypatch.setattr(mcp_server.graph, "ainvoke", graph_ainvoke)
monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None)

with pytest.raises(ValueError, match="local targets are disabled"):
await run_scan(tmp_path.as_uri(), allow_local_targets=False)

assert graph_ainvoke.await_count == 0


@pytest.mark.parametrize(
("target", "expected"),
[
(r"\\server\share\skill", True),
("//server/share/skill", True),
("git@github.com:NVIDIA/SkillSpector.git", False),
("ssh://git@github.com/NVIDIA/SkillSpector.git", False),
("git+ssh://git@github.com/NVIDIA/SkillSpector.git", False),
("custom://example/skill", False),
],
)
def test_is_local_target_classifies_protocol_edges(target: str, expected: bool) -> None:
"""Classifier treats UNC-style paths as local and known remote schemes as remote."""
assert mcp_server._is_local_target(target) is expected


def test_is_local_target_checks_relative_paths_from_cwd(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Existing relative paths are local; missing relative paths stay unresolved."""
(tmp_path / "skill").mkdir()
monkeypatch.chdir(tmp_path)

assert mcp_server._is_local_target("skill") is True
assert mcp_server._is_local_target("missing-skill") is False


async def test_run_scan_allows_remote_target_when_local_targets_disallowed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Remote HTTP targets still reach the resolver path when local targets are blocked."""
graph_ainvoke = AsyncMock(
return_value={
"risk_score": 0,
"risk_severity": "low",
"risk_recommendation": "safe",
"filtered_findings": [],
"report_body": "ok",
}
)
monkeypatch.setattr(mcp_server.graph, "ainvoke", graph_ainvoke)
monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None)

target = "https://example.com/skills/safe.git"
result = await run_scan(target, allow_local_targets=False)

assert result["target"] == target
assert graph_ainvoke.await_count == 1
assert graph_ainvoke.await_args.args[0]["input_path"] == target


async def test_run_scan_keeps_default_local_target_compatibility(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The default run_scan path still accepts local targets."""
graph_ainvoke = AsyncMock(
return_value={
"risk_score": 0,
"risk_severity": "low",
"risk_recommendation": "safe",
"filtered_findings": [],
"report_body": "ok",
}
)
monkeypatch.setattr(mcp_server.graph, "ainvoke", graph_ainvoke)
monkeypatch.setattr(mcp_server, "resolve_provider_credentials", lambda: None)

result = await run_scan(str(tmp_path))

assert result["target"] == str(tmp_path)
assert graph_ainvoke.await_count == 1
assert graph_ainvoke.await_args.args[0]["input_path"] == str(tmp_path)


@pytest.mark.parametrize(
("transport", "expected_allow_local_targets"),
[("stdio", True), ("http", False)],
)
def test_run_passes_transport_local_target_policy(
transport: str,
expected_allow_local_targets: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""run() keeps stdio local scans available and disables them for HTTP."""
captured: dict[str, bool] = {}
server = SimpleNamespace(
settings=SimpleNamespace(host=None, port=None),
run=MagicMock(),
)

def fake_build_server(*, allow_local_targets: bool = True):
captured["allow_local_targets"] = allow_local_targets
return server

monkeypatch.setattr(mcp_server, "build_server", fake_build_server)

mcp_server.run(transport=transport, host="0.0.0.0", port=9000)

assert captured["allow_local_targets"] is expected_allow_local_targets
if transport == "http":
assert server.settings.host == "0.0.0.0"
assert server.settings.port == 9000
server.run.assert_called_once_with(transport="streamable-http")
else:
server.run.assert_called_once_with(transport="stdio")


async def test_build_server_registers_scan_skill() -> None:
"""build_server wires up the scan_skill tool (requires the mcp extra)."""
pytest.importorskip("mcp")
Expand Down