From 0180c6c5b4e4ee6ea201953cee86fa6c62e7225d Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Thu, 12 Jun 2025 22:55:49 -0500 Subject: [PATCH 1/2] feat: add automatic Docker service endpoint resolution - Add endpoint resolution for Docker service names - Implement connection testing with fallback strategies - Support localhost, host.docker.internal, and bridge IP fallbacks - Add socket-based connectivity testing - Resolve S3 filesystem endpoints automatically in Docker environments --- dreadnode/main.py | 113 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 6 deletions(-) diff --git a/dreadnode/main.py b/dreadnode/main.py index 0bea356b..c0583a94 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -2,6 +2,7 @@ import inspect import os import random +import socket import typing as t from dataclasses import dataclass from datetime import datetime, timezone @@ -32,7 +33,14 @@ ENV_SERVER, ENV_SERVER_URL, ) -from dreadnode.metric import Metric, MetricAggMode, MetricDict, Scorer, ScorerCallable, T +from dreadnode.metric import ( + Metric, + MetricAggMode, + MetricDict, + Scorer, + ScorerCallable, + T, +) from dreadnode.task import P, R, Task from dreadnode.tracing.exporters import ( FileExportConfig, @@ -54,7 +62,7 @@ JsonDict, JsonValue, ) -from dreadnode.util import clean_str, handle_internal_errors +from dreadnode.util import clean_str, handle_internal_errors, logger from dreadnode.version import VERSION if t.TYPE_CHECKING: @@ -128,6 +136,91 @@ def __init__( self._initialized = False + @staticmethod + def _resolve_endpoint(endpoint): + """Automatically resolve endpoints based on environment + + Args: + endpoint: The endpoint URL to resolve. + + Returns: + str: The resolved endpoint URL. + """ + parsed = urlparse(endpoint) + + # If it's a real domain (has dots), use as-is + if "." in parsed.hostname: + return endpoint + + # If it's a service name, try to resolve it + if Dreadnode._is_docker_service_name(parsed.hostname): + return Dreadnode._resolve_docker_service(endpoint, parsed) + + return endpoint + + @staticmethod + def _is_docker_service_name(hostname): + """Check if this looks like a Docker service name + + Args: + hostname: The hostname to check. + + Returns: + bool: True if the hostname looks like a Docker service name, False otherwise. + """ + return hostname and "." not in hostname and hostname != "localhost" + + @staticmethod + def _resolve_docker_service(original_endpoint, parsed): + """Try different resolution strategies for Docker services + + Args: + original_endpoint: The original endpoint URL. + parsed: The parsed URL object. + + Returns: + str: The resolved endpoint URL. + + Raises: + RuntimeError: If no valid endpoint is found. + """ + strategies = [ + original_endpoint, # Try original first (works if running in same network) + f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost + f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop + f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP + ] + + for endpoint in strategies: + if Dreadnode._test_connection(endpoint): + logger.warning( + f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'." + ) + return endpoint + + # If nothing works, return original and let it fail with a helpful error + raise RuntimeError( + f"Failed to connect to the Dreadnode Artifact storage at {endpoint}." + ) + + @staticmethod + def _test_connection(endpoint): + """Quick connectivity test + + Args: + endpoint: The endpoint URL to test. + + Returns: + bool: True if the connection is successful, False otherwise. + """ + try: + parsed = urlparse(endpoint) + socket.create_connection((parsed.hostname, parsed.port), timeout=1) + except Exception: # noqa: BLE001 + return False + + return True + def configure( self, *, @@ -166,8 +259,12 @@ def configure( self._initialized = False - self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + self.server = ( + server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + ) + self.token = ( + token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + ) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -261,12 +358,13 @@ def initialize(self) -> None: # ) credentials = self._api.get_user_data_credentials() + resolved_endpoint = self._resolve_endpoint(credentials.endpoint) self._fs = S3FileSystem( key=credentials.access_key_id, secret=credentials.secret_access_key, token=credentials.session_token, client_kwargs={ - "endpoint_url": credentials.endpoint, + "endpoint_url": resolved_endpoint, "region_name": credentials.region, }, ) @@ -1002,7 +1100,10 @@ def log_metric( value if isinstance(value, Metric) else Metric( - float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} + float(value), + step, + timestamp or datetime.now(timezone.utc), + attributes or {}, ) ) return target.log_metric(name, metric, origin=origin, mode=mode) From 2b7badd05278536167998ed786f5eb4b4edf5b13 Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Fri, 13 Jun 2025 00:10:10 -0500 Subject: [PATCH 2/2] refactor: add type hints and improve error handling in Dreadnode - Add type annotations to _resolve_endpoint, _is_docker_service_name, _resolve_docker_service, and _test_connection methods - Add ValueError for invalid endpoint URLs in _resolve_endpoint - Improve null safety with proper None checks - Use default port 443 for socket connections when port is None - Simplify variable assignments by removing unnecessary parentheses - Add SLF001 exception to ruff config for test files to allow private member access --- dreadnode/main.py | 36 +++++++++++++++++++----------------- pyproject.toml | 1 + 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/dreadnode/main.py b/dreadnode/main.py index c0583a94..9b578877 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from urllib.parse import urljoin, urlparse, urlunparse +from urllib.parse import ParseResult, urljoin, urlparse, urlunparse import coolname # type: ignore [import-untyped] import logfire @@ -137,7 +137,7 @@ def __init__( self._initialized = False @staticmethod - def _resolve_endpoint(endpoint): + def _resolve_endpoint(endpoint: str | None) -> str | None: """Automatically resolve endpoints based on environment Args: @@ -145,10 +145,18 @@ def _resolve_endpoint(endpoint): Returns: str: The resolved endpoint URL. + + Raises: + ValueError: If the endpoint URL is invalid. """ + if not endpoint: + return None parsed = urlparse(endpoint) # If it's a real domain (has dots), use as-is + if not parsed.hostname: + raise ValueError(f"Invalid endpoint URL: {endpoint}") + if "." in parsed.hostname: return endpoint @@ -159,7 +167,7 @@ def _resolve_endpoint(endpoint): return endpoint @staticmethod - def _is_docker_service_name(hostname): + def _is_docker_service_name(hostname: str) -> bool: """Check if this looks like a Docker service name Args: @@ -168,10 +176,10 @@ def _is_docker_service_name(hostname): Returns: bool: True if the hostname looks like a Docker service name, False otherwise. """ - return hostname and "." not in hostname and hostname != "localhost" + return bool(hostname and "." not in hostname and hostname != "localhost") @staticmethod - def _resolve_docker_service(original_endpoint, parsed): + def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: """Try different resolution strategies for Docker services Args: @@ -196,15 +204,13 @@ def _resolve_docker_service(original_endpoint, parsed): logger.warning( f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'." ) - return endpoint + return str(endpoint) # If nothing works, return original and let it fail with a helpful error - raise RuntimeError( - f"Failed to connect to the Dreadnode Artifact storage at {endpoint}." - ) + raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") @staticmethod - def _test_connection(endpoint): + def _test_connection(endpoint: str) -> bool: """Quick connectivity test Args: @@ -215,7 +221,7 @@ def _test_connection(endpoint): """ try: parsed = urlparse(endpoint) - socket.create_connection((parsed.hostname, parsed.port), timeout=1) + socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1) except Exception: # noqa: BLE001 return False @@ -259,12 +265,8 @@ def configure( self._initialized = False - self.server = ( - server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - ) - self.token = ( - token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) - ) + self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) diff --git a/pyproject.toml b/pyproject.toml index c2e083f0..306e8fc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,4 +122,5 @@ skip-magic-trailing-comma = false "tests/**/*.py" = [ "INP001", # namespace not required for pytest "S101", # asserts allowed in tests... + "SLF001", # allow access to private members ]