diff --git a/src/specify_cli/_download_security.py b/src/specify_cli/_download_security.py new file mode 100644 index 0000000000..d7db2682c2 --- /dev/null +++ b/src/specify_cli/_download_security.py @@ -0,0 +1,346 @@ +"""Helpers for bounded downloads and archive extraction.""" + +from __future__ import annotations + +import hashlib +import re +import stat +import zipfile +from pathlib import Path, PurePosixPath +from typing import NoReturn, TypeVar +from urllib.parse import urlparse + + +ErrorT = TypeVar("ErrorT", bound=Exception) + +MAX_DOWNLOAD_BYTES = 50 * 1024 * 1024 +MAX_ZIP_ENTRIES = 512 +MAX_ZIP_MEMBER_BYTES = 10 * 1024 * 1024 +MAX_ZIP_TOTAL_BYTES = 50 * 1024 * 1024 +READ_CHUNK_SIZE = 1024 * 1024 + +# Tighter ceilings for responses that are read fully into memory and parsed as +# JSON. The 50 MiB MAX_DOWNLOAD_BYTES default is sized for archive/payload +# downloads; JSON responses are far smaller, so capping them close to their real +# size shrinks the memory-DoS surface and keeps the "too large" error reachable +# (rather than only triggering on tens of MiB). Pass the matching constant +# explicitly at each JSON call site so the intended bound is pinned there. +# * METADATA - fixed-shape single-object responses (an OAuth token, one +# release's metadata): a few KiB in practice, 1 MiB is already generous. +# * CATALOG - listings that grow with the number of published items. The +# largest bundled catalog is ~130 KiB today, so 8 MiB leaves ~60x headroom +# for growth while staying well under the download ceiling. +MAX_JSON_METADATA_BYTES = 1 * 1024 * 1024 +MAX_JSON_CATALOG_BYTES = 8 * 1024 * 1024 +SHA256_RE = re.compile(r"^[0-9a-fA-F]{64}$") + + +def is_https_or_localhost_http(url: str) -> bool: + """Return True if *url* is HTTPS, or HTTP limited to loopback hosts. + + Shared scheme-safety predicate used by the auth HTTP redirect handler and + by the direct URL validations in the CLI download flows, so the rule (and + any future tightening of it) lives in one place. + + A hostname is always required: a URL without one (e.g. ``https:///x``) + has no real target and is rejected regardless of scheme. + + The loopback allowance is a deliberate *exact-string* match on + ``localhost`` / ``127.0.0.1`` / ``::1``, not an IP-range check: other + loopback addresses (e.g. ``127.0.0.2``) are intentionally not covered. + ``urlparse`` already lower-cases the hostname, so the comparison is + case-insensitive. + """ + parsed = urlparse(url) + if not parsed.hostname: + return False + is_localhost = parsed.hostname in ("localhost", "127.0.0.1", "::1") + return parsed.scheme == "https" or (parsed.scheme == "http" and is_localhost) + + +def _raise(error_type: type[ErrorT], message: str) -> NoReturn: + raise error_type(message) + + +def _raise_from(error_type: type[ErrorT], message: str, exc: Exception) -> NoReturn: + raise error_type(message) from exc + + +def read_response_limited( + response, + *, + max_bytes: int = MAX_DOWNLOAD_BYTES, + error_type: type[ErrorT] = ValueError, + label: str = "download", +) -> bytes: + """Read at most *max_bytes* from a response object. + + ``response.read(n)`` is only guaranteed to return *up to* ``n`` bytes and may + return fewer even when more data is pending (e.g. chunked transfer encoding), + so a single ``read(max_bytes + 1)`` cannot enforce the bound on its own. Read + in a loop until EOF or until one byte past the limit has been accumulated. + + *max_bytes* is keyword-only. It defaults to the module-wide + ``MAX_DOWNLOAD_BYTES`` (50 MiB) ceiling for archive/payload downloads; + callers with a tighter budget (e.g. small JSON responses) should pass an + explicit value so the intended bound is pinned at the call site rather than + tracking changes to the shared default. + """ + chunks: list[bytes] = [] + total = 0 + limit = max_bytes + 1 + while total < limit: + chunk = response.read(min(READ_CHUNK_SIZE, limit - total)) + if not chunk: + break + chunks.append(chunk) + total += len(chunk) + if total > max_bytes: + _raise(error_type, f"{label} exceeds maximum size of {max_bytes} bytes") + return b"".join(chunks) + + +def normalize_sha256(value: object, *, error_type: type[ErrorT] = ValueError) -> str | None: + """Normalize an optional sha256/sha256: checksum value.""" + if value is None: + return None + if not isinstance(value, str): + _raise(error_type, "sha256 checksum must be a string") + + checksum = value.strip() + if checksum.startswith("sha256:"): + checksum = checksum[len("sha256:") :] + if not SHA256_RE.fullmatch(checksum): + _raise(error_type, "sha256 checksum must be 64 hexadecimal characters") + return checksum.lower() + + +def verify_sha256( + data: bytes, + expected: object, + *, + error_type: type[ErrorT] = ValueError, + label: str = "download", +) -> None: + """Verify *data* against an optional sha256 checksum.""" + checksum = normalize_sha256(expected, error_type=error_type) + if checksum is None: + return + + actual = hashlib.sha256(data).hexdigest() + if actual != checksum: + _raise( + error_type, + f"{label} checksum mismatch: expected sha256:{checksum}, got sha256:{actual}", + ) + + +def read_zip_member_limited( + zf: zipfile.ZipFile, + name: str, + *, + max_bytes: int = MAX_ZIP_MEMBER_BYTES, + error_type: type[ErrorT] = ValueError, + label: str | None = None, +) -> bytes: + """Read a single ZIP member into memory under a hard size cap. + + Reading a member with ``zf.open(name).read()`` is unbounded: a crafted + archive can declare a tiny ``file_size`` yet decompress to many gigabytes (a + "zip bomb"), exhausting memory before the caller ever inspects the data. + This rejects members whose *declared* size already exceeds *max_bytes* and, + to defend against headers that lie, also reads in bounded chunks and stops + one byte past the limit. + + Use this for any inline manifest/metadata read that happens *before* + :func:`safe_extract_zip` (which already enforces the same per-member bound + during extraction); a raw ``zf.open(...).read()`` bypasses that protection. + """ + member_label = label or name + try: + info = zf.getinfo(name) + except KeyError as exc: + _raise_from(error_type, f"ZIP member not found: {name}", exc) + if info.file_size > max_bytes: + _raise( + error_type, + f"ZIP member {member_label} exceeds maximum size of {max_bytes} bytes", + ) + + chunks: list[bytes] = [] + total = 0 + limit = max_bytes + 1 + try: + with zf.open(name, "r") as source: + while total < limit: + chunk = source.read(min(READ_CHUNK_SIZE, limit - total)) + if not chunk: + break + chunks.append(chunk) + total += len(chunk) + except (OSError, zipfile.BadZipFile, RuntimeError) as exc: + _raise_from(error_type, f"Failed to read ZIP member {member_label}: {exc}", exc) + if total > max_bytes: + _raise( + error_type, + f"ZIP member {member_label} exceeds maximum size of {max_bytes} bytes", + ) + return b"".join(chunks) + + +def _safe_zip_name(name: str, *, error_type: type[ErrorT]) -> str: + """Return a normalized ZIP member name or raise on traversal.""" + if "\x00" in name: + _raise(error_type, f"Unsafe path in ZIP archive: {name!r}") + + normalized = name.replace("\\", "/") + path = PurePosixPath(normalized) + raw_parts = normalized.split("/") + # Strip a single trailing empty segment, i.e. the one-slash directory + # marker that legitimate ZIPs use ("mydir/", "mydir/subdir/"). Anything + # else that produces an empty segment - consecutive slashes ("a//b") or a + # second trailing slash - is left in place and rejected below as malformed. + if raw_parts and raw_parts[-1] == "": + raw_parts = raw_parts[:-1] + has_windows_drive = re.match(r"^[A-Za-z]:", normalized) is not None + if ( + not raw_parts + or path.is_absolute() + or has_windows_drive + or any(part in {"", ".", ".."} for part in raw_parts) + ): + _raise( + error_type, + f"Unsafe path in ZIP archive: {name} (potential path traversal)", + ) + return normalized + + +def safe_extract_zip( + zip_path: Path, + target_dir: Path, + *, + error_type: type[ErrorT] = ValueError, + max_entries: int = MAX_ZIP_ENTRIES, + max_member_bytes: int = MAX_ZIP_MEMBER_BYTES, + max_total_bytes: int = MAX_ZIP_TOTAL_BYTES, +) -> None: + """Extract a ZIP archive after path, symlink, and size validation.""" + try: + target_root = target_dir.resolve() + except OSError as exc: + _raise_from(error_type, f"Invalid ZIP extraction target: {target_dir}", exc) + + try: + zf = zipfile.ZipFile(zip_path, "r") + except (OSError, zipfile.BadZipFile) as exc: + _raise_from(error_type, f"Invalid ZIP archive: {zip_path}", exc) + + with zf: + try: + members = zf.infolist() + except zipfile.BadZipFile as exc: + _raise_from(error_type, f"Invalid ZIP archive: {zip_path}", exc) + if len(members) > max_entries: + _raise( + error_type, + f"ZIP archive contains too many entries ({len(members)} > {max_entries})", + ) + + normalized_members: list[tuple[zipfile.ZipInfo, str, bool]] = [] + total_size = 0 + for member in members: + normalized_name = _safe_zip_name(member.filename, error_type=error_type) + is_dir = member.is_dir() or normalized_name.endswith("/") + + mode = member.external_attr >> 16 + if stat.S_ISLNK(mode): + _raise(error_type, f"Unsafe symlink in ZIP archive: {member.filename}") + + member_path = (target_dir / normalized_name).resolve() + try: + member_path.relative_to(target_root) + except ValueError: + _raise( + error_type, + f"Unsafe path in ZIP archive: {member.filename} " + "(potential path traversal)", + ) + + if not is_dir: + if member.file_size > max_member_bytes: + _raise( + error_type, + f"ZIP member {member.filename} exceeds maximum size " + f"of {max_member_bytes} bytes", + ) + total_size += member.file_size + if total_size > max_total_bytes: + _raise( + error_type, + f"ZIP archive exceeds maximum uncompressed size " + f"of {max_total_bytes} bytes", + ) + + normalized_members.append((member, normalized_name, is_dir)) + + # The loop above bounds the *declared* total via member.file_size, but a + # crafted archive can understate those headers. Mirror the per-member + # guard below with a cumulative count of the bytes actually written so + # the total-size bound holds even when the headers lie. + total_written = 0 + for member, normalized_name, is_dir in normalized_members: + member_path = target_dir / normalized_name + if is_dir: + try: + member_path.mkdir(parents=True, exist_ok=True) + except OSError as exc: + _raise_from( + error_type, + f"Failed to create ZIP directory {member.filename}: {exc}", + exc, + ) + continue + + try: + member_path.parent.mkdir(parents=True, exist_ok=True) + except OSError as exc: + _raise_from( + error_type, + f"Failed to create parent directory for ZIP member {member.filename}: {exc}", + exc, + ) + written = 0 + # Raised outside the try below: if error_type subclasses OSError or + # RuntimeError, raising inside would re-wrap the limit error as + # "Failed to extract" and lose the size-bound message. + limit_error: str | None = None + try: + with zf.open(member, "r") as source, member_path.open("wb") as dest: + while True: + chunk = source.read(READ_CHUNK_SIZE) + if not chunk: + break + written += len(chunk) + if written > max_member_bytes: + limit_error = ( + f"ZIP member {member.filename} exceeds maximum size " + f"of {max_member_bytes} bytes" + ) + break + total_written += len(chunk) + if total_written > max_total_bytes: + limit_error = ( + f"ZIP archive exceeds maximum uncompressed size " + f"of {max_total_bytes} bytes" + ) + break + dest.write(chunk) + except (OSError, zipfile.BadZipFile, RuntimeError) as exc: + _raise_from( + error_type, + f"Failed to extract ZIP member {member.filename}: {exc}", + exc, + ) + if limit_error is not None: + _raise(error_type, limit_error) diff --git a/src/specify_cli/_github_http.py b/src/specify_cli/_github_http.py index d2030b57a8..e9a5f7a4b1 100644 --- a/src/specify_cli/_github_http.py +++ b/src/specify_cli/_github_http.py @@ -91,6 +91,11 @@ def resolve_github_release_asset_api_url( import json import urllib.error + from specify_cli._download_security import ( + MAX_JSON_METADATA_BYTES, + read_response_limited, + ) + parsed = urlparse(download_url) parts = [unquote(part) for part in parsed.path.strip("/").split("/")] @@ -118,8 +123,17 @@ def resolve_github_release_asset_api_url( try: with open_url_fn(release_url, timeout=timeout) as response: - release_data = json.loads(response.read()) - except (urllib.error.URLError, json.JSONDecodeError): + release_data = json.loads( + read_response_limited( + response, + max_bytes=MAX_JSON_METADATA_BYTES, + label=f"GitHub release metadata {release_url}", + ) + ) + # ValueError covers both an oversized body (raised by read_response_limited) + # and json.JSONDecodeError (a ValueError subclass); on any of these, fall + # back to the original URL by returning None. + except (urllib.error.URLError, ValueError): return None for asset in release_data.get("assets", []): diff --git a/src/specify_cli/_version.py b/src/specify_cli/_version.py index e634a4f286..7720cf2ab6 100644 --- a/src/specify_cli/_version.py +++ b/src/specify_cli/_version.py @@ -4,8 +4,8 @@ release tag. The ``self_app`` Typer sub-command group is co-located here so all version-related logic lives in one place. -Dependencies: stdlib + packaging + ._console only (no other internal imports -at module level, keeping this layer thin and circular-import-safe). +Dependencies: stdlib + packaging + ._console + ._download_security only +(keeping this layer thin and circular-import-safe). """ from __future__ import annotations @@ -28,6 +28,7 @@ import typer from packaging.version import InvalidVersion, Version +from ._download_security import MAX_JSON_METADATA_BYTES, read_response_limited from ._console import console GITHUB_API_LATEST = "https://api.github.com/repos/github/spec-kit/releases/latest" @@ -118,8 +119,15 @@ def _fetch_latest_release_tag() -> tuple[str | None, str | None]: GITHUB_API_LATEST, timeout=5, extra_headers={"Accept": "application/vnd.github+json"}, + strict_redirects=True, ) as resp: - payload = json.loads(resp.read().decode("utf-8")) + payload = json.loads( + read_response_limited( + resp, + max_bytes=MAX_JSON_METADATA_BYTES, + label="GitHub latest release", + ).decode("utf-8") + ) tag = payload.get("tag_name") if not isinstance(tag, str) or not tag: raise ValueError("GitHub API response missing valid tag_name") diff --git a/src/specify_cli/authentication/azure_devops.py b/src/specify_cli/authentication/azure_devops.py index 5d71a1957b..72e25de92b 100644 --- a/src/specify_cli/authentication/azure_devops.py +++ b/src/specify_cli/authentication/azure_devops.py @@ -8,6 +8,7 @@ import subprocess from typing import TYPE_CHECKING +from .._download_security import MAX_JSON_METADATA_BYTES, read_response_limited from .base import AuthProvider if TYPE_CHECKING: @@ -17,6 +18,10 @@ _ADO_RESOURCE_ID = "499b84ac-1321-427f-aa17-267ca6975798" +class _TokenResponseTooLarge(Exception): + """Raised when an Azure AD token response exceeds the bounded read limit.""" + + class AzureDevOpsAuth(AuthProvider): """Azure DevOps authentication provider. @@ -109,9 +114,31 @@ def _acquire_via_client_credentials(entry: AuthConfigEntry) -> str | None: headers={"Content-Type": "application/x-www-form-urlencoded"}, ) try: - with urllib.request.urlopen(req, timeout=30) as resp: # noqa: S310 - payload = _json.loads(resp.read().decode("utf-8")) + from specify_cli.authentication.http import _StripAuthOnRedirect + + # A 307/308 redirect preserves the POST body, which carries the + # client_secret. Reuse the package HTTPS-downgrade guard (empty host + # list means no auth header to strip, just the scheme check) so the + # secret can never be forwarded to a non-HTTPS, non-loopback host. + opener = urllib.request.build_opener(_StripAuthOnRedirect(())) + with opener.open(req, timeout=30) as resp: # noqa: S310 + payload = _json.loads( + read_response_limited( + resp, + max_bytes=MAX_JSON_METADATA_BYTES, + error_type=_TokenResponseTooLarge, + label="Azure DevOps token response", + ).decode("utf-8") + ) token = payload.get("access_token", "").strip() return token or None - except (urllib.error.URLError, OSError, _json.JSONDecodeError, KeyError): + except ( + urllib.error.URLError, + OSError, + _json.JSONDecodeError, + _TokenResponseTooLarge, + ): + # Network failure, malformed JSON, or an oversized response — fall + # through to the next strategy. Unrelated programming errors (other + # ValueErrors, KeyErrors) intentionally propagate so they surface. return None diff --git a/src/specify_cli/authentication/http.py b/src/specify_cli/authentication/http.py index e8ab8c1241..1aeacce9f8 100644 --- a/src/specify_cli/authentication/http.py +++ b/src/specify_cli/authentication/http.py @@ -17,6 +17,7 @@ from typing import Callable from urllib.parse import urlparse +from .._download_security import is_https_or_localhost_http from . import get_provider from .config import AuthConfigEntry, _default_config_path, find_entries_for_url, load_auth_config @@ -60,8 +61,23 @@ def _hostname_in_hosts(hostname: str, hosts: tuple[str, ...]) -> bool: RedirectValidator = Callable[[str, str], None] +def _validate_strict_redirect(_old_url: str, new_url: str) -> None: + if not is_https_or_localhost_http(new_url): + raise urllib.error.URLError( + "unsafe redirect: target must use HTTPS with a hostname, " + "or HTTP for localhost (127.0.0.1, ::1)" + ) + + class _StripAuthOnRedirect(urllib.request.HTTPRedirectHandler): - """Drop ``Authorization`` when a redirect leaves trusted hosts or downgrades.""" + """Redirect handler that guards every redirect it is installed for. + + 1. Run any caller-provided redirect validator. + 2. Reject redirects that are not HTTPS with a hostname, except HTTP to + localhost / 127.0.0.1 / ::1 (the exact hosts allowed by + ``is_https_or_localhost_http``). + 3. Drop ``Authorization`` when a redirect leaves trusted hosts or downgrades. + """ def __init__( self, @@ -75,6 +91,7 @@ def __init__( def redirect_request(self, req, fp, code, msg, headers, newurl): if self._redirect_validator is not None: self._redirect_validator(req.full_url, newurl) + _validate_strict_redirect(req.full_url, newurl) original_auth = ( req.get_header("Authorization") @@ -123,6 +140,7 @@ def open_url( timeout: int = 10, extra_headers: dict[str, str] | None = None, redirect_validator: RedirectValidator | None = None, + strict_redirects: bool = False, ): """Open *url* with config-driven auth, redirect stripping, and fallthrough. @@ -135,9 +153,19 @@ def open_url( *extra_headers* (e.g. ``Accept``) are merged into every attempt. *redirect_validator*, when provided, is called with ``(old_url, new_url)`` before following each redirect and may raise to reject the redirect. + + Redirect scheme safety: every authenticated attempt goes through + ``_StripAuthOnRedirect``, which always rejects redirects to non-HTTPS + URLs (except HTTP to localhost / 127.0.0.1 / ::1, the hosts allowed by + ``is_https_or_localhost_http``). The unauthenticated fallback installs the + same handler when *strict_redirects* is true or *redirect_validator* is + supplied; without either, it follows redirects without that handler. """ entries = find_entries_for_url(url, _load_config()) + effective_redirect_validator = redirect_validator + use_redirect_handler = strict_redirects or effective_redirect_validator is not None + def _make_req(auth_headers: dict[str, str]) -> urllib.request.Request: merged = {} if extra_headers: @@ -157,7 +185,7 @@ def _make_req(auth_headers: dict[str, str]) -> urllib.request.Request: continue req = _make_req(provider.auth_headers(token, entry.auth)) - opener = urllib.request.build_opener(_StripAuthOnRedirect(entry.hosts, redirect_validator)) + opener = urllib.request.build_opener(_StripAuthOnRedirect(entry.hosts, effective_redirect_validator)) try: return opener.open(req, timeout=timeout) except urllib.error.HTTPError as exc: @@ -168,7 +196,9 @@ def _make_req(auth_headers: dict[str, str]) -> urllib.request.Request: # No entry worked (or none matched) — unauthenticated fallback req = _make_req({}) - if redirect_validator is not None: - opener = urllib.request.build_opener(_StripAuthOnRedirect((), redirect_validator)) + if use_redirect_handler: + # No auth is attached on this path, so the handler's host list is empty: + # here it runs redirect validation only, not auth stripping. + opener = urllib.request.build_opener(_StripAuthOnRedirect((), effective_redirect_validator)) return opener.open(req, timeout=timeout) return urllib.request.urlopen(req, timeout=timeout) # noqa: S310 diff --git a/src/specify_cli/catalogs.py b/src/specify_cli/catalogs.py index 8bd3b2dc06..33cc0bc996 100644 --- a/src/specify_cli/catalogs.py +++ b/src/specify_cli/catalogs.py @@ -68,18 +68,18 @@ def _entry( @classmethod def _validate_catalog_url(cls, url: str) -> None: - """Validate that a catalog URL uses HTTPS, except localhost HTTP.""" + """Validate that a catalog URL uses HTTPS, except loopback HTTP.""" from urllib.parse import urlparse parsed = urlparse(url) + if not parsed.hostname: + raise cls._error("Catalog URL must be a valid URL with a host.") is_localhost = parsed.hostname in ("localhost", "127.0.0.1", "::1") if parsed.scheme != "https" and not (parsed.scheme == "http" and is_localhost): raise cls._error( f"Catalog URL must use HTTPS (got {parsed.scheme}://). " - "HTTP is only allowed for localhost." + "HTTP is only allowed for localhost, 127.0.0.1, and ::1." ) - if not parsed.netloc: - raise cls._error("Catalog URL must be a valid URL with a host.") def _load_catalog_config(self, config_path: Path) -> list[CatalogEntry] | None: """Load catalog stack configuration from a YAML file. diff --git a/src/specify_cli/extensions/__init__.py b/src/specify_cli/extensions/__init__.py index 19cc0f0910..ad2890520e 100644 --- a/src/specify_cli/extensions/__init__.py +++ b/src/specify_cli/extensions/__init__.py @@ -15,7 +15,6 @@ import re import shutil import tempfile -import zipfile from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path @@ -26,6 +25,12 @@ from packaging import version as pkg_version from packaging.specifiers import InvalidSpecifier, SpecifierSet +from .._download_security import ( + MAX_JSON_CATALOG_BYTES, + read_response_limited, + safe_extract_zip, + verify_sha256, +) from .._init_options import is_ai_skills_enabled from .._invocation_style import is_dollar_skills_agent, is_slash_skills_agent from .._utils import dump_frontmatter, relative_extension_path_violation @@ -1475,21 +1480,7 @@ def install_from_zip( with tempfile.TemporaryDirectory() as tmpdir: temp_path = Path(tmpdir) - # Extract ZIP safely (prevent Zip Slip attack) - with zipfile.ZipFile(zip_path, "r") as zf: - # Validate all paths first before extracting anything - temp_path_resolved = temp_path.resolve() - for member in zf.namelist(): - member_path = (temp_path / member).resolve() - # Use is_relative_to for safe path containment check - try: - member_path.relative_to(temp_path_resolved) - except ValueError: - raise ValidationError( - f"Unsafe path in ZIP archive: {member} (potential path traversal)" - ) - # Only extract after all paths are validated - zf.extractall(temp_path) + safe_extract_zip(zip_path, temp_path, error_type=ValidationError) # Find extension directory (may be nested) extension_dir = temp_path @@ -2262,7 +2253,14 @@ def _fetch_single_catalog( # Fetch from network try: with self._open_url(entry.url, timeout=10) as response: - catalog_data = json.loads(response.read()) + catalog_data = json.loads( + read_response_limited( + response, + max_bytes=MAX_JSON_CATALOG_BYTES, + error_type=ExtensionError, + label=f"extension catalog {entry.url}", + ) + ) self._validate_catalog_payload(catalog_data, entry.url) @@ -2439,7 +2437,14 @@ def fetch_catalog(self, force_refresh: bool = False) -> Dict[str, Any]: import urllib.error with self._open_url(catalog_url, timeout=10) as response: - catalog_data = json.loads(response.read()) + catalog_data = json.loads( + read_response_limited( + response, + max_bytes=MAX_JSON_CATALOG_BYTES, + error_type=ExtensionError, + label=f"extension catalog {catalog_url}", + ) + ) # Validate catalog structure. Reuses the same helper as # ``_fetch_single_catalog`` so all three branches (root type, @@ -2615,7 +2620,18 @@ def download_extension( with self._open_url( download_url, timeout=60, extra_headers=extra_headers ) as response: - zip_data = response.read() + zip_data = read_response_limited( + response, + error_type=ExtensionError, + label=f"extension '{extension_id}' download", + ) + + verify_sha256( + zip_data, + ext_info.get("sha256"), + error_type=ExtensionError, + label=f"extension '{extension_id}' download", + ) zip_path.write_bytes(zip_data) return zip_path diff --git a/src/specify_cli/extensions/_commands.py b/src/specify_cli/extensions/_commands.py index 3b60b6d52d..2fd5e8a24f 100644 --- a/src/specify_cli/extensions/_commands.py +++ b/src/specify_cli/extensions/_commands.py @@ -23,6 +23,7 @@ from .._console import console from .._assets import get_speckit_version +from .._download_security import read_zip_member_limited extension_app = typer.Typer( name="extension", @@ -1188,19 +1189,25 @@ def extension_update( manifest_data = None namelist = zf.namelist() + # Read the manifest under a hard size cap: this happens + # before install_from_zip()'s safe_extract_zip(), so a + # raw zf.open().read() here would bypass that bound and + # let a zip-bomb extension.yml exhaust memory. # First try root-level extension.yml if "extension.yml" in namelist: - with zf.open("extension.yml") as f: - parsed_manifest = yaml.safe_load(f) - manifest_data = parsed_manifest if parsed_manifest is not None else {} + parsed_manifest = yaml.safe_load( + read_zip_member_limited(zf, "extension.yml") + ) + manifest_data = parsed_manifest if parsed_manifest is not None else {} else: # Look for extension.yml in a single top-level subdirectory # (e.g., "repo-name-branch/extension.yml") manifest_paths = [n for n in namelist if n.endswith("/extension.yml") and n.count("/") == 1] if len(manifest_paths) == 1: - with zf.open(manifest_paths[0]) as f: - parsed_manifest = yaml.safe_load(f) - manifest_data = parsed_manifest if parsed_manifest is not None else {} + parsed_manifest = yaml.safe_load( + read_zip_member_limited(zf, manifest_paths[0]) + ) + manifest_data = parsed_manifest if parsed_manifest is not None else {} if manifest_data is None: raise ValueError("Downloaded extension archive is missing 'extension.yml'") diff --git a/src/specify_cli/presets/__init__.py b/src/specify_cli/presets/__init__.py index 66f1bbc5e5..c7dbd1b0e0 100644 --- a/src/specify_cli/presets/__init__.py +++ b/src/specify_cli/presets/__init__.py @@ -12,7 +12,6 @@ import hashlib import os import tempfile -import zipfile import shutil from dataclasses import dataclass from pathlib import Path @@ -27,6 +26,12 @@ from packaging import version as pkg_version from packaging.specifiers import SpecifierSet, InvalidSpecifier +from .._download_security import ( + MAX_JSON_CATALOG_BYTES, + read_response_limited, + safe_extract_zip, + verify_sha256, +) from ..extensions import REINSTALL_COMMAND, ExtensionRegistry, normalize_priority from .._init_options import is_ai_skills_enabled from ..integrations.base import IntegrationBase @@ -1646,18 +1651,7 @@ def install_from_zip( with tempfile.TemporaryDirectory() as tmpdir: temp_path = Path(tmpdir) - with zipfile.ZipFile(zip_path, 'r') as zf: - temp_path_resolved = temp_path.resolve() - for member in zf.namelist(): - member_path = (temp_path / member).resolve() - try: - member_path.relative_to(temp_path_resolved) - except ValueError: - raise PresetValidationError( - f"Unsafe path in ZIP archive: {member} " - "(potential path traversal)" - ) - zf.extractall(temp_path) + safe_extract_zip(zip_path, temp_path, error_type=PresetValidationError) pack_dir = temp_path manifest_path = pack_dir / "preset.yml" @@ -1852,17 +1846,17 @@ def _validate_catalog_url(self, url: str) -> None: from urllib.parse import urlparse parsed = urlparse(url) + if not parsed.hostname: + raise PresetValidationError( + "Catalog URL must be a valid URL with a host." + ) is_localhost = parsed.hostname in ("localhost", "127.0.0.1", "::1") if parsed.scheme != "https" and not ( parsed.scheme == "http" and is_localhost ): raise PresetValidationError( f"Catalog URL must use HTTPS (got {parsed.scheme}://). " - "HTTP is only allowed for localhost." - ) - if not parsed.netloc: - raise PresetValidationError( - "Catalog URL must be a valid URL with a host." + "HTTP is only allowed for localhost, 127.0.0.1, and ::1." ) def _make_request(self, url: str): @@ -2163,7 +2157,14 @@ def _fetch_single_catalog(self, entry: PresetCatalogEntry, force_refresh: bool = try: with self._open_url(entry.url, timeout=10) as response: - catalog_data = json.loads(response.read()) + catalog_data = json.loads( + read_response_limited( + response, + max_bytes=MAX_JSON_CATALOG_BYTES, + error_type=PresetError, + label=f"preset catalog {entry.url}", + ) + ) self._validate_catalog_payload(catalog_data, entry.url) @@ -2314,7 +2315,14 @@ def fetch_catalog(self, force_refresh: bool = False) -> Dict[str, Any]: try: with self._open_url(catalog_url, timeout=10) as response: - catalog_data = json.loads(response.read()) + catalog_data = json.loads( + read_response_limited( + response, + max_bytes=MAX_JSON_CATALOG_BYTES, + error_type=PresetError, + label=f"preset catalog {catalog_url}", + ) + ) # Validate catalog structure. Reuses the same helper as # ``_fetch_single_catalog`` so all three branches (root type, @@ -2503,7 +2511,18 @@ def download_pack( try: with self._open_url(download_url, timeout=60, extra_headers=extra_headers) as response: - zip_data = response.read() + zip_data = read_response_limited( + response, + error_type=PresetError, + label=f"preset '{pack_id}' download", + ) + + verify_sha256( + zip_data, + pack_info.get("sha256"), + error_type=PresetError, + label=f"preset '{pack_id}' download", + ) zip_path.write_bytes(zip_data) return zip_path diff --git a/src/specify_cli/presets/_commands.py b/src/specify_cli/presets/_commands.py index 682bfe919d..4e29c949c9 100644 --- a/src/specify_cli/presets/_commands.py +++ b/src/specify_cli/presets/_commands.py @@ -101,53 +101,43 @@ def preset_add( elif from_url: # Validate URL scheme before downloading - from ipaddress import ip_address - from urllib.parse import urlparse as _urlparse - - _parsed = _urlparse(from_url) - - def _is_allowed_download_url(parsed_url): - host = parsed_url.hostname - if not host: - return False - is_loopback = host == "localhost" - if not is_loopback: - try: - is_loopback = ip_address(host).is_loopback - except ValueError: - # Host is not an IP literal (e.g., a regular hostname); treat as non-loopback. - pass - return parsed_url.scheme == "https" or (parsed_url.scheme == "http" and is_loopback) + from specify_cli._download_security import is_https_or_localhost_http def _validate_download_redirect(old_url, new_url): - if not _is_allowed_download_url(_urlparse(new_url)): + if not is_https_or_localhost_http(new_url): import urllib.error raise urllib.error.URLError( "redirect target must use HTTPS with a hostname, " - "or HTTP for localhost/loopback" + "or HTTP for localhost (127.0.0.1, ::1)" ) - if not _is_allowed_download_url(_parsed): + if not is_https_or_localhost_http(from_url): console.print( - "[red]Error:[/red] URL must use HTTPS with a hostname, " - "or HTTP for localhost/loopback." + "[red]Error:[/red] URL must use HTTPS with a hostname and be " + "a valid URL with a host. HTTP is only allowed for localhost, " + "127.0.0.1, and ::1." ) raise typer.Exit(1) console.print(f"Installing preset from [cyan]{from_url}[/cyan]...") import urllib.error import tempfile - import shutil with tempfile.TemporaryDirectory() as tmpdir: zip_path = Path(tmpdir) / "preset.zip" try: + from functools import partial + + from specify_cli._download_security import read_response_limited from specify_cli.authentication.http import open_url as _open_url from specify_cli._github_http import resolve_github_release_asset_api_url _preset_extra_headers = None - _resolved_from_url = resolve_github_release_asset_api_url(from_url, _open_url) + _resolved_from_url = resolve_github_release_asset_api_url( + from_url, + partial(_open_url, strict_redirects=True), + ) if _resolved_from_url: from_url = _resolved_from_url _preset_extra_headers = {"Accept": "application/octet-stream"} @@ -159,19 +149,25 @@ def _validate_download_redirect(old_url, new_url): redirect_validator=_validate_download_redirect, ) as response: final_url = response.geturl() if hasattr(response, "geturl") else from_url - if not _is_allowed_download_url(_urlparse(final_url)): + if not is_https_or_localhost_http(final_url): console.print( "[red]Error:[/red] Preset URL redirected to a disallowed URL: " f"{final_url}. Redirect targets must use HTTPS with a hostname, " - "or HTTP for localhost/loopback." + "or HTTP for localhost (127.0.0.1, ::1)." ) raise typer.Exit(1) - with zip_path.open("wb") as output: - try: - shutil.copyfileobj(response, output) - except TypeError: - output.write(response.read()) - except urllib.error.URLError as e: + zip_path.write_bytes( + read_response_limited( + response, + error_type=PresetError, + label=f"preset {from_url}", + ) + ) + # The URL scheme is validated above, so the only failures here + # are network errors and an oversized body (raised as PresetError + # via error_type). Catching those specifically lets unrelated + # ValueErrors surface instead of masquerading as download errors. + except (urllib.error.URLError, PresetError) as e: console.print(f"[red]Error:[/red] Failed to download: {e}") raise typer.Exit(1) diff --git a/tests/http_helpers.py b/tests/http_helpers.py index 46e26806b4..00f549e397 100644 --- a/tests/http_helpers.py +++ b/tests/http_helpers.py @@ -1,15 +1,46 @@ """HTTP test helpers shared by version-related CLI tests.""" +import io import json +import urllib.request from unittest.mock import MagicMock +import pytest + def mock_urlopen_response(payload: dict) -> MagicMock: """Build a urlopen context-manager mock whose read returns JSON.""" body = json.dumps(payload).encode("utf-8") resp = MagicMock() - resp.read.return_value = body + resp.read.side_effect = io.BytesIO(body).read cm = MagicMock() cm.__enter__.return_value = resp cm.__exit__.return_value = False return cm + + +@pytest.fixture(autouse=True) +def route_opener_open_through_urlopen(monkeypatch): + """Route build_opener().open through urllib.request.urlopen. + + ``open_url(..., strict_redirects=True)`` fetches via + ``build_opener(...).open()``, which bypasses ``urllib.request.urlopen`` + — and with it the urlopen patches these test modules are built on. + Delegating ``open()`` to urlopen at call time keeps those patches + effective; the redirect handler's own behavior is covered by + ``TestRedirectStripping`` in test_authentication.py. + + Import this fixture into a test module to activate it there. + """ + + class _UrlopenDelegatingOpener: + def open(self, req, data=None, timeout=None): + if data is None: + return urllib.request.urlopen(req, timeout=timeout) + return urllib.request.urlopen(req, data=data, timeout=timeout) + + monkeypatch.setattr( + urllib.request, + "build_opener", + lambda *handlers: _UrlopenDelegatingOpener(), + ) diff --git a/tests/integrations/test_integration_catalog.py b/tests/integrations/test_integration_catalog.py index fae9e32d23..5c50591d17 100644 --- a/tests/integrations/test_integration_catalog.py +++ b/tests/integrations/test_integration_catalog.py @@ -984,7 +984,7 @@ def test_add_catalog_accepts_numeric_string_priority(self, tmp_path, monkeypatch ("bad_url", "reason"), [ ("http://insecure.example.com/catalog.json", "HTTPS"), - (123, "HTTPS"), + (123, "valid URL with a host"), ], ) def test_add_catalog_rejects_existing_entry_with_bad_url( diff --git a/tests/self_upgrade_helpers.py b/tests/self_upgrade_helpers.py index c363f57b13..fc0f339f92 100644 --- a/tests/self_upgrade_helpers.py +++ b/tests/self_upgrade_helpers.py @@ -18,7 +18,7 @@ _verify_upgrade, ) from tests.conftest import strip_ansi -from tests.http_helpers import mock_urlopen_response +from tests.http_helpers import mock_urlopen_response, route_opener_open_through_urlopen __all__ = ( "SENTINEL_GH_TOKEN", @@ -31,6 +31,7 @@ "_verify_upgrade", "mock_urlopen_response", "requires_posix", + "route_opener_open_through_urlopen", "runner", "strip_ansi", ) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 8b09245384..cce3ad9a7b 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -14,6 +14,7 @@ from __future__ import annotations import base64 +import io import json import os @@ -497,10 +498,15 @@ def test_resolve_token_azure_ad_success(self, monkeypatch): tenant_id="tid", client_id="cid", client_secret_env="MY_SECRET", ) mock_resp = MagicMock() - mock_resp.read.return_value = b'{"access_token": "ad-acquired-token"}' + mock_resp.read.side_effect = io.BytesIO(b'{"access_token": "ad-acquired-token"}').read mock_resp.__enter__ = lambda s: s mock_resp.__exit__ = MagicMock(return_value=False) - with patch("urllib.request.urlopen", return_value=mock_resp): + # The token request goes through a strict-redirect opener (so a 307/308 + # cannot forward the client_secret body to a non-HTTPS host), not bare + # urlopen; patch the opener it builds. + mock_opener = MagicMock() + mock_opener.open.return_value = mock_resp + with patch("urllib.request.build_opener", return_value=mock_opener): assert AzureDevOpsAuth().resolve_token(entry) == "ad-acquired-token" def test_resolve_token_azure_ad_missing_secret_returns_none(self, monkeypatch): @@ -793,17 +799,18 @@ def test_redirect_outside_hosts_strips_auth(self): assert new_req.headers.get("Authorization") is None assert new_req.unredirected_hdrs.get("Authorization") is None - def test_https_to_http_same_host_redirect_strips_auth(self): + def test_https_to_http_same_host_redirect_rejected(self): from specify_cli.authentication.http import _StripAuthOnRedirect from urllib.request import Request import io + import urllib.error + handler = _StripAuthOnRedirect(("github.com",)) req = Request("https://github.com/org/repo", headers={"Authorization": "Bearer tok"}) - new_req = handler.redirect_request(req, io.BytesIO(b""), 302, "Found", {}, - "http://github.com/org/repo") - assert new_req is not None - assert new_req.headers.get("Authorization") is None - assert new_req.unredirected_hdrs.get("Authorization") is None + + with pytest.raises(urllib.error.URLError, match="unsafe redirect"): + handler.redirect_request(req, io.BytesIO(b""), 302, "Found", {}, + "http://github.com/org/repo") def test_redirect_validator_can_reject_before_following_redirect(self): import urllib.error @@ -845,6 +852,18 @@ def test_multi_hop_redirect_within_hosts_preserves_auth(self): auth3 = req3.get_header("Authorization") or req3.unredirected_hdrs.get("Authorization") assert auth3 == "Bearer tok" + def test_redirect_rejects_https_downgrade(self): + """HTTPS downloads must not follow redirects to non-local HTTP URLs.""" + from specify_cli.authentication.http import _StripAuthOnRedirect + from urllib.request import Request + import io + import urllib.error + handler = _StripAuthOnRedirect(("example.com",)) + req = Request("https://example.com/archive.zip") + with pytest.raises(urllib.error.URLError, match="unsafe redirect"): + handler.redirect_request(req, io.BytesIO(b""), 302, "Found", {}, + "http://evil.example.com/archive.zip") + # --------------------------------------------------------------------------- # _fetch_latest_release_tag delegation @@ -864,7 +883,7 @@ def side_effect(req, timeout=None): captured["request"] = req body = _json.dumps({"tag_name": "v9.9.9"}).encode() resp = MagicMock() - resp.read.return_value = body + resp.read.side_effect = io.BytesIO(body).read cm = MagicMock() cm.__enter__.return_value = resp cm.__exit__.return_value = False @@ -884,19 +903,25 @@ def test_gh_token_forwarded_when_configured(self, monkeypatch): assert captured["request"].get_header("Authorization") == "Bearer forwarded-sentinel" def test_no_config_means_no_auth(self, monkeypatch): - from unittest.mock import patch + from unittest.mock import MagicMock, patch from specify_cli._version import _fetch_latest_release_tag self._set_config(monkeypatch, []) captured, side_effect = self._capture_request() - with patch("specify_cli.authentication.http.urllib.request.urlopen", side_effect=side_effect): + # The release fetch uses strict_redirects=True, so the unauthenticated + # path goes through build_opener().open(), not urlopen. + mock_opener = MagicMock() + mock_opener.open.side_effect = side_effect + with patch("specify_cli.authentication.http.urllib.request.build_opener", return_value=mock_opener): _fetch_latest_release_tag() assert captured["request"].get_header("Authorization") is None def test_accept_header_present(self, monkeypatch): - from unittest.mock import patch + from unittest.mock import MagicMock, patch from specify_cli._version import _fetch_latest_release_tag self._set_config(monkeypatch, []) captured, side_effect = self._capture_request() - with patch("specify_cli.authentication.http.urllib.request.urlopen", side_effect=side_effect): + mock_opener = MagicMock() + mock_opener.open.side_effect = side_effect + with patch("specify_cli.authentication.http.urllib.request.build_opener", return_value=mock_opener): _fetch_latest_release_tag() assert captured["request"].get_header("Accept") == "application/vnd.github+json" diff --git a/tests/test_download_security.py b/tests/test_download_security.py new file mode 100644 index 0000000000..3b06ff9491 --- /dev/null +++ b/tests/test_download_security.py @@ -0,0 +1,227 @@ +"""Tests for bounded download and ZIP extraction helpers.""" + +from __future__ import annotations + +import stat +import zipfile + +import pytest + +from specify_cli._download_security import ( + is_https_or_localhost_http, + read_response_limited, + read_zip_member_limited, + safe_extract_zip, + verify_sha256, +) + + +@pytest.mark.parametrize( + "url, allowed", + [ + ("https://example.com/preset.zip", True), + ("http://localhost:8000/preset.zip", True), + ("http://127.0.0.1/preset.zip", True), + ("http://[::1]/preset.zip", True), + # Non-loopback HTTP is rejected. + ("http://example.com/preset.zip", False), + # Loopback allowance is an exact-string match: 127.0.0.2 is not covered. + ("http://127.0.0.2/preset.zip", False), + # A hostname is always required, even for HTTPS. + ("https:///preset.zip", False), + ("https://", False), + ], +) +def test_is_https_or_localhost_http(url, allowed): + assert is_https_or_localhost_http(url) is allowed + + +class _Response: + """Faithful stream stand-in: read() advances a cursor and returns b"" at EOF.""" + + def __init__(self, data: bytes, *, chunk: int | None = None): + self.data = data + self.pos = 0 + self.chunk = chunk + + def read(self, size: int = -1) -> bytes: + if size < 0: + size = len(self.data) - self.pos + if self.chunk is not None: + size = min(size, self.chunk) + out = self.data[self.pos : self.pos + size] + self.pos += len(out) + return out + + +class _CustomZipError(ValueError): + pass + + +def test_read_response_limited_rejects_oversized_download(): + with pytest.raises(ValueError, match="exceeds maximum size"): + read_response_limited(_Response(b"abcde"), max_bytes=4) + + +def test_read_response_limited_returns_full_body_within_limit(): + assert read_response_limited(_Response(b"abcde"), max_bytes=10) == b"abcde" + + +def test_read_response_limited_enforces_bound_under_short_reads(): + response = _Response(b"x" * 100, chunk=8) + with pytest.raises(ValueError, match="exceeds maximum size"): + read_response_limited(response, max_bytes=16) + + +def test_verify_sha256_rejects_mismatch(): + with pytest.raises(ValueError, match="checksum mismatch"): + verify_sha256(b"payload", "sha256:" + "0" * 64) + + +@pytest.mark.parametrize( + "member_name", + [ + "../evil.txt", + "nested/../../evil.txt", + "nested\\..\\evil.txt", + "C:\\Windows\\evil.txt", + "C:drive-relative.txt", + ], +) +def test_safe_extract_zip_rejects_traversal(tmp_path, member_name): + zip_path = tmp_path / "bad.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr(member_name, "nope") + + with pytest.raises(ValueError, match="Unsafe path"): + safe_extract_zip(zip_path, tmp_path / "out") + + +@pytest.mark.parametrize("member_name", [".", "./file.txt", "nested/./file.txt", "nested//file.txt"]) +def test_safe_extract_zip_rejects_dot_path_segments(tmp_path, member_name): + zip_path = tmp_path / "bad.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr(member_name, "nope") + + with pytest.raises(_CustomZipError, match="Unsafe path"): + safe_extract_zip(zip_path, tmp_path / "out", error_type=_CustomZipError) + + +def test_safe_extract_zip_rejects_symlinks(tmp_path): + zip_path = tmp_path / "bad.zip" + info = zipfile.ZipInfo("link") + info.external_attr = (stat.S_IFLNK | 0o777) << 16 + + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr(info, "target") + + with pytest.raises(ValueError, match="Unsafe symlink"): + safe_extract_zip(zip_path, tmp_path / "out") + + +def test_safe_extract_zip_rejects_symlink_without_partial_extraction(tmp_path): + zip_path = tmp_path / "mixed.zip" + link = zipfile.ZipInfo("evil-link") + link.external_attr = (stat.S_IFLNK | 0o777) << 16 + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("safe/first.txt", "hello") + zf.writestr(link, "target") + zf.writestr("safe/second.txt", "world") + + out_dir = tmp_path / "out" + with pytest.raises(ValueError, match="Unsafe symlink"): + safe_extract_zip(zip_path, out_dir) + + assert not out_dir.exists() or not any(out_dir.rglob("*")) + + +def test_safe_extract_zip_rejects_oversized_member(tmp_path): + zip_path = tmp_path / "bad.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("big.txt", "abcde") + + with pytest.raises(ValueError, match="exceeds maximum size"): + safe_extract_zip(zip_path, tmp_path / "out", max_member_bytes=4) + + +def test_safe_extract_zip_rejects_too_many_entries(tmp_path): + zip_path = tmp_path / "bad.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("one.txt", "1") + zf.writestr("two.txt", "2") + + with pytest.raises(ValueError, match="too many entries"): + safe_extract_zip(zip_path, tmp_path / "out", max_entries=1) + + +def test_safe_extract_zip_rejects_total_uncompressed_size(tmp_path): + zip_path = tmp_path / "bad.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("one.txt", "123") + zf.writestr("two.txt", "456") + + with pytest.raises(ValueError, match="maximum uncompressed size"): + safe_extract_zip(zip_path, tmp_path / "out", max_total_bytes=5) + + +def test_safe_extract_zip_wraps_bad_zip_file(tmp_path): + zip_path = tmp_path / "bad.zip" + zip_path.write_bytes(b"not a zip archive") + + with pytest.raises(_CustomZipError, match="Invalid ZIP archive"): + safe_extract_zip(zip_path, tmp_path / "out", error_type=_CustomZipError) + + +def test_read_zip_member_limited_returns_member_within_limit(tmp_path): + zip_path = tmp_path / "ok.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("extension.yml", "extension:\n id: demo\n") + + with zipfile.ZipFile(zip_path, "r") as zf: + data = read_zip_member_limited(zf, "extension.yml") + + assert data == b"extension:\n id: demo\n" + + +def test_read_zip_member_limited_rejects_oversized_member(tmp_path): + zip_path = tmp_path / "bomb.zip" + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + zf.writestr("extension.yml", "a" * 5000) + + with zipfile.ZipFile(zip_path, "r") as zf: + with pytest.raises(ValueError, match="exceeds maximum size"): + read_zip_member_limited(zf, "extension.yml", max_bytes=16) + + +def test_read_zip_member_limited_wraps_missing_member(tmp_path): + zip_path = tmp_path / "ok.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("other.txt", "x") + + with zipfile.ZipFile(zip_path, "r") as zf: + with pytest.raises(_CustomZipError, match="ZIP member not found"): + read_zip_member_limited(zf, "extension.yml", error_type=_CustomZipError) + + +def test_safe_extract_zip_extracts_safe_archive(tmp_path): + zip_path = tmp_path / "ok.zip" + out_dir = tmp_path / "out" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("nested/file.txt", "hello") + + safe_extract_zip(zip_path, out_dir) + + assert (out_dir / "nested" / "file.txt").read_text(encoding="utf-8") == "hello" + + +def test_safe_extract_zip_treats_normalized_trailing_backslash_as_directory(tmp_path): + zip_path = tmp_path / "ok.zip" + out_dir = tmp_path / "out" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("nested\\", "") + zf.writestr("nested/file.txt", "hello") + + safe_extract_zip(zip_path, out_dir) + + assert (out_dir / "nested").is_dir() + assert (out_dir / "nested" / "file.txt").read_text(encoding="utf-8") == "hello" diff --git a/tests/test_extensions.py b/tests/test_extensions.py index c60a7e430f..9ae795848c 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -9,6 +9,7 @@ - Catalog stack (multi-catalog support) """ +import io import pytest import json import os @@ -3209,7 +3210,7 @@ def test_fetch_single_catalog_sends_auth_header(self, temp_dir, monkeypatch): catalog_data = {"schema_version": "1.0", "extensions": {}} mock_response = MagicMock() - mock_response.read.return_value = json.dumps(catalog_data).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(catalog_data).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) mock_response.geturl.return_value = "https://raw.githubusercontent.com/org/repo/main/catalog.json" @@ -3265,7 +3266,7 @@ def test_fetch_single_catalog_rejects_malformed_payload(self, temp_dir, payload) catalog = self._make_catalog(temp_dir) mock_response = MagicMock() - mock_response.read.return_value = json.dumps(payload).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(payload).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -3333,7 +3334,7 @@ def test_fetch_single_catalog_rejects_malformed_cached_payload( "extensions": {"foo": {"name": "Foo", "version": "1.0.0"}}, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(valid).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(valid).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -3380,7 +3381,7 @@ def test_fetch_catalog_rejects_malformed_payload(self, temp_dir, payload): catalog = self._make_catalog(temp_dir) mock_response = MagicMock() - mock_response.read.return_value = json.dumps(payload).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(payload).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -3420,7 +3421,7 @@ def test_fetch_catalog_recovers_from_unreadable_cache(self, temp_dir): "extensions": {"foo": {"name": "Foo", "version": "1.0.0"}}, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(valid).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(valid).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -3458,7 +3459,7 @@ def test_fetch_catalog_recovers_from_unreadable_metadata(self, temp_dir): "extensions": {"foo": {"name": "Foo", "version": "1.0.0"}}, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(valid).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(valid).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -3532,7 +3533,7 @@ def test_fetch_catalog_writes_cache_as_utf8(self, temp_dir, monkeypatch): "extensions": {"foo": {"name": "Foo", "version": "1.0.0"}}, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(payload).encode("utf-8") + mock_response.read.side_effect = io.BytesIO(json.dumps(payload).encode("utf-8")).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -3582,10 +3583,12 @@ def test_fetch_catalog_survives_unwritable_cache(self, temp_dir, monkeypatch): "schema_version": "1.0", "extensions": {"foo": {"name": "Foo", "version": "1.0.0"}}, } - mock_response = MagicMock() - mock_response.read.return_value = json.dumps(valid).encode() - mock_response.__enter__ = lambda s: s - mock_response.__exit__ = MagicMock(return_value=False) + def make_response(): + mock_response = MagicMock() + mock_response.read.side_effect = io.BytesIO(json.dumps(valid).encode()).read + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + return mock_response # Simulate an unwritable cache dir: every write_text under the # cache directory raises PermissionError (an OSError subclass). @@ -3598,7 +3601,7 @@ def failing_write_text(self, data, *args, **kwargs): monkeypatch.setattr(_PathCls, "write_text", failing_write_text) - with patch.object(catalog, "_open_url", return_value=mock_response): + with patch.object(catalog, "_open_url", side_effect=lambda *a, **kw: make_response()): # Legacy single-catalog path. assert catalog.fetch_catalog(force_refresh=True) == valid @@ -3634,7 +3637,7 @@ def test_get_merged_extensions_skips_non_mapping_entries(self, temp_dir): }, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(payload).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(payload).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -3670,7 +3673,7 @@ def test_download_extension_sends_auth_header(self, temp_dir, monkeypatch): zip_bytes = zip_buf.getvalue() release_response = MagicMock() - release_response.read.return_value = json.dumps( + release_response.read.side_effect = io.BytesIO(json.dumps( { "assets": [ { @@ -3679,12 +3682,12 @@ def test_download_extension_sends_auth_header(self, temp_dir, monkeypatch): } ] } - ).encode() + ).encode()).read release_response.__enter__ = lambda s: s release_response.__exit__ = MagicMock(return_value=False) asset_response = MagicMock() - asset_response.read.return_value = zip_bytes + asset_response.read.side_effect = io.BytesIO(zip_bytes).read asset_response.__enter__ = lambda s: s asset_response.__exit__ = MagicMock(return_value=False) @@ -3732,7 +3735,7 @@ def test_download_extension_accepts_direct_github_rest_asset_url(self, temp_dir, zip_bytes = zip_buf.getvalue() asset_response = MagicMock() - asset_response.read.return_value = zip_bytes + asset_response.read.side_effect = io.BytesIO(zip_bytes).read asset_response.__enter__ = lambda s: s asset_response.__exit__ = MagicMock(return_value=False) @@ -4073,7 +4076,7 @@ def test_load_catalog_config_defaults_blank_names(self, temp_dir): @pytest.mark.parametrize( ("url", "expected_detail"), [ - ("relative/catalog.json", "HTTPS"), + ("relative/catalog.json", "valid URL with a host"), ("https:///no-host", "valid URL with a host"), ], ) @@ -5327,7 +5330,7 @@ def test_download_extension_allows_bundled_with_url(self, temp_dir): } mock_response = MagicMock() - mock_response.read.return_value = b"fake zip data" + mock_response.read.side_effect = io.BytesIO(b"fake zip data").read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) diff --git a/tests/test_github_http.py b/tests/test_github_http.py index e258f4917f..0fb82b5b99 100644 --- a/tests/test_github_http.py +++ b/tests/test_github_http.py @@ -1,16 +1,20 @@ """Tests for GitHub-authenticated HTTP request helpers.""" +import io import json import os from contextlib import contextmanager from unittest.mock import MagicMock, patch +from urllib.request import Request import pytest from specify_cli._github_http import ( + GITHUB_HOSTS, build_github_request, resolve_github_release_asset_api_url, ) +from specify_cli.authentication.http import _StripAuthOnRedirect class TestBuildGitHubRequest: @@ -90,7 +94,7 @@ def _make_open_url_fn(self, release_json): @contextmanager def fake_open(url, timeout=None, extra_headers=None): resp = MagicMock() - resp.read.return_value = json.dumps(release_json).encode() + resp.read.side_effect = io.BytesIO(json.dumps(release_json).encode()).read yield resp return fake_open @@ -144,7 +148,7 @@ def test_returns_none_on_network_error(self): @contextmanager def failing_open(url, timeout=None, extra_headers=None): raise urllib.error.URLError("network error") - yield # noqa: unreachable + yield # pragma: no cover result = resolve_github_release_asset_api_url( "https://github.com/org/repo/releases/download/v1/pack.zip", @@ -160,7 +164,7 @@ def test_tag_with_special_characters_is_url_encoded(self): def capturing_open(url, timeout=None, extra_headers=None): captured_urls.append(url) resp = MagicMock() - resp.read.return_value = json.dumps({"assets": []}).encode() + resp.read.side_effect = io.BytesIO(json.dumps({"assets": []}).encode()).read yield resp resolve_github_release_asset_api_url( @@ -179,7 +183,7 @@ def test_tag_with_hash_is_url_encoded(self): def capturing_open(url, timeout=None, extra_headers=None): captured_urls.append(url) resp = MagicMock() - resp.read.return_value = json.dumps({"assets": []}).encode() + resp.read.side_effect = io.BytesIO(json.dumps({"assets": []}).encode()).read yield resp resolve_github_release_asset_api_url( @@ -188,3 +192,43 @@ def capturing_open(url, timeout=None, extra_headers=None): ) assert len(captured_urls) == 1 assert "releases/tags/v1%23beta" in captured_urls[0] + + +class TestGitHubRedirectAuth: + """Tests for GitHub-owned redirect auth handling.""" + + def test_multi_hop_github_redirect_preserves_unredirected_auth(self): + """Auth survives a multi-hop redirect chain within GitHub hosts.""" + handler = _StripAuthOnRedirect(tuple(GITHUB_HOSTS)) + req1 = Request( + "https://github.com/org/repo", + headers={"Authorization": "Bearer tok"}, + ) + + req2 = handler.redirect_request( + req1, + io.BytesIO(b""), + 302, + "Found", + {}, + "https://codeload.github.com/org/repo/zip", + ) + assert req2 is not None + auth2 = req2.get_header("Authorization") or req2.unredirected_hdrs.get( + "Authorization" + ) + assert auth2 == "Bearer tok" + + req3 = handler.redirect_request( + req2, + io.BytesIO(b""), + 302, + "Found", + {}, + "https://raw.githubusercontent.com/org/repo/main/file", + ) + assert req3 is not None + auth3 = req3.get_header("Authorization") or req3.unredirected_hdrs.get( + "Authorization" + ) + assert auth3 == "Bearer tok" diff --git a/tests/test_presets.py b/tests/test_presets.py index 58574bbc9c..fee20d417a 100644 --- a/tests/test_presets.py +++ b/tests/test_presets.py @@ -1516,7 +1516,7 @@ def test_fetch_single_catalog_sends_auth_header(self, project_dir, monkeypatch): catalog_data = {"schema_version": "1.0", "presets": {}} mock_response = MagicMock() - mock_response.read.return_value = json.dumps(catalog_data).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(catalog_data).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) mock_response.geturl.return_value = "https://raw.githubusercontent.com/org/repo/main/presets/catalog.json" @@ -1572,7 +1572,7 @@ def test_fetch_single_catalog_rejects_malformed_payload(self, project_dir, paylo catalog = PresetCatalog(project_dir) mock_response = MagicMock() - mock_response.read.return_value = json.dumps(payload).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(payload).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -1641,7 +1641,7 @@ def test_fetch_single_catalog_rejects_malformed_cached_payload( "presets": {"foo": {"name": "Foo", "version": "1.0.0"}}, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(valid).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(valid).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -1688,7 +1688,7 @@ def test_fetch_catalog_rejects_malformed_payload(self, project_dir, payload): catalog = PresetCatalog(project_dir) mock_response = MagicMock() - mock_response.read.return_value = json.dumps(payload).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(payload).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -1729,7 +1729,7 @@ def test_fetch_catalog_recovers_from_unreadable_cache(self, project_dir): "presets": {"foo": {"name": "Foo", "version": "1.0.0"}}, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(valid).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(valid).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -1767,7 +1767,7 @@ def test_fetch_catalog_recovers_from_unreadable_metadata(self, project_dir): "presets": {"foo": {"name": "Foo", "version": "1.0.0"}}, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(valid).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(valid).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -1837,7 +1837,7 @@ def test_fetch_catalog_writes_cache_as_utf8(self, project_dir, monkeypatch): "presets": {"foo": {"name": "Foo", "version": "1.0.0"}}, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(payload).encode("utf-8") + mock_response.read.side_effect = io.BytesIO(json.dumps(payload).encode("utf-8")).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -1885,10 +1885,12 @@ def test_fetch_catalog_survives_unwritable_cache(self, project_dir, monkeypatch) "schema_version": "1.0", "presets": {"foo": {"name": "Foo", "version": "1.0.0"}}, } - mock_response = MagicMock() - mock_response.read.return_value = json.dumps(valid).encode() - mock_response.__enter__ = lambda s: s - mock_response.__exit__ = MagicMock(return_value=False) + def make_response(): + mock_response = MagicMock() + mock_response.read.side_effect = io.BytesIO(json.dumps(valid).encode()).read + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + return mock_response # Simulate an unwritable cache dir: every write_text under the # cache directory raises PermissionError (an OSError subclass). @@ -1901,7 +1903,7 @@ def failing_write_text(self, data, *args, **kwargs): monkeypatch.setattr(_PathCls, "write_text", failing_write_text) - with patch.object(catalog, "_open_url", return_value=mock_response): + with patch.object(catalog, "_open_url", side_effect=lambda *a, **kw: make_response()): # Legacy single-catalog path. assert catalog.fetch_catalog(force_refresh=True) == valid @@ -1938,7 +1940,7 @@ def test_get_merged_packs_skips_non_mapping_entries(self, project_dir): }, } mock_response = MagicMock() - mock_response.read.return_value = json.dumps(payload).encode() + mock_response.read.side_effect = io.BytesIO(json.dumps(payload).encode()).read mock_response.__enter__ = lambda s: s mock_response.__exit__ = MagicMock(return_value=False) @@ -1972,7 +1974,7 @@ def test_download_pack_sends_auth_header(self, project_dir, monkeypatch): zip_bytes = zip_buf.getvalue() release_response = MagicMock() - release_response.read.return_value = json.dumps( + release_response.read.side_effect = io.BytesIO(json.dumps( { "assets": [ { @@ -1981,12 +1983,12 @@ def test_download_pack_sends_auth_header(self, project_dir, monkeypatch): } ] } - ).encode() + ).encode()).read release_response.__enter__ = lambda s: s release_response.__exit__ = MagicMock(return_value=False) asset_response = MagicMock() - asset_response.read.return_value = zip_bytes + asset_response.read.side_effect = io.BytesIO(zip_bytes).read asset_response.__enter__ = lambda s: s asset_response.__exit__ = MagicMock(return_value=False) @@ -2034,7 +2036,7 @@ def test_download_pack_accepts_direct_github_rest_asset_url(self, project_dir, m zip_bytes = zip_buf.getvalue() asset_response = MagicMock() - asset_response.read.return_value = zip_bytes + asset_response.read.side_effect = io.BytesIO(zip_bytes).read asset_response.__enter__ = lambda s: s asset_response.__exit__ = MagicMock(return_value=False) @@ -4576,10 +4578,10 @@ def test_preset_add_from_github_release_url_resolves_and_downloads(self, project class FakeResponse: def __init__(self, data): - self._data = data + self._stream = io.BytesIO(data) - def read(self): - return self._data + def read(self, size=-1): + return self._stream.read(size) def __enter__(self): return self @@ -4587,7 +4589,9 @@ def __enter__(self): def __exit__(self, *a): return False - def fake_open_url(url, timeout=None, extra_headers=None, redirect_validator=None): + def fake_open_url( + url, timeout=None, extra_headers=None, redirect_validator=None, strict_redirects=False + ): captured_urls.append((url, extra_headers)) if "releases/tags/" in url: return FakeResponse(json.dumps({ @@ -4634,10 +4638,10 @@ def test_preset_add_from_direct_api_asset_url_passes_through(self, project_dir): class FakeResponse: def __init__(self, data): - self._data = data + self._stream = io.BytesIO(data) - def read(self): - return self._data + def read(self, size=-1): + return self._stream.read(size) def __enter__(self): return self @@ -4645,7 +4649,9 @@ def __enter__(self): def __exit__(self, *a): return False - def fake_open_url(url, timeout=None, extra_headers=None, redirect_validator=None): + def fake_open_url( + url, timeout=None, extra_headers=None, redirect_validator=None, strict_redirects=False + ): captured_urls.append((url, extra_headers)) return FakeResponse(zip_bytes) diff --git a/tests/test_self_upgrade_detection.py b/tests/test_self_upgrade_detection.py index ab575e7435..73b55ebb79 100644 --- a/tests/test_self_upgrade_detection.py +++ b/tests/test_self_upgrade_detection.py @@ -13,6 +13,7 @@ from specify_cli import app from tests.self_upgrade_helpers import ( + route_opener_open_through_urlopen, # noqa: F401 (autouse fixture) _InstallMethod, _assemble_installer_argv, _completed_process, diff --git a/tests/test_self_upgrade_execution.py b/tests/test_self_upgrade_execution.py index 6696b4fc79..5c761014be 100644 --- a/tests/test_self_upgrade_execution.py +++ b/tests/test_self_upgrade_execution.py @@ -7,6 +7,7 @@ from specify_cli import app from tests.self_upgrade_helpers import ( + route_opener_open_through_urlopen, # noqa: F401 (autouse fixture) _completed_process, mock_urlopen_response, requires_posix, diff --git a/tests/test_self_upgrade_verification.py b/tests/test_self_upgrade_verification.py index f1a018f06c..c4e7eecf1b 100644 --- a/tests/test_self_upgrade_verification.py +++ b/tests/test_self_upgrade_verification.py @@ -8,6 +8,7 @@ from specify_cli import app from tests.self_upgrade_helpers import ( + route_opener_open_through_urlopen, # noqa: F401 (autouse fixture) SENTINEL_GH_TOKEN, SENTINEL_GITHUB_TOKEN, _InstallMethod, diff --git a/tests/test_upgrade.py b/tests/test_upgrade.py index 3ad8c84f62..6a8b069b5c 100644 --- a/tests/test_upgrade.py +++ b/tests/test_upgrade.py @@ -9,6 +9,8 @@ `--disable-socket` as an extra safety net. """ +import io +import json import urllib.error import importlib.metadata from unittest.mock import MagicMock, patch @@ -17,6 +19,7 @@ from typer.testing import CliRunner from specify_cli import app +from specify_cli._download_security import read_response_limited as _real_read_response_limited from specify_cli._version import ( _fetch_latest_release_tag, _get_installed_version, @@ -24,7 +27,10 @@ _normalize_tag, ) from tests.conftest import strip_ansi -from tests.http_helpers import mock_urlopen_response +from tests.http_helpers import ( + mock_urlopen_response, + route_opener_open_through_urlopen, # noqa: F401 (autouse fixture) +) runner = CliRunner() @@ -36,6 +42,19 @@ ) +def _mock_urlopen_response(payload: dict) -> MagicMock: + body = json.dumps(payload).encode("utf-8") + resp = MagicMock() + # Back read() with a real stream so it advances and returns b"" at EOF, + # matching http.client.HTTPResponse (a fixed return_value would loop forever + # under read_response_limited's bounded read loop). + resp.read.side_effect = io.BytesIO(body).read + cm = MagicMock() + cm.__enter__.return_value = resp + cm.__exit__.return_value = False + return cm + + def _http_error(code: int, message: str = "error") -> urllib.error.HTTPError: return urllib.error.HTTPError( url="https://api.github.com/repos/github/spec-kit/releases/latest", @@ -235,6 +254,46 @@ def test_generic_exception_propagates(self): _fetch_latest_release_tag() +class TestBoundedRead: + """Regression test for the read_response_limited hardening. + + A future refactor could silently revert `_fetch_latest_release_tag` to + `resp.read()` (the unbounded form) — this test pins the contract that + the response body is read through ``read_response_limited`` with a + bounded ``max_bytes``. + """ + + def test_response_body_is_bounded(self): + recorded: dict[str, int | str] = {} + + def _spy(response, *, max_bytes: int, label: str, **kwargs): + # max_bytes and label are keyword-only with no defaults: if the + # caller forgets to pass either, the call raises TypeError here + # (instead of recording a misleading None). + recorded["max_bytes"] = max_bytes + recorded["label"] = label + # Forward to the real implementation so the function under test + # still gets a parseable body. + return _real_read_response_limited( + response, max_bytes=max_bytes, label=label, **kwargs + ) + + with patch( + "specify_cli.authentication.http.urllib.request.urlopen", + return_value=_mock_urlopen_response({"tag_name": "v9.9.9"}), + ), patch("specify_cli._version.read_response_limited", side_effect=_spy): + tag, reason = _fetch_latest_release_tag() + + assert tag == "v9.9.9" + assert reason is None + # The cap (1 MiB) is a deliberate ceiling for the GitHub release + # JSON — keep it explicit so a future refactor that drops the + # `max_bytes=` argument fails this test instead of regressing + # silently to the default. + assert recorded["max_bytes"] == 1024 * 1024 + assert recorded["label"] == "GitHub latest release" + + _FAILURE_CASES = [ ("offline or timeout", urllib.error.URLError("down")), (_RATE_LIMITED_REASON, _http_error(403)),