Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/specify_cli/_download_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Helpers for bounded HTTP downloads."""

from __future__ import annotations

from typing import NoReturn, TypeVar
from urllib.parse import urlparse


ErrorT = TypeVar("ErrorT", bound=Exception)

MAX_DOWNLOAD_BYTES = 50 * 1024 * 1024
READ_CHUNK_SIZE = 1024 * 1024

# Tighter ceiling for responses that are read fully into memory and parsed as
# JSON. The 50 MiB MAX_DOWNLOAD_BYTES default is sized for archive/payload
# downloads; JSON metadata responses are far smaller, so capping them close to
# their real size shrinks the memory-DoS surface and keeps the "too large"
# error reachable (rather than only triggering on tens of MiB). Pass it
# explicitly at each JSON call site so the intended bound is pinned there.
# METADATA covers fixed-shape single-object responses (an OAuth token, one
# release's metadata): a few KiB in practice, 1 MiB is already generous.
MAX_JSON_METADATA_BYTES = 1 * 1024 * 1024


def is_https_or_localhost_http(url: str) -> bool:
"""Return True if *url* is HTTPS, or HTTP limited to loopback hosts.

Shared scheme-safety predicate used by the auth HTTP redirect handler and
by the direct URL validations in the CLI download flows, so the rule (and
any future tightening of it) lives in one place.

A hostname is always required: a URL without one (e.g. ``https:///x``)
has no real target and is rejected regardless of scheme.

The loopback allowance is a deliberate *exact-string* match on
``localhost`` / ``127.0.0.1`` / ``::1``, not an IP-range check: other
loopback addresses (e.g. ``127.0.0.2``) are intentionally not covered.
``urlparse`` already lower-cases the hostname, so the comparison is
case-insensitive.
"""
parsed = urlparse(url)
if not parsed.hostname:
return False
is_localhost = parsed.hostname in ("localhost", "127.0.0.1", "::1")
return parsed.scheme == "https" or (parsed.scheme == "http" and is_localhost)


def _raise(error_type: type[ErrorT], message: str) -> NoReturn:
raise error_type(message)


def read_response_limited(
response,
*,
max_bytes: int = MAX_DOWNLOAD_BYTES,
error_type: type[ErrorT] = ValueError,
label: str = "download",
) -> bytes:
"""Read at most *max_bytes* from a response object.

``response.read(n)`` is only guaranteed to return *up to* ``n`` bytes and may
return fewer even when more data is pending (e.g. chunked transfer encoding),
so a single ``read(max_bytes + 1)`` cannot enforce the bound on its own. Read
in a loop until EOF or until one byte past the limit has been accumulated.

*max_bytes* is keyword-only. It defaults to the module-wide
``MAX_DOWNLOAD_BYTES`` (50 MiB) ceiling for archive/payload downloads;
callers with a tighter budget (e.g. small JSON responses) should pass an
explicit value so the intended bound is pinned at the call site rather than
tracking changes to the shared default.
"""
chunks: list[bytes] = []
total = 0
limit = max_bytes + 1
while total < limit:
chunk = response.read(min(READ_CHUNK_SIZE, limit - total))
if not chunk:
break
chunks.append(chunk)
total += len(chunk)
if total > max_bytes:
_raise(error_type, f"{label} exceeds maximum size of {max_bytes} bytes")
return b"".join(chunks)
18 changes: 16 additions & 2 deletions src/specify_cli/_github_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def resolve_github_release_asset_api_url(
import json
import urllib.error

from specify_cli._download_security import (
MAX_JSON_METADATA_BYTES,
read_response_limited,
)

parsed = urlparse(download_url)
parts = [unquote(part) for part in parsed.path.strip("/").split("/")]

Expand Down Expand Up @@ -118,8 +123,17 @@ def resolve_github_release_asset_api_url(

try:
with open_url_fn(release_url, timeout=timeout) as response:
release_data = json.loads(response.read())
except (urllib.error.URLError, json.JSONDecodeError):
release_data = json.loads(
read_response_limited(
response,
max_bytes=MAX_JSON_METADATA_BYTES,
label=f"GitHub release metadata {release_url}",
)
)
# ValueError covers both an oversized body (raised by read_response_limited)
# and json.JSONDecodeError (a ValueError subclass); on any of these, fall
# back to the original URL by returning None.
except (urllib.error.URLError, ValueError):
return None

for asset in release_data.get("assets", []):
Expand Down
14 changes: 11 additions & 3 deletions src/specify_cli/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
release tag. The ``self_app`` Typer sub-command group is co-located here so
all version-related logic lives in one place.

Dependencies: stdlib + packaging + ._console only (no other internal imports
at module level, keeping this layer thin and circular-import-safe).
Dependencies: stdlib + packaging + ._console + ._download_security only
(keeping this layer thin and circular-import-safe).
"""
from __future__ import annotations

Expand All @@ -28,6 +28,7 @@
import typer
from packaging.version import InvalidVersion, Version

from ._download_security import MAX_JSON_METADATA_BYTES, read_response_limited
from ._console import console

GITHUB_API_LATEST = "https://api.github.com/repos/github/spec-kit/releases/latest"
Expand Down Expand Up @@ -118,8 +119,15 @@ def _fetch_latest_release_tag() -> tuple[str | None, str | None]:
GITHUB_API_LATEST,
timeout=5,
extra_headers={"Accept": "application/vnd.github+json"},
strict_redirects=True,
) as resp:
payload = json.loads(resp.read().decode("utf-8"))
payload = json.loads(
read_response_limited(
resp,
max_bytes=MAX_JSON_METADATA_BYTES,
label="GitHub latest release",
).decode("utf-8")
)
tag = payload.get("tag_name")
if not isinstance(tag, str) or not tag:
raise ValueError("GitHub API response missing valid tag_name")
Expand Down
33 changes: 30 additions & 3 deletions src/specify_cli/authentication/azure_devops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import subprocess
from typing import TYPE_CHECKING

