Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 108 additions & 5 deletions dreadnode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import inspect
import os
import random
import socket
import typing as t
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from urllib.parse import urljoin, urlparse, urlunparse
from urllib.parse import ParseResult, urljoin, urlparse, urlunparse

import coolname # type: ignore [import-untyped]
import logfire
Expand All @@ -32,7 +33,14 @@
ENV_SERVER,
ENV_SERVER_URL,
)
from dreadnode.metric import Metric, MetricAggMode, MetricDict, Scorer, ScorerCallable, T
from dreadnode.metric import (
Metric,
MetricAggMode,
MetricDict,
Scorer,
ScorerCallable,
T,
)
from dreadnode.task import P, R, Task
from dreadnode.tracing.exporters import (
FileExportConfig,
Expand All @@ -54,7 +62,7 @@
JsonDict,
JsonValue,
)
from dreadnode.util import clean_str, handle_internal_errors
from dreadnode.util import clean_str, handle_internal_errors, logger
from dreadnode.version import VERSION

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -128,6 +136,97 @@ def __init__(

self._initialized = False

@staticmethod
def _resolve_endpoint(endpoint: str | None) -> str | None:
"""Automatically resolve endpoints based on environment

Args:
endpoint: The endpoint URL to resolve.

Returns:
str: The resolved endpoint URL.

Raises:
ValueError: If the endpoint URL is invalid.
"""
if not endpoint:
return None
parsed = urlparse(endpoint)

# If it's a real domain (has dots), use as-is
if not parsed.hostname:
raise ValueError(f"Invalid endpoint URL: {endpoint}")

if "." in parsed.hostname:
return endpoint

# If it's a service name, try to resolve it
if Dreadnode._is_docker_service_name(parsed.hostname):
return Dreadnode._resolve_docker_service(endpoint, parsed)

return endpoint

@staticmethod
def _is_docker_service_name(hostname: str) -> bool:
"""Check if this looks like a Docker service name

Args:
hostname: The hostname to check.

Returns:
bool: True if the hostname looks like a Docker service name, False otherwise.
"""
return bool(hostname and "." not in hostname and hostname != "localhost")

@staticmethod
def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str:
"""Try different resolution strategies for Docker services

Args:
original_endpoint: The original endpoint URL.
parsed: The parsed URL object.

Returns:
str: The resolved endpoint URL.

Raises:
RuntimeError: If no valid endpoint is found.
"""
strategies = [
original_endpoint, # Try original first (works if running in same network)
f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost
f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop
f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP
]

for endpoint in strategies:
if Dreadnode._test_connection(endpoint):
logger.warning(
f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'."
)
return str(endpoint)

# If nothing works, return original and let it fail with a helpful error
raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.")

@staticmethod
def _test_connection(endpoint: str) -> bool:
"""Quick connectivity test

Args:
endpoint: The endpoint URL to test.

Returns:
bool: True if the connection is successful, False otherwise.
"""
try:
parsed = urlparse(endpoint)
socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1)
except Exception: # noqa: BLE001
return False

return True

def configure(
self,
*,
Expand Down Expand Up @@ -261,12 +360,13 @@ def initialize(self) -> None:
# )

credentials = self._api.get_user_data_credentials()
resolved_endpoint = self._resolve_endpoint(credentials.endpoint)
self._fs = S3FileSystem(
key=credentials.access_key_id,
secret=credentials.secret_access_key,
token=credentials.session_token,
client_kwargs={
"endpoint_url": credentials.endpoint,
"endpoint_url": resolved_endpoint,
"region_name": credentials.region,
},
)
Expand Down Expand Up @@ -1002,7 +1102,10 @@ def log_metric(
value
if isinstance(value, Metric)
else Metric(
float(value), step, timestamp or datetime.now(timezone.utc), attributes or {}
float(value),
step,
timestamp or datetime.now(timezone.utc),
attributes or {},
)
)
return target.log_metric(name, metric, origin=origin, mode=mode)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,5 @@ skip-magic-trailing-comma = false
"tests/**/*.py" = [
"INP001", # namespace not required for pytest
"S101", # asserts allowed in tests...
"SLF001", # allow access to private members
]