diff --git a/dreadnode/main.py b/dreadnode/main.py index 0bea356b..9b578877 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -2,11 +2,12 @@ import inspect import os import random +import socket import typing as t from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from urllib.parse import urljoin, urlparse, urlunparse +from urllib.parse import ParseResult, urljoin, urlparse, urlunparse import coolname # type: ignore [import-untyped] import logfire @@ -32,7 +33,14 @@ ENV_SERVER, ENV_SERVER_URL, ) -from dreadnode.metric import Metric, MetricAggMode, MetricDict, Scorer, ScorerCallable, T +from dreadnode.metric import ( + Metric, + MetricAggMode, + MetricDict, + Scorer, + ScorerCallable, + T, +) from dreadnode.task import P, R, Task from dreadnode.tracing.exporters import ( FileExportConfig, @@ -54,7 +62,7 @@ JsonDict, JsonValue, ) -from dreadnode.util import clean_str, handle_internal_errors +from dreadnode.util import clean_str, handle_internal_errors, logger from dreadnode.version import VERSION if t.TYPE_CHECKING: @@ -128,6 +136,97 @@ def __init__( self._initialized = False + @staticmethod + def _resolve_endpoint(endpoint: str | None) -> str | None: + """Automatically resolve endpoints based on environment + + Args: + endpoint: The endpoint URL to resolve. + + Returns: + str: The resolved endpoint URL. + + Raises: + ValueError: If the endpoint URL is invalid. + """ + if not endpoint: + return None + parsed = urlparse(endpoint) + + # If it's a real domain (has dots), use as-is + if not parsed.hostname: + raise ValueError(f"Invalid endpoint URL: {endpoint}") + + if "." in parsed.hostname: + return endpoint + + # If it's a service name, try to resolve it + if Dreadnode._is_docker_service_name(parsed.hostname): + return Dreadnode._resolve_docker_service(endpoint, parsed) + + return endpoint + + @staticmethod + def _is_docker_service_name(hostname: str) -> bool: + """Check if this looks like a Docker service name + + Args: + hostname: The hostname to check. + + Returns: + bool: True if the hostname looks like a Docker service name, False otherwise. + """ + return bool(hostname and "." not in hostname and hostname != "localhost") + + @staticmethod + def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: + """Try different resolution strategies for Docker services + + Args: + original_endpoint: The original endpoint URL. + parsed: The parsed URL object. + + Returns: + str: The resolved endpoint URL. + + Raises: + RuntimeError: If no valid endpoint is found. + """ + strategies = [ + original_endpoint, # Try original first (works if running in same network) + f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost + f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop + f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP + ] + + for endpoint in strategies: + if Dreadnode._test_connection(endpoint): + logger.warning( + f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'." + ) + return str(endpoint) + + # If nothing works, return original and let it fail with a helpful error + raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") + + @staticmethod + def _test_connection(endpoint: str) -> bool: + """Quick connectivity test + + Args: + endpoint: The endpoint URL to test. + + Returns: + bool: True if the connection is successful, False otherwise. + """ + try: + parsed = urlparse(endpoint) + socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1) + except Exception: # noqa: BLE001 + return False + + return True + def configure( self, *, @@ -261,12 +360,13 @@ def initialize(self) -> None: # ) credentials = self._api.get_user_data_credentials() + resolved_endpoint = self._resolve_endpoint(credentials.endpoint) self._fs = S3FileSystem( key=credentials.access_key_id, secret=credentials.secret_access_key, token=credentials.session_token, client_kwargs={ - "endpoint_url": credentials.endpoint, + "endpoint_url": resolved_endpoint, "region_name": credentials.region, }, ) @@ -1002,7 +1102,10 @@ def log_metric( value if isinstance(value, Metric) else Metric( - float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} + float(value), + step, + timestamp or datetime.now(timezone.utc), + attributes or {}, ) ) return target.log_metric(name, metric, origin=origin, mode=mode) diff --git a/pyproject.toml b/pyproject.toml index c2e083f0..306e8fc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,4 +122,5 @@ skip-magic-trailing-comma = false "tests/**/*.py" = [ "INP001", # namespace not required for pytest "S101", # asserts allowed in tests... + "SLF001", # allow access to private members ]