from .._download_security import MAX_JSON_METADATA_BYTES, read_response_limited
from .base import AuthProvider

if TYPE_CHECKING:
Expand All @@ -17,6 +18,10 @@
_ADO_RESOURCE_ID = "499b84ac-1321-427f-aa17-267ca6975798"


class _TokenResponseTooLarge(Exception):
"""Raised when an Azure AD token response exceeds the bounded read limit."""


class AzureDevOpsAuth(AuthProvider):
"""Azure DevOps authentication provider.

Expand Down Expand Up @@ -109,9 +114,31 @@ def _acquire_via_client_credentials(entry: AuthConfigEntry) -> str | None:
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
try:
with urllib.request.urlopen(req, timeout=30) as resp: # noqa: S310
payload = _json.loads(resp.read().decode("utf-8"))
from specify_cli.authentication.http import _StripAuthOnRedirect

# A 307/308 redirect preserves the POST body, which carries the
# client_secret. Reuse the package HTTPS-downgrade guard (empty host
# list means no auth header to strip, just the scheme check) so the
# secret can never be forwarded to a non-HTTPS, non-loopback host.
opener = urllib.request.build_opener(_StripAuthOnRedirect(()))
with opener.open(req, timeout=30) as resp: # noqa: S310
payload = _json.loads(
read_response_limited(
resp,
max_bytes=MAX_JSON_METADATA_BYTES,
error_type=_TokenResponseTooLarge,
label="Azure DevOps token response",
).decode("utf-8")
)
token = payload.get("access_token", "").strip()
return token or None
except (urllib.error.URLError, OSError, _json.JSONDecodeError, KeyError):
except (
urllib.error.URLError,
OSError,
_json.JSONDecodeError,
_TokenResponseTooLarge,
):
# Network failure, malformed JSON, or an oversized response — fall
# through to the next strategy. Unrelated programming errors (other
# ValueErrors, KeyErrors) intentionally propagate so they surface.
return None
38 changes: 34 additions & 4 deletions src/specify_cli/authentication/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Callable
from urllib.parse import urlparse

from .._download_security import is_https_or_localhost_http
from . import get_provider
from .config import AuthConfigEntry, _default_config_path, find_entries_for_url, load_auth_config

Expand Down Expand Up @@ -60,8 +61,23 @@ def _hostname_in_hosts(hostname: str, hosts: tuple[str, ...]) -> bool:
RedirectValidator = Callable[[str, str], None]


def _validate_strict_redirect(_old_url: str, new_url: str) -> None:
if not is_https_or_localhost_http(new_url):
raise urllib.error.URLError(
"unsafe redirect: target must use HTTPS with a hostname, "
"or HTTP for localhost (127.0.0.1, ::1)"
)


class _StripAuthOnRedirect(urllib.request.HTTPRedirectHandler):
"""Drop ``Authorization`` when a redirect leaves trusted hosts or downgrades."""
"""Redirect handler that guards every redirect it is installed for.

1. Run any caller-provided redirect validator.
2. Reject redirects that are not HTTPS with a hostname, except HTTP to
localhost / 127.0.0.1 / ::1 (the exact hosts allowed by
``is_https_or_localhost_http``).
3. Drop ``Authorization`` when a redirect leaves trusted hosts or downgrades.
"""

def __init__(
self,
Expand All @@ -75,6 +91,7 @@ def __init__(
def redirect_request(self, req, fp, code, msg, headers, newurl):
if self._redirect_validator is not None:
self._redirect_validator(req.full_url, newurl)
_validate_strict_redirect(req.full_url, newurl)

original_auth = (
req.get_header("Authorization")
Expand Down Expand Up @@ -123,6 +140,7 @@ def open_url(
timeout: int = 10,
extra_headers: dict[str, str] | None = None,
redirect_validator: RedirectValidator | None = None,
strict_redirects: bool = False,
):
"""Open *url* with config-driven auth, redirect stripping, and fallthrough.

Expand All @@ -135,9 +153,19 @@ def open_url(
*extra_headers* (e.g. ``Accept``) are merged into every attempt.
*redirect_validator*, when provided, is called with ``(old_url, new_url)``
before following each redirect and may raise to reject the redirect.

Redirect scheme safety: every authenticated attempt goes through
``_StripAuthOnRedirect``, which always rejects redirects to non-HTTPS
URLs (except HTTP to localhost / 127.0.0.1 / ::1, the hosts allowed by
``is_https_or_localhost_http``). The unauthenticated fallback installs the
same handler when *strict_redirects* is true or *redirect_validator* is
supplied; without either, it follows redirects without that handler.
"""
entries = find_entries_for_url(url, _load_config())

effective_redirect_validator = redirect_validator
use_redirect_handler = strict_redirects or effective_redirect_validator is not None

def _make_req(auth_headers: dict[str, str]) -> urllib.request.Request:
merged = {}
if extra_headers:
Expand All @@ -157,7 +185,7 @@ def _make_req(auth_headers: dict[str, str]) -> urllib.request.Request:
continue

req = _make_req(provider.auth_headers(token, entry.auth))
opener = urllib.request.build_opener(_StripAuthOnRedirect(entry.hosts, redirect_validator))
opener = urllib.request.build_opener(_StripAuthOnRedirect(entry.hosts, effective_redirect_validator))
try:
return opener.open(req, timeout=timeout)
except urllib.error.HTTPError as exc:
Expand All @@ -168,7 +196,9 @@ def _make_req(auth_headers: dict[str, str]) -> urllib.request.Request:

# No entry worked (or none matched) — unauthenticated fallback
req = _make_req({})
if redirect_validator is not None:
opener = urllib.request.build_opener(_StripAuthOnRedirect((), redirect_validator))
if use_redirect_handler:
# No auth is attached on this path, so the handler's host list is empty:
# here it runs redirect validation only, not auth stripping.
opener = urllib.request.build_opener(_StripAuthOnRedirect((), effective_redirect_validator))
return opener.open(req, timeout=timeout)
return urllib.request.urlopen(req, timeout=timeout) # noqa: S310
33 changes: 32 additions & 1 deletion tests/http_helpers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
"""HTTP test helpers shared by version-related CLI tests."""

import io
import json
import urllib.request
from unittest.mock import MagicMock

import pytest


def mock_urlopen_response(payload: dict) -> MagicMock:
"""Build a urlopen context-manager mock whose read returns JSON."""
body = json.dumps(payload).encode("utf-8")
resp = MagicMock()
resp.read.return_value = body
resp.read.side_effect = io.BytesIO(body).read
cm = MagicMock()
cm.__enter__.return_value = resp
cm.__exit__.return_value = False
return cm


@pytest.fixture(autouse=True)
def route_opener_open_through_urlopen(monkeypatch):
"""Route build_opener().open through urllib.request.urlopen.

``open_url(..., strict_redirects=True)`` fetches via
``build_opener(...).open()``, which bypasses ``urllib.request.urlopen``
— and with it the urlopen patches these test modules are built on.
Delegating ``open()`` to urlopen at call time keeps those patches
effective; the redirect handler's own behavior is covered by
``TestRedirectStripping`` in test_authentication.py.

Import this fixture into a test module to activate it there.
"""

class _UrlopenDelegatingOpener:
def open(self, req, data=None, timeout=None):
if data is None:
return urllib.request.urlopen(req, timeout=timeout)
return urllib.request.urlopen(req, data=data, timeout=timeout)

monkeypatch.setattr(
urllib.request,
"build_opener",
lambda *handlers: _UrlopenDelegatingOpener(),
)
3 changes: 2 additions & 1 deletion tests/self_upgrade_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
_verify_upgrade,
)
from tests.conftest import strip_ansi
from tests.http_helpers import mock_urlopen_response
from tests.http_helpers import mock_urlopen_response, route_opener_open_through_urlopen

__all__ = (
"SENTINEL_GH_TOKEN",
Expand All @@ -31,6 +31,7 @@
"_verify_upgrade",
"mock_urlopen_response",
"requires_posix",
"route_opener_open_through_urlopen",
"runner",
"strip_ansi",
)
Expand Down
Loading