From cd00a859914ea0b17dcfb96f2f3dc6f6d3b860dc Mon Sep 17 00:00:00 2001 From: monoxgas Date: Fri, 18 Jul 2025 14:16:05 -0600 Subject: [PATCH 1/8] Initial CLI port --- docs/sdk/api.mdx | 156 +++++++++++++++++++++++++++--- dreadnode/__main__.py | 10 ++ dreadnode/api/client.py | 130 ++++++++++++++++++++----- dreadnode/api/models.py | 43 +++++--- dreadnode/api/util.py | 2 +- dreadnode/cli/__init__.py | 3 + dreadnode/cli/api.py | 79 +++++++++++++++ dreadnode/cli/config.py | 94 ++++++++++++++++++ dreadnode/cli/main.py | 127 ++++++++++++++++++++++++ dreadnode/cli/profile/__init__.py | 3 + dreadnode/cli/profile/cli.py | 76 +++++++++++++++ dreadnode/constants.py | 52 ++++++++-- dreadnode/data_types/__init__.py | 14 +-- dreadnode/tracing/span.py | 13 ++- dreadnode/util.py | 24 +++++ poetry.lock | 81 +++++++++++++--- pyproject.toml | 8 ++ 17 files changed, 829 insertions(+), 86 deletions(-) create mode 100644 dreadnode/__main__.py create mode 100644 dreadnode/cli/__init__.py create mode 100644 dreadnode/cli/api.py create mode 100644 dreadnode/cli/config.py create mode 100644 dreadnode/cli/main.py create mode 100644 dreadnode/cli/profile/__init__.py create mode 100644 dreadnode/cli/profile/cli.py diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index b5017b0f..377c2202 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -12,7 +12,11 @@ ApiClient ```python ApiClient( - base_url: str, api_key: str, *, debug: bool = False + base_url: str, + *, + api_key: str | None = None, + cookies: dict[str, str] | None = None, + debug: bool = False, ) ``` @@ -29,7 +33,9 @@ Initializes the API client. (`str`) –The base URL of the Dreadnode API. * **`api_key`** - (`str`) + (`str`, default: + `None` + ) –The API key for authentication. * **`debug`** (`bool`, default: @@ -42,11 +48,13 @@ Initializes the API client. def __init__( self, base_url: str, - api_key: str, *, + api_key: str | None = None, + cookies: dict[str, str] | None = None, debug: bool = False, ): - """Initializes the API client. + """ + Initializes the API client. Args: base_url (str): The base URL of the Dreadnode API. @@ -57,12 +65,27 @@ def __init__( if not self._base_url.endswith("/api"): self._base_url += "/api" + _cookies = httpx.Cookies() + cookie_domain = urlparse(base_url).hostname + if cookie_domain is None: + raise ValueError(f"Invalid URL: {base_url}") + + if cookie_domain == "localhost": + cookie_domain = "localhost.local" + + for key, value in (cookies or {}).items(): + _cookies.set(key, value, domain=cookie_domain) + + headers = { + "User-Agent": f"dreadnode-sdk/{VERSION}", + "Accept": "application/json", + } + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + self._client = httpx.Client( - headers={ - "User-Agent": f"dreadnode-sdk/{VERSION}", - "Accept": "application/json", - "X-API-Key": api_key, - }, + headers=headers, base_url=self._base_url, timeout=30, ) @@ -133,7 +156,8 @@ def export_metrics( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports metric data for a specific project. + """ + Exports metric data for a specific project. Args: project: The project identifier. @@ -224,7 +248,8 @@ def export_parameters( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports parameter data for a specific project. + """ + Exports parameter data for a specific project. Args: project: The project identifier. @@ -306,7 +331,8 @@ def export_runs( status: StatusFilter = "completed", aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports run data for a specific project. + """ + Exports run data for a specific project. Args: project: The project identifier. @@ -398,7 +424,8 @@ def export_timeseries( time_axis: TimeAxisType = "relative", aggregations: list[TimeAggregationType] | None = None, ) -> pd.DataFrame: - """Exports timeseries data for a specific project. + """ + Exports timeseries data for a specific project. Args: project: The project identifier. @@ -427,6 +454,26 @@ def export_timeseries( ``` + + +### get\_device\_codes + +```python +get_device_codes() -> DeviceCodeResponse +``` + +Start the authentication flow by requesting user and device codes. + + +```python +def get_device_codes(self) -> DeviceCodeResponse: + """Start the authentication flow by requesting user and device codes.""" + + response = self.request("POST", "/auth/device/code") + return DeviceCodeResponse(**response.json()) +``` + + ### get\_project @@ -634,6 +681,26 @@ def get_run_trace( ``` + + +### get\_user + +```python +get_user() -> UserResponse +``` + +Get the user email and username. + + +```python +def get_user(self) -> UserResponse: + """Get the user email and username.""" + + response = self.request("GET", "/user") + return UserResponse(**response.json()) +``` + + ### get\_user\_data\_credentials @@ -729,6 +796,47 @@ def list_runs(self, project: str) -> list[RunSummary]: ``` + + +### poll\_for\_token + +```python +poll_for_token( + device_code: str, + interval: int = DEFAULT_POLL_INTERVAL, + max_poll_time: int = DEFAULT_MAX_POLL_TIME, +) -> AccessRefreshTokenResponse +``` + +Poll for the access token with the given device code. + + +```python +def poll_for_token( + self, + device_code: str, + interval: int = DEFAULT_POLL_INTERVAL, + max_poll_time: int = DEFAULT_MAX_POLL_TIME, +) -> AccessRefreshTokenResponse: + """Poll for the access token with the given device code.""" + + start_time = datetime.now(timezone.utc) + while (datetime.now(timezone.utc) - start_time).total_seconds() < max_poll_time: + response = self._request( + "POST", "/auth/device/token", json_data={"device_code": device_code} + ) + + if response.status_code == 200: # noqa: PLR2004 + return AccessRefreshTokenResponse(**response.json()) + if response.status_code != 401: # noqa: PLR2004 + raise RuntimeError(self._get_error_message(response)) + + time.sleep(interval) + + raise RuntimeError("Polling for token timed out") +``` + + ### request @@ -782,7 +890,8 @@ def request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Makes an HTTP request to the API and raises exceptions for errors. + """ + Makes an HTTP request to the API and raises exceptions for errors. Args: method (str): The HTTP method (e.g., "GET", "POST"). @@ -808,6 +917,25 @@ def request( ``` + + +### url\_for\_user\_code + +```python +url_for_user_code(user_code: str) -> str +``` + +Get the URL to verify the user code. + + +```python +def url_for_user_code(self, user_code: str) -> str: + """Get the URL to verify the user code.""" + + return f"{self._base_url.removesuffix('/api')}/account/device?code={user_code}" +``` + + ExportFormat ------------ diff --git a/dreadnode/__main__.py b/dreadnode/__main__.py new file mode 100644 index 00000000..cab7e045 --- /dev/null +++ b/dreadnode/__main__.py @@ -0,0 +1,10 @@ +from dreadnode.cli import cli + + +def run() -> None: + """Run the Dreadnode CLI.""" + cli.meta() + + +if __name__ == "__main__": + run() diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index 789edc23..199363c3 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -1,22 +1,18 @@ import io import json +import time import typing as t +from datetime import datetime, timezone +from urllib.parse import urlparse import httpx import pandas as pd from pydantic import BaseModel from ulid import ULID -from dreadnode.api.util import ( - convert_flat_tasks_to_tree, - convert_flat_trace_to_tree, - process_run, - process_task, -) -from dreadnode.util import logger -from dreadnode.version import VERSION - -from .models import ( +from dreadnode.api.models import ( + AccessRefreshTokenResponse, + DeviceCodeResponse, MetricAggregationType, Project, RawRun, @@ -31,7 +27,17 @@ TraceSpan, TraceTree, UserDataCredentials, + UserResponse, +) +from dreadnode.api.util import ( + convert_flat_tasks_to_tree, + convert_flat_trace_to_tree, + process_run, + process_task, ) +from dreadnode.constants import DEFAULT_MAX_POLL_TIME, DEFAULT_POLL_INTERVAL +from dreadnode.util import logger +from dreadnode.version import VERSION ModelT = t.TypeVar("ModelT", bound=BaseModel) @@ -47,11 +53,13 @@ class ApiClient: def __init__( self, base_url: str, - api_key: str, *, + api_key: str | None = None, + cookies: dict[str, str] | None = None, debug: bool = False, ): - """Initializes the API client. + """ + Initializes the API client. Args: base_url (str): The base URL of the Dreadnode API. @@ -62,12 +70,27 @@ def __init__( if not self._base_url.endswith("/api"): self._base_url += "/api" + _cookies = httpx.Cookies() + cookie_domain = urlparse(base_url).hostname + if cookie_domain is None: + raise ValueError(f"Invalid URL: {base_url}") + + if cookie_domain == "localhost": + cookie_domain = "localhost.local" + + for key, value in (cookies or {}).items(): + _cookies.set(key, value, domain=cookie_domain) + + headers = { + "User-Agent": f"dreadnode-sdk/{VERSION}", + "Accept": "application/json", + } + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + self._client = httpx.Client( - headers={ - "User-Agent": f"dreadnode-sdk/{VERSION}", - "Accept": "application/json", - "X-API-Key": api_key, - }, + headers=headers, base_url=self._base_url, timeout=30, ) @@ -77,7 +100,8 @@ def __init__( self._client.event_hooks["response"].append(self._log_response) def _log_request(self, request: httpx.Request) -> None: - """Logs HTTP requests if debug mode is enabled. + """ + Logs HTTP requests if debug mode is enabled. Args: request (httpx.Request): The HTTP request object. @@ -90,7 +114,8 @@ def _log_request(self, request: httpx.Request) -> None: logger.debug("-------------------------------------------") def _log_response(self, response: httpx.Response) -> None: - """Logs HTTP responses if debug mode is enabled. + """ + Logs HTTP responses if debug mode is enabled. Args: response (httpx.Response): The HTTP response object. @@ -103,7 +128,8 @@ def _log_response(self, response: httpx.Response) -> None: logger.debug("--------------------------------------------") def _get_error_message(self, response: httpx.Response) -> str: - """Extracts the error message from an HTTP response. + """ + Extracts the error message from an HTTP response. Args: response (httpx.Response): The HTTP response object. @@ -125,7 +151,8 @@ def _request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Makes a raw HTTP request to the API. + """ + Makes a raw HTTP request to the API. Args: method (str): The HTTP method (e.g., "GET", "POST"). @@ -146,7 +173,8 @@ def request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Makes an HTTP request to the API and raises exceptions for errors. + """ + Makes an HTTP request to the API and raises exceptions for errors. Args: method (str): The HTTP method (e.g., "GET", "POST"). @@ -170,6 +198,52 @@ def request( return response + # Auth + + def url_for_user_code(self, user_code: str) -> str: + """Get the URL to verify the user code.""" + + return f"{self._base_url.removesuffix('/api')}/account/device?code={user_code}" + + def get_device_codes(self) -> DeviceCodeResponse: + """Start the authentication flow by requesting user and device codes.""" + + response = self.request("POST", "/auth/device/code") + return DeviceCodeResponse(**response.json()) + + def poll_for_token( + self, + device_code: str, + interval: int = DEFAULT_POLL_INTERVAL, + max_poll_time: int = DEFAULT_MAX_POLL_TIME, + ) -> AccessRefreshTokenResponse: + """Poll for the access token with the given device code.""" + + start_time = datetime.now(timezone.utc) + while (datetime.now(timezone.utc) - start_time).total_seconds() < max_poll_time: + response = self._request( + "POST", "/auth/device/token", json_data={"device_code": device_code} + ) + + if response.status_code == 200: # noqa: PLR2004 + return AccessRefreshTokenResponse(**response.json()) + if response.status_code != 401: # noqa: PLR2004 + raise RuntimeError(self._get_error_message(response)) + + time.sleep(interval) + + raise RuntimeError("Polling for token timed out") + + # User + + def get_user(self) -> UserResponse: + """Get the user email and username.""" + + response = self.request("GET", "/user") + return UserResponse(**response.json()) + + # Strikes + def list_projects(self) -> list[Project]: """Retrieves a list of projects. @@ -294,7 +368,8 @@ def export_runs( status: StatusFilter = "completed", aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports run data for a specific project. + """ + Exports run data for a specific project. Args: project: The project identifier. @@ -327,7 +402,8 @@ def export_metrics( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports metric data for a specific project. + """ + Exports metric data for a specific project. Args: project: The project identifier. @@ -363,7 +439,8 @@ def export_parameters( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports parameter data for a specific project. + """ + Exports parameter data for a specific project. Args: project: The project identifier. @@ -401,7 +478,8 @@ def export_timeseries( time_axis: TimeAxisType = "relative", aggregations: list[TimeAggregationType] | None = None, ) -> pd.DataFrame: - """Exports timeseries data for a specific project. + """ + Exports timeseries data for a specific project. Args: project: The project identifier. diff --git a/dreadnode/api/models.py b/dreadnode/api/models.py index 53685180..d1d594de 100644 --- a/dreadnode/api/models.py +++ b/dreadnode/api/models.py @@ -32,6 +32,35 @@ class UserResponse(BaseModel): api_key: UserAPIKey +class UserDataCredentials(BaseModel): + access_key_id: str + secret_access_key: str + session_token: str + expiration: datetime + region: str + bucket: str + prefix: str + endpoint: str | None + + +# Auth + + +class DeviceCodeResponse(BaseModel): + id: UUID + completed: bool + device_code: str + expires_at: datetime + expires_in: int + user_code: str + verification_url: str + + +class AccessRefreshTokenResponse(BaseModel): + access_token: str + refresh_token: str + + # Strikes SpanStatus = t.Literal[ @@ -404,17 +433,3 @@ class TraceTree(BaseModel): """Span at this node, can be a Task or a TraceSpan.""" children: list["TraceTree"] = [] """Children of this span, representing nested spans or tasks.""" - - -# User data credentials - - -class UserDataCredentials(BaseModel): - access_key_id: str - secret_access_key: str - session_token: str - expiration: datetime - region: str - bucket: str - prefix: str - endpoint: str | None diff --git a/dreadnode/api/util.py b/dreadnode/api/util.py index 004e3a91..b5b42adb 100644 --- a/dreadnode/api/util.py +++ b/dreadnode/api/util.py @@ -1,6 +1,6 @@ from logging import getLogger -from .models import ( +from dreadnode.api.models import ( Object, ObjectUri, ObjectVal, diff --git a/dreadnode/cli/__init__.py b/dreadnode/cli/__init__.py new file mode 100644 index 00000000..b9e50854 --- /dev/null +++ b/dreadnode/cli/__init__.py @@ -0,0 +1,3 @@ +from dreadnode.cli.main import cli + +__all__ = ["cli"] diff --git a/dreadnode/cli/api.py b/dreadnode/cli/api.py new file mode 100644 index 00000000..ad2034c8 --- /dev/null +++ b/dreadnode/cli/api.py @@ -0,0 +1,79 @@ +import atexit +import base64 +import json +from datetime import datetime, timezone + +from dreadnode.api.client import ApiClient +from dreadnode.cli.config import UserConfig +from dreadnode.constants import ( + DEFAULT_TOKEN_MAX_TTL, +) + + +class Token: + """A JWT token with an expiration time.""" + + data: str + expires_at: datetime + + @staticmethod + def parse_jwt_token_expiration(token: str) -> datetime: + """Return the expiration date from a JWT token.""" + + _, b64payload, _ = token.split(".") + payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8") + return datetime.fromtimestamp(json.loads(payload).get("exp"), tz=timezone.utc) + + def __init__(self, token: str): + self.data = token + self.expires_at = Token.parse_jwt_token_expiration(token) + + def ttl(self) -> int: + """Get number of seconds left until the token expires.""" + return int((self.expires_at - datetime.now(tz=timezone.utc)).total_seconds()) + + def is_expired(self) -> bool: + """Return True if the token is expired.""" + return self.ttl() <= 0 + + def is_close_to_expiry(self) -> bool: + """Return True if the token is close to expiry.""" + return self.ttl() <= DEFAULT_TOKEN_MAX_TTL + + +def create_api_client(*, profile: str | None = None) -> ApiClient: + """Create an authenticated API client using stored configuration data.""" + + user_config = UserConfig.read() + config = user_config.get_server_config(profile) + + client = ApiClient( + config.url, + cookies={"access_token": config.access_token, "refresh_token": config.refresh_token}, + ) + + # Preemptively check if the token is expired + if Token(config.refresh_token).is_expired(): + raise RuntimeError("Authentication expired, use [bold]dreadnode login[/]") + + def _flush_auth_changes() -> None: + """Flush the authentication data to disk if it has been updated.""" + + access_token = client._client.cookies.get("access_token") # noqa: SLF001 + refresh_token = client._client.cookies.get("refresh_token") # noqa: SLF001 + + changed: bool = False + if access_token and access_token != config.access_token: + changed = True + config.access_token = access_token + + if refresh_token and refresh_token != config.refresh_token: + changed = True + config.refresh_token = refresh_token + + if changed: + user_config.set_server_config(config, profile).write() + + atexit.register(_flush_auth_changes) + + return client diff --git a/dreadnode/cli/config.py b/dreadnode/cli/config.py new file mode 100644 index 00000000..10cfa127 --- /dev/null +++ b/dreadnode/cli/config.py @@ -0,0 +1,94 @@ +import rich +from pydantic import BaseModel +from ruamel.yaml import YAML + +from dreadnode.constants import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH + + +class ServerConfig(BaseModel): + """Server specific authentication data and API URL.""" + + url: str + email: str + username: str + api_key: str + access_token: str + refresh_token: str + + +class UserConfig(BaseModel): + """User configuration supporting multiple server profiles.""" + + active: str | None = None + servers: dict[str, ServerConfig] = {} + + def _update_active(self) -> None: + """If active is not set, set it to the first available server and raise an error if no servers are configured.""" + + if self.active not in self.servers: + self.active = next(iter(self.servers)) if self.servers else None + + def _update_urls(self) -> bool: + updated = False + for search, replace in { + "//staging-crucible.dreadnode.io": "//staging-platform.dreadnode.io", + "//dev-crucible.dreadnode.io": "//dev-platform.dreadnode.io", + "//crucible.dreadnode.io": "//platform.dreadnode.io", + }.items(): + for server in self.servers.values(): + if search in server.url: + server.url = server.url.replace(search, replace) + updated = True + return updated + + @classmethod + def read(cls) -> "UserConfig": + """Read the user configuration from the file system or return an empty instance.""" + + if not USER_CONFIG_PATH.exists(): + return cls() + + with USER_CONFIG_PATH.open("r") as f: + self = cls.model_validate(YAML().load(f)) + + if self._update_urls(): + self.write() + + return self + + def write(self) -> None: + """Write the user configuration to the file system.""" + + self._update_active() + + if not USER_CONFIG_PATH.parent.exists(): + rich.print(f":rocket: Creating config at {USER_CONFIG_PATH.parent}") + USER_CONFIG_PATH.parent.mkdir(parents=True) + + with USER_CONFIG_PATH.open("w") as f: + YAML().dump(self.model_dump(mode="json"), f) + + @property + def active_profile_name(self) -> str | None: + """Get the name of the active profile.""" + self._update_active() + return self.active + + def get_server_config(self, profile: str | None = None) -> ServerConfig: + """Get the server configuration for the given profile or None if not set.""" + + profile = profile or self.active + if not profile: + raise RuntimeError("No profile is set, use [bold]dreadnode login[/] to authenticate") + + if profile not in self.servers: + raise RuntimeError(f"No server configuration for profile: {profile}") + + return self.servers[profile] + + def set_server_config(self, config: ServerConfig, profile: str | None = None) -> "UserConfig": + """Set the server configuration for the given profile.""" + + profile = profile or self.active or DEFAULT_PROFILE_NAME + self.servers[profile] = config + return self diff --git a/dreadnode/cli/main.py b/dreadnode/cli/main.py new file mode 100644 index 00000000..061872c3 --- /dev/null +++ b/dreadnode/cli/main.py @@ -0,0 +1,127 @@ +import contextlib +import typing as t +import webbrowser + +import cyclopts +import rich + +from dreadnode.api.client import ApiClient +from dreadnode.cli.api import create_api_client +from dreadnode.cli.config import ServerConfig, UserConfig +from dreadnode.cli.profile import cli as profile_cli +from dreadnode.constants import PLATFORM_BASE_URL + +cli = cyclopts.App(help="Interact with Dreadnode platforms", version_flags=[], help_on_error=True) + +cli["--help"].group = "Meta" + +cli.command(profile_cli) + + +@cli.meta.default +def meta( + *tokens: t.Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)], +) -> None: + rich.print() + cli(tokens) + + +@cli.command(help="Authenticate to a platform server.", group="Auth") +def login( + *, + server: t.Annotated[ + str | None, cyclopts.Parameter(name=["--server", "-s"], help="URL of the server") + ] = None, + profile: t.Annotated[ + str | None, + cyclopts.Parameter(name=["--profile", "-p"], help="Profile alias to assign / update"), + ] = None, +) -> None: + if not server: + server = PLATFORM_BASE_URL + with contextlib.suppress(Exception): + existing_config = UserConfig.read().get_server_config(profile) + server = existing_config.url + + # create client with no auth data + client = ApiClient(base_url=server) + + rich.print(":laptop_computer: Requesting device code ...") + + # request user and device codes + codes = client.get_device_codes() + + # present verification URL to user + verification_url = client.url_for_user_code(codes.user_code) + verification_url_base = verification_url.split("?")[0] + + rich.print() + rich.print( + f"""\ +Attempting to automatically open the authorization page in your default browser. +If the browser does not open or you wish to use a different device, open the following URL: + +:link: [bold]{verification_url_base}[/] + +Then enter the code: [bold]{codes.user_code}[/] +""" + ) + + webbrowser.open(verification_url) + + # poll for the access token after user verification + tokens = client.poll_for_token(codes.device_code) + + client = ApiClient( + server, cookies={"refresh_token": tokens.refresh_token, "access_token": tokens.access_token} + ) + user = client.get_user() + + UserConfig.read().set_server_config( + ServerConfig( + url=server, + access_token=tokens.access_token, + refresh_token=tokens.refresh_token, + email=user.email_address, + username=user.username, + api_key=user.api_key.key, + ), + profile, + ).write() + + rich.print(f":white_check_mark: Authenticated as {user.email_address} ({user.username})") + + +@cli.command(help="Refresh data for the active server profile.", group="Auth") +def refresh() -> None: + user_config = UserConfig.read() + server_config = user_config.get_server_config() + + client = create_api_client() + user = client.get_user() + + server_config.email = user.email_address + server_config.username = user.username + server_config.api_key = user.api_key.key + + user_config.set_server_config(server_config).write() + + rich.print( + f":white_check_mark: Refreshed '[bold]{user_config.active}[/bold]' ([magenta]{user.email_address}[/] / [cyan]{user.username}[/])" + ) + + +@cli.command(help="Show versions and exit.", group="Meta") +def version() -> None: + import importlib.metadata + import platform + import sys + + version = importlib.metadata.version("dreadnode") + python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + + os_name = platform.system() + arch = platform.machine() + rich.print(f"Platform: {os_name} ({arch})") + rich.print(f"Python: {python_version}") + rich.print(f"Dreadnode: {version}") diff --git a/dreadnode/cli/profile/__init__.py b/dreadnode/cli/profile/__init__.py new file mode 100644 index 00000000..77af6edd --- /dev/null +++ b/dreadnode/cli/profile/__init__.py @@ -0,0 +1,3 @@ +from dreadnode.cli.profile.cli import cli + +__all__ = ["cli"] diff --git a/dreadnode/cli/profile/cli.py b/dreadnode/cli/profile/cli.py new file mode 100644 index 00000000..5394026c --- /dev/null +++ b/dreadnode/cli/profile/cli.py @@ -0,0 +1,76 @@ +import typing as t + +import cyclopts +import rich +from rich import box +from rich.table import Table + +from dreadnode.cli.api import Token +from dreadnode.cli.config import UserConfig +from dreadnode.util import time_to + +cli = cyclopts.App(name="profile", help="Manage server profiles") + + +@cli.command(name=["show", "list"], help="List all server profiles") +def show() -> None: + config = UserConfig.read() + if not config.servers: + rich.print(":exclamation: No server profiles are configured") + return + + table = Table(box=box.ROUNDED) + table.add_column("Profile", style="magenta") + table.add_column("URL", style="cyan") + table.add_column("Email") + table.add_column("Username") + table.add_column("Valid Until") + + for profile, server in config.servers.items(): + active = profile == config.active + refresh_token = Token(server.refresh_token) + + table.add_row( + profile + ("*" if active else ""), + server.url, + server.email, + server.username, + "[red]expired[/]" + if refresh_token.is_expired() + else f"{refresh_token.expires_at.astimezone().strftime('%c')} ({time_to(refresh_token.expires_at)})", + style="bold" if active else None, + ) + + rich.print(table) + + +@cli.command(help="Set the active server profile") +def switch(profile: t.Annotated[str, cyclopts.Parameter(help="Profile to switch to")]) -> None: + config = UserConfig.read() + if profile not in config.servers: + rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist") + return + + config.active = profile + config.write() + + rich.print(f":laptop_computer: Switched to [bold magenta]{profile}[/]") + rich.print(f"|- email: [bold]{config.servers[profile].email}[/]") + rich.print(f"|- username: {config.servers[profile].username}") + rich.print(f"|- url: {config.servers[profile].url}") + rich.print() + + +@cli.command(help="Remove a server profile") +def forget( + profile: t.Annotated[str, cyclopts.Parameter(help="Profile of the server to remove")], +) -> None: + config = UserConfig.read() + if profile not in config.servers: + rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist") + return + + del config.servers[profile] + config.write() + + rich.print(f":axe: Forgot about [bold]{profile}[/]") diff --git a/dreadnode/constants.py b/dreadnode/constants.py index ae9dc730..ca0fa7b0 100644 --- a/dreadnode/constants.py +++ b/dreadnode/constants.py @@ -1,4 +1,36 @@ -# Environment variable names +import os +import pathlib + +# +# Defaults +# + +# name of the default server profile +DEFAULT_PROFILE_NAME = "main" +# default poll interval for the authentication flow +DEFAULT_POLL_INTERVAL = 5 +# default maximum poll time for the authentication flow +DEFAULT_MAX_POLL_TIME = 300 +# default maximum token TTL in seconds +DEFAULT_TOKEN_MAX_TTL = 60 +# Default values for the S3 storage +DEFAULT_MAX_INLINE_OBJECT_BYTES = 10 * 1024 # 10KB +# default platform domain +DEFAULT_PLATFORM_BASE_DOMAIN = "dreadnode.io" +# default server URL +DEFAULT_SERVER_URL = f"https://platform.{DEFAULT_PLATFORM_BASE_DOMAIN}" +# default local directory for dreadnode objects +DEFAULT_LOCAL_OBJECT_DIR = ".dreadnode/objects" +# default docker registry subdomain +DEFAULT_DOCKER_REGISTRY_SUBDOMAIN = "registry" +# default docker registry local port +DEFAULT_DOCKER_REGISTRY_LOCAL_PORT = 5005 +# default docker registry image tag +DEFAULT_DOCKER_REGISTRY_IMAGE_TAG = "registry" + +# +# Environment Variable Names +# ENV_SERVER_URL = "DREADNODE_SERVER_URL" ENV_SERVER = "DREADNODE_SERVER" # alternative to SERVER_URL @@ -7,10 +39,18 @@ ENV_LOCAL_DIR = "DREADNODE_LOCAL_DIR" ENV_PROJECT = "DREADNODE_PROJECT" -# Default values +# +# Environment +# -DEFAULT_SERVER_URL = "https://platform.dreadnode.io" -DEFAULT_LOCAL_OBJECT_DIR = ".dreadnode/objects" +# enable debugging +DEBUG = bool(os.getenv("DREADNODE_DEBUG")) or False -# Default values for the S3 storage -MAX_INLINE_OBJECT_BYTES = 10 * 1024 # 10KB +# server url +PLATFORM_BASE_URL = os.getenv(ENV_SERVER, os.getenv(ENV_SERVER_URL, DEFAULT_SERVER_URL)) + +# path to the user configuration file +USER_CONFIG_PATH = pathlib.Path( + # allow overriding the user config file via env variable + os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config" +) diff --git a/dreadnode/data_types/__init__.py b/dreadnode/data_types/__init__.py index 04a95f21..11eac1d4 100644 --- a/dreadnode/data_types/__init__.py +++ b/dreadnode/data_types/__init__.py @@ -1,9 +1,9 @@ -from .audio import Audio -from .base import WithMeta -from .image import Image -from .object_3d import Object3D -from .table import Table -from .text import Code, Markdown, Text -from .video import Video +from dreadnode.data_types.audio import Audio +from dreadnode.data_types.base import WithMeta +from dreadnode.data_types.image import Image +from dreadnode.data_types.object_3d import Object3D +from dreadnode.data_types.table import Table +from dreadnode.data_types.text import Code, Markdown, Text +from dreadnode.data_types.video import Video __all__ = ["Audio", "Code", "Image", "Markdown", "Object3D", "Table", "Text", "Video", "WithMeta"] diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 525d5124..6b5258c0 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -31,16 +31,12 @@ from dreadnode.artifact.merger import ArtifactMerger from dreadnode.artifact.storage import ArtifactStorage from dreadnode.artifact.tree_builder import ArtifactTreeBuilder, DirectoryNode -from dreadnode.constants import MAX_INLINE_OBJECT_BYTES +from dreadnode.constants import DEFAULT_MAX_INLINE_OBJECT_BYTES from dreadnode.convert import run_span_to_graph from dreadnode.metric import Metric, MetricAggMode, MetricsDict from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal from dreadnode.serialization import Serialized, serialize -from dreadnode.types import UNSET, AnyDict, JsonDict, Unset -from dreadnode.util import clean_str -from dreadnode.version import VERSION - -from .constants import ( +from dreadnode.tracing.constants import ( EVENT_ATTRIBUTE_LINK_HASH, EVENT_ATTRIBUTE_OBJECT_HASH, EVENT_ATTRIBUTE_OBJECT_LABEL, @@ -69,6 +65,9 @@ SPAN_ATTRIBUTE_VERSION, SpanType, ) +from dreadnode.types import UNSET, AnyDict, JsonDict, Unset +from dreadnode.util import clean_str +from dreadnode.version import VERSION if t.TYPE_CHECKING: import networkx as nx # type: ignore [import-untyped] @@ -630,7 +629,7 @@ def _create_object_by_hash(self, serialized: Serialized, object_hash: str) -> Ob data_hash = serialized.data_hash schema_hash = serialized.schema_hash - if data is None or data_bytes is None or data_len <= MAX_INLINE_OBJECT_BYTES: + if data is None or data_bytes is None or data_len <= DEFAULT_MAX_INLINE_OBJECT_BYTES: return ObjectVal( hash=object_hash, value=data, diff --git a/dreadnode/util.py b/dreadnode/util.py index 89262d23..a78230ef 100644 --- a/dreadnode/util.py +++ b/dreadnode/util.py @@ -8,6 +8,7 @@ import sys import typing as t from contextlib import contextmanager +from datetime import datetime from pathlib import Path from types import TracebackType @@ -54,6 +55,29 @@ def safe_repr(obj: t.Any) -> str: return "" +def time_to(future_datetime: datetime) -> str: + """Get a string describing the time difference between a future datetime and now.""" + + now = datetime.now(tz=future_datetime.tzinfo) + time_difference = future_datetime - now + + days = time_difference.days + seconds = time_difference.seconds + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + seconds = seconds % 60 + + result = [] + if days > 0: + result.append(f"{days}d") + if hours > 0: + result.append(f"{hours}hr") + if minutes > 0: + result.append(f"{minutes}m") + + return ", ".join(result) if result else "Just now" + + def log_internal_error() -> None: try: current_test = os.environ.get("PYTEST_CURRENT_TEST", "") diff --git a/poetry.lock b/poetry.lock index 92aa60ad..4c7bc671 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -173,7 +173,7 @@ description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version == \"3.10\"" +markers = "python_version < \"3.11\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -185,7 +185,7 @@ version = "25.3.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main", "dev"] files = [ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, @@ -932,6 +932,30 @@ files = [ {file = "coolname-2.2.0.tar.gz", hash = "sha256:6c5d5731759104479e7ca195a9b64f7900ac5bead40183c09323c7d0be9e75c7"}, ] +[[package]] +name = "cyclopts" +version = "3.22.2" +description = "Intuitive, easy CLIs based on type hints." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "cyclopts-3.22.2-py3-none-any.whl", hash = "sha256:6681b0815fa2de2bccc364468fd25b15aa9617cb505c0b16ca62e2b18a57619e"}, + {file = "cyclopts-3.22.2.tar.gz", hash = "sha256:d3495231af6ae86479579777d212ddf77b113200f828badeaf401162ed87227d"}, +] + +[package.dependencies] +attrs = ">=23.1.0" +docstring-parser = {version = ">=0.15", markers = "python_version < \"4.0\""} +rich = ">=13.6.0" +rich-rst = ">=1.3.1,<2.0.0" +typing-extensions = {version = ">=4.8.0", markers = "python_version < \"3.11\""} + +[package.extras] +toml = ["tomli (>=2.0.0) ; python_version < \"3.11\""] +trio = ["trio (>=0.10.0)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "datasets" version = "3.6.0" @@ -1046,6 +1070,30 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "docstring-parser" +version = "0.16" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +groups = ["main"] +files = [ + {file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"}, + {file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"}, +] + +[[package]] +name = "docutils" +version = "0.21.2" +description = "Docutils -- Python Documentation Utilities" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2"}, + {file = "docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f"}, +] + [[package]] name = "elastic-transport" version = "8.17.1" @@ -1113,7 +1161,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["main", "dev"] -markers = "python_version == \"3.10\"" +markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"}, {file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"}, @@ -3453,6 +3501,22 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.1 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rich-rst" +version = "1.3.1" +description = "A beautiful reStructuredText renderer for rich" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "rich_rst-1.3.1-py3-none-any.whl", hash = "sha256:498a74e3896507ab04492d326e794c3ef76e7cda078703aa592d1853d91098c1"}, + {file = "rich_rst-1.3.1.tar.gz", hash = "sha256:fad46e3ba42785ea8c1785e2ceaa56e0ffa32dbe5410dec432f37e4107c4f383"}, +] + +[package.dependencies] +docutils = "*" +rich = ">=12.0.0" + [[package]] name = "rigging" version = "2.3.0" @@ -3644,7 +3708,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"}, @@ -3653,7 +3716,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"}, @@ -3662,7 +3724,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"}, @@ -3671,7 +3732,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f6f3eac23941b32afccc23081e1f50612bdbe4e982012ef4f5797986828cd01"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"}, @@ -3680,7 +3740,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"}, {file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"}, @@ -3935,7 +3994,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["main", "dev"] -markers = "python_version == \"3.10\"" +markers = "python_version < \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -4624,4 +4683,4 @@ training = ["transformers"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "21fe5cf29eefa6f77e8bb811529fa19adff4f32d8e64f13432402631c4d3808f" +content-hash = "11b2807c639414563027bcea596f37dc4350225231a48c487288adfc8aec39ea" diff --git a/pyproject.toml b/pyproject.toml index 91df662a..a435ac94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ pandas = "^2.2.3" fsspec = { version = ">=2023.1.0,<=2025.3.0", extras = [ "s3", ] } # Pinned for datasets compatibility +cyclopts = "^3.22.2" transformers = { version = "^4.41.0", optional = true } soundfile = { version = "^0.13.1", optional = true } @@ -58,6 +59,13 @@ packages = ["src"] [tool.hatch.build.targets.sdist] packages = ["src"] +[project.scripts] +dreadnode = 'dreadnode.__main__:run' +dn = 'dreadnode.__main__:run' + +[tool.poetry.plugins."pipx.run"] +dreadnode = 'dreadnode.__main__:run' + [tool.pytest.ini_options] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" From 1b1046be64f28b101de1f97da8bb6ca3b1066188 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Fri, 18 Jul 2025 14:16:29 -0600 Subject: [PATCH 2/8] Update dependencies --- poetry.lock | 311 ++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 3 + 2 files changed, 313 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 4c7bc671..88b243b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -199,6 +199,29 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""] +[[package]] +name = "beautifulsoup4" +version = "4.13.4" +description = "Screen-scraping library" +optional = false +python-versions = ">=3.7.0" +groups = ["dev"] +files = [ + {file = "beautifulsoup4-4.13.4-py3-none-any.whl", hash = "sha256:9bbbb14bfde9d79f38b8cd5f8c7c85f4b8f2523190ebed90e950a8dea4cb1c4b"}, + {file = "beautifulsoup4-4.13.4.tar.gz", hash = "sha256:dbb3c4e1ceae6aefebdaf2423247260cd062430a410e38c66f2baa50a8437195"}, +] + +[package.dependencies] +soupsieve = ">1.2" +typing-extensions = ">=4.0.0" + +[package.extras] +cchardet = ["cchardet"] +chardet = ["chardet"] +charset-normalizer = ["charset-normalizer"] +html5lib = ["html5lib"] +lxml = ["lxml"] + [[package]] name = "boto3" version = "1.38.14" @@ -1364,6 +1387,24 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "ghp-import" +version = "2.1.0" +description = "Copy your docs directly to the gh-pages branch." +optional = false +python-versions = "*" +groups = ["dev"] +files = [ + {file = "ghp-import-2.1.0.tar.gz", hash = "sha256:9c535c4c61193c2df8871222567d7fd7e5014d835f97dc7b7439069e2413d343"}, + {file = "ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619"}, +] + +[package.dependencies] +python-dateutil = ">=2.8.1" + +[package.extras] +dev = ["flake8", "markdown", "twine", "wheel"] + [[package]] name = "googleapis-common-protos" version = "1.70.0" @@ -1382,6 +1423,21 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0)"] +[[package]] +name = "griffe" +version = "1.7.3" +description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "griffe-1.7.3-py3-none-any.whl", hash = "sha256:c6b3ee30c2f0f17f30bcdef5068d6ab7a2a4f1b8bf1a3e74b56fffd21e1c5f75"}, + {file = "griffe-1.7.3.tar.gz", hash = "sha256:52ee893c6a3a968b639ace8015bec9d36594961e156e23315c8e8e51401fa50b"}, +] + +[package.dependencies] +colorama = ">=0.4" + [[package]] name = "h11" version = "0.16.0" @@ -1905,6 +1961,22 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (==8.1.3) ; python_version >= \"3.11\"", "build (==1.2.2) ; python_version >= \"3.11\"", "colorama (==0.4.5) ; python_version < \"3.8\"", "colorama (==0.4.6) ; python_version >= \"3.8\"", "exceptiongroup (==1.1.3) ; python_version >= \"3.7\" and python_version < \"3.11\"", "freezegun (==1.1.0) ; python_version < \"3.8\"", "freezegun (==1.5.0) ; python_version >= \"3.8\"", "mypy (==v0.910) ; python_version < \"3.6\"", "mypy (==v0.971) ; python_version == \"3.6\"", "mypy (==v1.13.0) ; python_version >= \"3.8\"", "mypy (==v1.4.1) ; python_version == \"3.7\"", "myst-parser (==4.0.0) ; python_version >= \"3.11\"", "pre-commit (==4.0.1) ; python_version >= \"3.9\"", "pytest (==6.1.2) ; python_version < \"3.8\"", "pytest (==8.3.2) ; python_version >= \"3.8\"", "pytest-cov (==2.12.1) ; python_version < \"3.8\"", "pytest-cov (==5.0.0) ; python_version == \"3.8\"", "pytest-cov (==6.0.0) ; python_version >= \"3.9\"", "pytest-mypy-plugins (==1.9.3) ; python_version >= \"3.6\" and python_version < \"3.8\"", "pytest-mypy-plugins (==3.1.0) ; python_version >= \"3.8\"", "sphinx-rtd-theme (==3.0.2) ; python_version >= \"3.11\"", "tox (==3.27.1) ; python_version < \"3.8\"", "tox (==4.23.2) ; python_version >= \"3.8\"", "twine (==6.0.1) ; python_version >= \"3.11\""] +[[package]] +name = "markdown" +version = "3.8.2" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "markdown-3.8.2-py3-none-any.whl", hash = "sha256:5c83764dbd4e00bdd94d85a19b8d55ccca20fe35b2e678a1422b380324dd5f24"}, + {file = "markdown-3.8.2.tar.gz", hash = "sha256:247b9a70dd12e27f67431ce62523e675b866d254f900c4fe75ce3dda62237c45"}, +] + +[package.extras] +docs = ["mdx_gh_links (>=0.2)", "mkdocs (>=1.6)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -1930,6 +2002,22 @@ profiling = ["gprof2dot"] rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] +[[package]] +name = "markdownify" +version = "1.1.0" +description = "Convert HTML to markdown." +optional = false +python-versions = "*" +groups = ["dev"] +files = [ + {file = "markdownify-1.1.0-py3-none-any.whl", hash = "sha256:32a5a08e9af02c8a6528942224c91b933b4bd2c7d078f9012943776fc313eeef"}, + {file = "markdownify-1.1.0.tar.gz", hash = "sha256:449c0bbbf1401c5112379619524f33b63490a8fa479456d41de9dc9e37560ebd"}, +] + +[package.dependencies] +beautifulsoup4 = ">=4.9,<5" +six = ">=1.15,<2" + [[package]] name = "markupsafe" version = "3.0.2" @@ -2013,6 +2101,126 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mergedeep" +version = "1.3.4" +description = "A deep merge function for 🐍." +optional = false +python-versions = ">=3.6" +groups = ["dev"] +files = [ + {file = "mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307"}, + {file = "mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8"}, +] + +[[package]] +name = "mkdocs" +version = "1.6.1" +description = "Project documentation with Markdown." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "mkdocs-1.6.1-py3-none-any.whl", hash = "sha256:db91759624d1647f3f34aa0c3f327dd2601beae39a366d6e064c03468d35c20e"}, + {file = "mkdocs-1.6.1.tar.gz", hash = "sha256:7b432f01d928c084353ab39c57282f29f92136665bdd6abf7c1ec8d822ef86f2"}, +] + +[package.dependencies] +click = ">=7.0" +colorama = {version = ">=0.4", markers = "platform_system == \"Windows\""} +ghp-import = ">=1.0" +jinja2 = ">=2.11.1" +markdown = ">=3.3.6" +markupsafe = ">=2.0.1" +mergedeep = ">=1.3.4" +mkdocs-get-deps = ">=0.2.0" +packaging = ">=20.5" +pathspec = ">=0.11.1" +pyyaml = ">=5.1" +pyyaml-env-tag = ">=0.1" +watchdog = ">=2.0" + +[package.extras] +i18n = ["babel (>=2.9.0)"] +min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4) ; platform_system == \"Windows\"", "ghp-import (==1.0)", "importlib-metadata (==4.4) ; python_version < \"3.10\"", "jinja2 (==2.11.1)", "markdown (==3.3.6)", "markupsafe (==2.0.1)", "mergedeep (==1.3.4)", "mkdocs-get-deps (==0.2.0)", "packaging (==20.5)", "pathspec (==0.11.1)", "pyyaml (==5.1)", "pyyaml-env-tag (==0.1)", "watchdog (==2.0)"] + +[[package]] +name = "mkdocs-autorefs" +version = "1.4.2" +description = "Automatically link across pages in MkDocs." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mkdocs_autorefs-1.4.2-py3-none-any.whl", hash = "sha256:83d6d777b66ec3c372a1aad4ae0cf77c243ba5bcda5bf0c6b8a2c5e7a3d89f13"}, + {file = "mkdocs_autorefs-1.4.2.tar.gz", hash = "sha256:e2ebe1abd2b67d597ed19378c0fff84d73d1dbce411fce7a7cc6f161888b6749"}, +] + +[package.dependencies] +Markdown = ">=3.3" +markupsafe = ">=2.0.1" +mkdocs = ">=1.1" + +[[package]] +name = "mkdocs-get-deps" +version = "0.2.0" +description = "MkDocs extension that lists all dependencies according to a mkdocs.yml file" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "mkdocs_get_deps-0.2.0-py3-none-any.whl", hash = "sha256:2bf11d0b133e77a0dd036abeeb06dec8775e46efa526dc70667d8863eefc6134"}, + {file = "mkdocs_get_deps-0.2.0.tar.gz", hash = "sha256:162b3d129c7fad9b19abfdcb9c1458a651628e4b1dea628ac68790fb3061c60c"}, +] + +[package.dependencies] +mergedeep = ">=1.3.4" +platformdirs = ">=2.2.0" +pyyaml = ">=5.1" + +[[package]] +name = "mkdocstrings" +version = "0.29.1" +description = "Automatic documentation from sources, for MkDocs." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mkdocstrings-0.29.1-py3-none-any.whl", hash = "sha256:37a9736134934eea89cbd055a513d40a020d87dfcae9e3052c2a6b8cd4af09b6"}, + {file = "mkdocstrings-0.29.1.tar.gz", hash = "sha256:8722f8f8c5cd75da56671e0a0c1bbed1df9946c0cef74794d6141b34011abd42"}, +] + +[package.dependencies] +Jinja2 = ">=2.11.1" +Markdown = ">=3.6" +MarkupSafe = ">=1.1" +mkdocs = ">=1.6" +mkdocs-autorefs = ">=1.4" +pymdown-extensions = ">=6.3" + +[package.extras] +crystal = ["mkdocstrings-crystal (>=0.3.4)"] +python = ["mkdocstrings-python (>=1.16.2)"] +python-legacy = ["mkdocstrings-python-legacy (>=0.2.1)"] + +[[package]] +name = "mkdocstrings-python" +version = "1.16.12" +description = "A Python handler for mkdocstrings." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mkdocstrings_python-1.16.12-py3-none-any.whl", hash = "sha256:22ded3a63b3d823d57457a70ff9860d5a4de9e8b1e482876fc9baabaf6f5f374"}, + {file = "mkdocstrings_python-1.16.12.tar.gz", hash = "sha256:9b9eaa066e0024342d433e332a41095c4e429937024945fea511afe58f63175d"}, +] + +[package.dependencies] +griffe = ">=1.6.2" +mkdocs-autorefs = ">=1.4" +mkdocstrings = ">=0.28.3" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + [[package]] name = "moviepy" version = "2.2.1" @@ -2600,6 +2808,18 @@ files = [ numpy = ">=1.23.5" types-pytz = ">=2022.1.1" +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pillow" version = "11.3.0" @@ -3172,6 +3392,25 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pymdown-extensions" +version = "10.16" +description = "Extension pack for Python Markdown." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pymdown_extensions-10.16-py3-none-any.whl", hash = "sha256:f5dd064a4db588cb2d95229fc4ee63a1b16cc8b4d0e6145c0899ed8723da1df2"}, + {file = "pymdown_extensions-10.16.tar.gz", hash = "sha256:71dac4fca63fabeffd3eb9038b756161a33ec6e8d230853d3cecf562155ab3de"}, +] + +[package.dependencies] +markdown = ">=3.6" +pyyaml = "*" + +[package.extras] +extra = ["pygments (>=2.19.1)"] + [[package]] name = "pytest" version = "8.4.1" @@ -3337,6 +3576,21 @@ files = [ ] markers = {main = "extra == \"training\""} +[[package]] +name = "pyyaml-env-tag" +version = "1.1" +description = "A custom YAML tag for referencing environment variables in YAML files." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04"}, + {file = "pyyaml_env_tag-1.1.tar.gz", hash = "sha256:2eb38b75a2d21ee0475d6d97ec19c63287a7e140231e4214969d0eac923cd7ff"}, +] + +[package.dependencies] +pyyaml = "*" + [[package]] name = "referencing" version = "0.36.2" @@ -3905,6 +4159,18 @@ files = [ cffi = ">=1.0" numpy = "*" +[[package]] +name = "soupsieve" +version = "2.7" +description = "A modern CSS selector implementation for Beautiful Soup." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "soupsieve-2.7-py3-none-any.whl", hash = "sha256:6e60cc5c1ffaf1cebcc12e8188320b72071e922c2e897f737cadce79ad5d30c4"}, + {file = "soupsieve-2.7.tar.gz", hash = "sha256:ad282f9b6926286d2ead4750552c8a6142bc4c783fd66b0293547c8fe6ae126a"}, +] + [[package]] name = "tiktoken" version = "0.9.0" @@ -4286,6 +4552,49 @@ platformdirs = ">=3.9.1,<5" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] +[[package]] +name = "watchdog" +version = "6.0.0" +description = "Filesystem events monitoring" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "watchdog-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d1cdb490583ebd691c012b3d6dae011000fe42edb7a82ece80965b42abd61f26"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc64ab3bdb6a04d69d4023b29422170b74681784ffb9463ed4870cf2f3e66112"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c897ac1b55c5a1461e16dae288d22bb2e412ba9807df8397a635d88f671d36c3"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6eb11feb5a0d452ee41f824e271ca311a09e250441c262ca2fd7ebcf2461a06c"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ef810fbf7b781a5a593894e4f439773830bdecb885e6880d957d5b9382a960d2"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:afd0fe1b2270917c5e23c2a65ce50c2a4abb63daafb0d419fde368e272a76b7c"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdd4e6f14b8b18c334febb9c4425a878a2ac20efd1e0b231978e7b150f92a948"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c7c15dda13c4eb00d6fb6fc508b3c0ed88b9d5d374056b239c4ad1611125c860"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f10cb2d5902447c7d0da897e2c6768bca89174d0c6e1e30abec5421af97a5b0"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e6f0e77c9417e7cd62af82529b10563db3423625c5fce018430b249bf977f9e8"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:90c8e78f3b94014f7aaae121e6b909674df5b46ec24d6bebc45c44c56729af2a"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7631a77ffb1f7d2eefa4445ebbee491c720a5661ddf6df3498ebecae5ed375c"}, + {file = "watchdog-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c7ac31a19f4545dd92fc25d200694098f42c9a8e391bc00bdd362c5736dbf881"}, + {file = "watchdog-6.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9513f27a1a582d9808cf21a07dae516f0fab1cf2d7683a742c498b93eedabb11"}, + {file = "watchdog-6.0.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7a0e56874cfbc4b9b05c60c8a1926fedf56324bb08cfbc188969777940aef3aa"}, + {file = "watchdog-6.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e6439e374fc012255b4ec786ae3c4bc838cd7309a540e5fe0952d03687d8804e"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2"}, + {file = "watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a"}, + {file = "watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680"}, + {file = "watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f"}, + {file = "watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282"}, +] + +[package.extras] +watchmedo = ["PyYAML (>=3.10)"] + [[package]] name = "win32-setctime" version = "1.2.0" @@ -4683,4 +4992,4 @@ training = ["transformers"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "11b2807c639414563027bcea596f37dc4350225231a48c487288adfc8aec39ea" +content-hash = "b00184c3d067d9748c6866e7223d3d3a95bb24bc56c96837e30b51f4bf154f2e" diff --git a/pyproject.toml b/pyproject.toml index a435ac94..322e1b94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ rigging = "^2.3.0" typer = "^0.15.2" datasets = "^3.5.0" pyarrow = "^19.0.1" +markdown = "^3.8.2" +markdownify = "^1.1.0" +mkdocstrings-python = "^1.16.12" [build-system] requires = ["poetry-core>=1.0.0", "setuptools>=42", "wheel"] From 2dc6137434e86a4da7a2f273b305030eb59a00e0 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Thu, 24 Jul 2025 01:52:32 -0600 Subject: [PATCH 3/8] Wrapping first pass on CLI port with github cloning. --- docs/sdk/api.mdx | 22 +++++ dreadnode/api/client.py | 9 ++ dreadnode/api/models.py | 9 ++ dreadnode/cli/github.py | 198 ++++++++++++++++++++++++++++++++++++++++ dreadnode/cli/main.py | 72 ++++++++++++++- 5 files changed, 307 insertions(+), 3 deletions(-) create mode 100644 dreadnode/cli/github.py diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index 377c2202..6d44f253 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -86,6 +86,7 @@ def __init__( self._client = httpx.Client( headers=headers, + cookies=_cookies, base_url=self._base_url, timeout=30, ) @@ -474,6 +475,27 @@ def get_device_codes(self) -> DeviceCodeResponse: ``` + + +### get\_github\_access\_token + +```python +get_github_access_token( + repos: list[str], +) -> GithubTokenResponse +``` + +Try to get a GitHub access token for the given repositories. + + +```python +def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse: + """Try to get a GitHub access token for the given repositories.""" + response = self.request("POST", "/github/token", json_data={"repos": repos}) + return GithubTokenResponse(**response.json()) +``` + + ### get\_project diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index 199363c3..caf805f2 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -13,6 +13,7 @@ from dreadnode.api.models import ( AccessRefreshTokenResponse, DeviceCodeResponse, + GithubTokenResponse, MetricAggregationType, Project, RawRun, @@ -91,6 +92,7 @@ def __init__( self._client = httpx.Client( headers=headers, + cookies=_cookies, base_url=self._base_url, timeout=30, ) @@ -242,6 +244,13 @@ def get_user(self) -> UserResponse: response = self.request("GET", "/user") return UserResponse(**response.json()) + # Github + + def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse: + """Try to get a GitHub access token for the given repositories.""" + response = self.request("POST", "/github/token", json_data={"repos": repos}) + return GithubTokenResponse(**response.json()) + # Strikes def list_projects(self) -> list[Project]: diff --git a/dreadnode/api/models.py b/dreadnode/api/models.py index d1d594de..61c52dda 100644 --- a/dreadnode/api/models.py +++ b/dreadnode/api/models.py @@ -433,3 +433,12 @@ class TraceTree(BaseModel): """Span at this node, can be a Task or a TraceSpan.""" children: list["TraceTree"] = [] """Children of this span, representing nested spans or tasks.""" + + +# Github + + +class GithubTokenResponse(BaseModel): + token: str + expires_at: datetime + repos: list[str] diff --git a/dreadnode/cli/github.py b/dreadnode/cli/github.py new file mode 100644 index 00000000..2e8a6bf3 --- /dev/null +++ b/dreadnode/cli/github.py @@ -0,0 +1,198 @@ +import os +import pathlib +import re +import tempfile +import typing as t +import zipfile + +import httpx +import rich + + +class GithubRepo(str): # noqa: SLOT000 + """ + A string subclass that normalizes various GitHub repository string formats. + + Supported formats: + - Full URLs: https://github.com/owner/repo + - SSH URLs: git@github.com:owner/repo.git + - Simple format: owner/repo + - With ref: owner/repo/tree/main + - With complex ref: owner/repo/tree/feature/custom + - With ref (URL): https://github.com/owner/repo/tree/main + - With .git: owner/repo.git + - Raw URLs: https://raw.githubusercontent.com/owner/repo/main + - Release URLs: owner/repo/releases/tag/v1.0.0 + - ZIP URLs: https://github.com/owner/repo/zipball/main + - Simple with ref: owner/repo@ref + """ + + # Instance properties + namespace: str + repo: str + ref: str + + # Regex patterns + SSH_PATTERN = re.compile(r"git@github\.com:([^/]+)/([^/]+?)(\.git)?$") + SIMPLE_PATTERN = re.compile(r"^([^/]+)/([^/]+?)(\.git)?$") + URL_PATTERN = re.compile(r"github\.com/([^/]+)/([^/]+?)(?:\.git|/(?:tree|blob)/(.+?))?$") + RAW_PATTERN = re.compile(r"raw\.githubusercontent\.com/([^/]+)/([^/]+)/(.+)") + RELEASE_PATTERN = re.compile(r"([^/]+)/([^/]+)/releases/tag/(.+)$") + OWN_FORMAT_PATTERN = re.compile(r"^([^/]+)/([^/@:]+)@(.+)$") + ZIPBALL_PATTERN = re.compile(r"github\.com/([^/]+)/([^/]+?)/zipball/(.+)$") + + def __new__(cls, value: t.Any, *_: t.Any, **__: t.Any) -> "GithubRepo": # noqa: PLR0912, PLR0915 + if not isinstance(value, str): + return super().__new__(cls, str(value)) + + namespace = None + repo = None + ref = "main" + + value = value.strip() + + # Try our own format first (owner/repo@ref) + match = cls.OWN_FORMAT_PATTERN.match(value) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) + + # Try as an SSH URL + elif value.startswith("git@"): + match = cls.SSH_PATTERN.search(value) + if match: + namespace, repo = match.group(1), match.group(2) + + # Try as a full URL + elif value.startswith(("http://", "https://")): + url_parts = value.split("//", 1)[1] + + # Try zipball pattern first + match = cls.ZIPBALL_PATTERN.search(url_parts) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) + + # Try raw githubusercontent pattern + elif url_parts.startswith("raw.githubusercontent.com"): + match = cls.RAW_PATTERN.search(url_parts) + if match: + namespace, repo, ref = match.group(1), match.group(2), match.group(3) + + # Try standard GitHub URL pattern + else: + match = cls.URL_PATTERN.search(url_parts) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) or ref + + # Try release tag format + elif "/releases/tag/" in value: + match = cls.RELEASE_PATTERN.match(value) + if match: + namespace, repo, ref = match.group(1), match.group(2), match.group(3) + + # Try simple owner/repo format + else: + # First try to extract any ref + tree_parts = value.split("/tree/") + blob_parts = value.split("/blob/") + + if len(tree_parts) > 1: + value, ref = tree_parts[0], tree_parts[1] + elif len(blob_parts) > 1: + value, ref = blob_parts[0], blob_parts[1] + + # Now check for owner/repo pattern + match = cls.SIMPLE_PATTERN.match(value) + if match: + namespace, repo = match.group(1), match.group(2) + + if not namespace or not repo: + raise ValueError(f"Invalid GitHub repository format: {value}") + + repo = repo.removesuffix(".git") + + obj = super().__new__(cls, f"{namespace}/{repo}@{ref}") + + obj.namespace = namespace + obj.repo = repo + obj.ref = ref + + return obj + + @property + def zip_url(self) -> str: + """ZIP archive URL for the repository.""" + return f"https://github.com/{self.namespace}/{self.repo}/zipball/{self.ref}" + + @property + def api_zip_url(self) -> str: + """API ZIP archive URL for the repository.""" + return f"https://api.github.com/repos/{self.namespace}/{self.repo}/zipball/{self.ref}" + + @property + def tree_url(self) -> str: + """URL to view the tree at this reference.""" + return f"https://github.com/{self.namespace}/{self.repo}/tree/{self.ref}" + + @property + def exists(self) -> bool: + """Check if a repo exists (or is private) on GitHub.""" + response = httpx.get(f"https://github.com/{self.namespace}/{self.repo}") + return response.status_code == 200 # noqa: PLR2004 + + def __repr__(self) -> str: + return f"GithubRepo(namespace='{self.namespace}', repo='{self.repo}', ref='{self.ref}')" + + +def get_repo_archive_source_path(source_dir: pathlib.Path) -> pathlib.Path: + """Return the actual source directory from a git repositoryZIP archive.""" + + if not (source_dir / "Dockerfile").exists() and not (source_dir / "Dockerfile.j2").exists(): + # if src has been downloaded from a ZIP archive, it may contain a single + # '--' folder, that is the actual source we want to use. + # Check if source_dir contains only one folder and update it if so. + children = list(source_dir.iterdir()) + if len(children) == 1 and children[0].is_dir(): + source_dir = children[0] + + return source_dir + + +def download_and_unzip_archive(url: str, *, headers: dict[str, str] | None = None) -> pathlib.Path: + """ + Downloads a ZIP archive from the given URL and unzips it into a temporary directory. + """ + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + local_zip_path = temp_dir / "archive.zip" + + rich.print(f":arrow_double_down: Downloading {url} ...") + + # download to temporary file + with httpx.stream("GET", url, follow_redirects=True, verify=True, headers=headers) as response: + response.raise_for_status() + with local_zip_path.open("wb") as zip_file: + for chunk in response.iter_bytes(chunk_size=8192): + zip_file.write(chunk) + + # unzip to temporary directory + try: + with zipfile.ZipFile(local_zip_path, "r") as zf: + for member in zf.infolist(): + file_path = os.path.realpath(temp_dir / member.filename) + if file_path.startswith(os.path.realpath(temp_dir)): + zf.extract(member, temp_dir) + else: + raise RuntimeError("Attempted Path Traversal Attack Detected") + + finally: + # always remove the zip file + if local_zip_path.exists(): + local_zip_path.unlink() + + return temp_dir diff --git a/dreadnode/cli/main.py b/dreadnode/cli/main.py index 061872c3..0370d07b 100644 --- a/dreadnode/cli/main.py +++ b/dreadnode/cli/main.py @@ -1,15 +1,21 @@ import contextlib +import pathlib +import shutil +import sys import typing as t import webbrowser import cyclopts import rich +from rich.panel import Panel +from rich.prompt import Prompt from dreadnode.api.client import ApiClient from dreadnode.cli.api import create_api_client from dreadnode.cli.config import ServerConfig, UserConfig +from dreadnode.cli.github import GithubRepo, download_and_unzip_archive from dreadnode.cli.profile import cli as profile_cli -from dreadnode.constants import PLATFORM_BASE_URL +from dreadnode.constants import DEBUG, PLATFORM_BASE_URL cli = cyclopts.App(help="Interact with Dreadnode platforms", version_flags=[], help_on_error=True) @@ -22,8 +28,16 @@ def meta( *tokens: t.Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)], ) -> None: - rich.print() - cli(tokens) + try: + rich.print() + cli(tokens) + except Exception as e: + if DEBUG: + raise + + rich.print() + rich.print(Panel(str(e), title="Error", title_align="left", border_style="red")) + sys.exit(1) @cli.command(help="Authenticate to a platform server.", group="Auth") @@ -72,6 +86,8 @@ def login( # poll for the access token after user verification tokens = client.poll_for_token(codes.device_code) + print(tokens) + client = ApiClient( server, cookies={"refresh_token": tokens.refresh_token, "access_token": tokens.access_token} ) @@ -111,6 +127,56 @@ def refresh() -> None: ) +@cli.command(help="Clone a github repository.") +def clone( + repo: t.Annotated[str, cyclopts.Parameter(help="Repository name or URL")], + target: t.Annotated[ + pathlib.Path | None, + cyclopts.Parameter(help="The target directory"), + ] = None, +) -> None: + github_repo = GithubRepo(repo) + + # Check if the target directory exists + target = target or pathlib.Path(github_repo.repo) + if target.exists(): + if ( + Prompt.ask(f":axe: Overwrite {target.absolute()}?", choices=["y", "n"], default="n") + == "n" + ): + return + rich.print() + shutil.rmtree(target) + + # Check if the repo is accessible + if github_repo.exists: + temp_dir = download_and_unzip_archive(github_repo.zip_url) + + # This could be a private repo that the user can access + # by getting an access token from our API + elif github_repo.namespace == "dreadnode": + github_access_token = create_api_client().get_github_access_token([github_repo.repo]) + rich.print(":key: Accessed private repository") + temp_dir = download_and_unzip_archive( + github_repo.api_zip_url, + headers={"Authorization": f"Bearer {github_access_token.token}"}, + ) + + else: + raise RuntimeError(f"Repository '{github_repo}' not found or inaccessible") + + # We assume the repo download results in a single + # child folder which is the real target + sub_dirs = list(temp_dir.iterdir()) + if len(sub_dirs) == 1 and sub_dirs[0].is_dir(): + temp_dir = sub_dirs[0] + + shutil.move(temp_dir, target) + + rich.print() + rich.print(f":tada: Cloned [b]{repo}[/] to [b]{target.absolute()}[/]") + + @cli.command(help="Show versions and exit.", group="Meta") def version() -> None: import importlib.metadata From ea03194daabce28d42d8ba282254532977fa0653 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Thu, 24 Jul 2025 04:03:52 -0600 Subject: [PATCH 4/8] Finalizing docs, auth flows, config management --- .secrets.baseline | 11 +- docs/intro.mdx | 38 +++---- docs/sdk/api.mdx | 2 +- docs/sdk/main.mdx | 97 ++++++++++++++--- docs/usage/cli.mdx | 195 ++++++++++++++++++++++++++++++++++ docs/usage/config.mdx | 194 +++++++++++++++++++++++++++++---- dreadnode/api/client.py | 2 +- dreadnode/cli/api.py | 2 +- dreadnode/cli/github.py | 75 +++++++++++++ dreadnode/cli/main.py | 19 ++-- dreadnode/cli/profile/cli.py | 33 +++++- dreadnode/{cli => }/config.py | 14 +++ dreadnode/constants.py | 1 + dreadnode/main.py | 189 +++++++++++++++----------------- dreadnode/util.py | 95 +++++++++++++++++ 15 files changed, 792 insertions(+), 175 deletions(-) create mode 100644 docs/usage/cli.mdx rename dreadnode/{cli => }/config.py (84%) diff --git a/.secrets.baseline b/.secrets.baseline index 59ecd314..8b97ad83 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -151,9 +151,16 @@ "filename": "docs/usage/config.mdx", "hashed_secret": "3f4f9a14a2d4d72a7074c2969dd34c89f2cbe61a", "is_verified": false, - "line_number": 23 + "line_number": 33 + }, + { + "type": "Secret Keyword", + "filename": "docs/usage/config.mdx", + "hashed_secret": "01eddf49c6b18f99f87ac7ba45e81d4a227e8d3f", + "is_verified": false, + "line_number": 171 } ] }, - "generated_at": "2025-07-14T09:19:13Z" + "generated_at": "2025-07-24T10:02:58Z" } diff --git a/docs/intro.mdx b/docs/intro.mdx index ede0baa0..407d6a1e 100644 --- a/docs/intro.mdx +++ b/docs/intro.mdx @@ -16,15 +16,26 @@ Which means, in order to evaluate Offensive Security agents, we need to develop ## Basic Example +Before you start, ensure you have the `dreadnode` package installed (see [installation](/install)). You can authenticate to a platform using the CLI, which is the recommended way to get started. + +```bash +# Authenticate to platform.dreadnode.io +dreadnode login + +# For self-hosted platforms, specify the server URL +dreadnode login --server http://self-hosted +``` + + +For complete authentication and configuration guidance, see the [Configuration](/usage/config) documentation. + + The most basic use of Strikes is a run with some logged data: ```python import asyncio import dreadnode -# Initialize with default settings -dreadnode.configure() - NAMES = ["Nick", "Will", "Brad", "Brian"] # Create a new task @@ -42,7 +53,7 @@ async def main() -> None: ) # Log inputs - dn.log_input("names", NAMES) + dreadnode.log_input("names", NAMES) # Run your tasks greetings = [ @@ -51,7 +62,7 @@ async def main() -> None: ] # Save outputs - dn.log_output("greetings", greetings) + dreadnode.log_output("greetings", greetings) # Track metrics dreadnode.log_metric("accuracy", 0.65, step=0) @@ -63,19 +74,6 @@ async def main() -> None: asyncio.run(main()) ``` - -We'll assume you have installed the `dreadnode` package and have your environment variables set up. Make sure you have `DREADNODE_API_KEY=...` set to your Platform API key. - -For more information on `dreadnode.configure()`, review the [Configuration](/usage/config) topic. - -If you call `dreadnode.configure()` without any token and your environment variables are not set, you'll receive a warning in the console, so keep an eye out! You can still run any of your code without sending data to the Dreadnode Platform. - - - -**Server Configuration** -By default, the SDK connects to the hosted Dreadnode platform at `https://platform.dreadnode.io`. If you're using a self-hosted instance, you must configure the server URL explicitly in your `dreadnode.configure()` call or via the `DREADNODE_SERVER` environment variable. See the [Configuration](/usage/config) guide for details. - - This code should be very familiar if you've used an ML-experimentation library before, and all the functions you're familiar with work exactly like you would expect. Under the hood, this code did a few things: @@ -114,8 +112,6 @@ Runs are the core unit of work in Strikes. They provide the context for all your ```python import dreadnode -dreadnode.configure() - with dreadnode.run("my-experiment"): # Everything that happens here is part of the run # All data collected is associated with this run @@ -147,8 +143,6 @@ Tasks are units of work within runs. They help you structure your code and provi ```python import dreadnode -dreadnode.configure() - @dreadnode.task() async def say_hello(name: str) -> str: return f"Hello, {name}!" diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index 6d44f253..46acf043 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -82,7 +82,7 @@ def __init__( } if api_key: - headers["Authorization"] = f"Bearer {api_key}" + headers["X-Api-Key"] = api_key self._client = httpx.Client( headers=headers, diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index c479c08c..a218f2be 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -118,7 +118,7 @@ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClie An ApiClient instance. """ if server is not None and token is not None: - return ApiClient(server, token) + return ApiClient(server, api_key=token) if not self._initialized: raise RuntimeError("Call .configure() before accessing the API") @@ -139,6 +139,7 @@ configure( *, server: str | None = None, token: str | None = None, + profile: str | None = None, local_dir: str | Path | Literal[False] = False, project: str | None = None, service_name: str | None = None, @@ -154,11 +155,16 @@ Configure the Dreadnode SDK and call `initialize()`. This method should always be called before using the SDK. -If `server` and `token` are not provided, the SDK will look in -the associated environment variables: +If `server` and `token` are not provided, the SDK will look for them +in the following order: -* `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` -* `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` +1. Environment variables: +2. `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` +3. `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` +4. Dreadnode profile (from `dreadnode login`) +5. Uses `profile` parameter if provided +6. Falls back to `DREADNODE_PROFILE` environment variable +7. Defaults to active profile **Parameters:** @@ -172,6 +178,11 @@ the associated environment variables: `None` ) –The Dreadnode API token. +* **`profile`** + (`str | None`, default: + `None` + ) + –The Dreadnode profile name to use (only used if env vars are not set). * **`local_dir`** (`str | Path | Literal[False]`, default: `False` @@ -215,6 +226,7 @@ def configure( *, server: str | None = None, token: str | None = None, + profile: str | None = None, local_dir: str | Path | t.Literal[False] = False, project: str | None = None, service_name: str | None = None, @@ -228,15 +240,21 @@ def configure( This method should always be called before using the SDK. - If `server` and `token` are not provided, the SDK will look in - the associated environment variables: + If `server` and `token` are not provided, the SDK will look for them + in the following order: - - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` - - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + 1. Environment variables: + - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` + - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + 2. Dreadnode profile (from `dreadnode login`) + - Uses `profile` parameter if provided + - Falls back to `DREADNODE_PROFILE` environment variable + - Defaults to active profile Args: server: The Dreadnode server URL. token: The Dreadnode API token. + profile: The Dreadnode profile name to use (only used if env vars are not set). local_dir: The local directory to store data in. project: The default project name to associate all runs with. service_name: The service name to use for OpenTelemetry. @@ -248,8 +266,43 @@ def configure( self._initialized = False - self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + # Determine configuration source and active profile for logging + config_source = "explicit parameters" + active_profile = None + + if not server or not token: + # Check environment variables first + env_server = os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + env_token = os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + + if env_server or env_token: + config_source = "environment vars" + else: + # Fall back to profile + config_source = "profile" + with contextlib.suppress(Exception): + user_config = UserConfig.read() + profile_name = profile or os.environ.get(ENV_PROFILE) + if profile_name: + active_profile = profile_name + else: + active_profile = user_config.active_profile_name + + if active_profile: + config_source = f"profile: {active_profile}" + + self.server = ( + server + or os.environ.get(ENV_SERVER_URL) + or os.environ.get(ENV_SERVER) + or self._get_profile_server(profile) + ) + self.token = ( + token + or os.environ.get(ENV_API_TOKEN) + or os.environ.get(ENV_API_KEY) + or self._get_profile_api_key(profile) + ) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -267,6 +320,17 @@ def configure( self.send_to_logfire = send_to_logfire self.otel_scope = otel_scope + # Log config information for clarity + if self.server or self.token or self.local_dir: + destination = self.server or DEFAULT_SERVER_URL or "local storage" + rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})") + + # Warn the user if the profile didn't resolve + elif active_profile and not (self.server or self.token): + rich.print( + f":exclamation: Dreadnode profile [orange_red1]{active_profile}[/] appears invalid." + ) + self.initialize() ``` @@ -305,7 +369,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan: A RunSpan object that can be used as a context manager. """ if not self._initialized: - self.initialize() + self.configure() return RunSpan.from_context( context=run_context, @@ -398,7 +462,8 @@ def initialize(self) -> None: if not (self.server or self.token or self.local_dir): warn_at_user_stacklevel( "Your current configuration won't persist run data anywhere. " - "Use `dreadnode.init(server=..., token=...)`, `dreadnode.init(local_dir=...)`, " + "Login with `dreadnode login` to set up a server and token, " + "Use `dreadnode.configure(server=..., token=...)`, `dreadnode.configure(profile=...)`, " f"or use environment variables ({ENV_SERVER_URL}, {ENV_API_TOKEN}, {ENV_LOCAL_DIR}).", category=DreadnodeConfigWarning, ) @@ -422,7 +487,7 @@ def initialize(self) -> None: ) self.server = urlunparse(parsed_new) - self._api = ApiClient(self.server, self.token) + self._api = ApiClient(self.server, api_key=self.token) self._api.list_projects() except Exception as e: @@ -456,7 +521,7 @@ def initialize(self) -> None: # ) credentials = self._api.get_user_data_credentials() - resolved_endpoint = self._resolve_endpoint(credentials.endpoint) + resolved_endpoint = resolve_endpoint(credentials.endpoint) self._fs = S3FileSystem( key=credentials.access_key_id, secret=credentials.secret_access_key, @@ -1639,7 +1704,7 @@ def run( The run will automatically be completed when the context manager exits. """ if not self._initialized: - self.initialize() + self.configure() if name is None: name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311 # nosec diff --git a/docs/usage/cli.mdx b/docs/usage/cli.mdx new file mode 100644 index 00000000..58524026 --- /dev/null +++ b/docs/usage/cli.mdx @@ -0,0 +1,195 @@ +--- +title: "CLI" +description: "Use the native command-line interface" +public: true +--- + +The Dreadnode CLI provides a command-line interface for authenticating with Dreadnode platforms, managing profiles, and cloning repositories. It's installed automatically with the `dreadnode` package. + +## Quick Start + +After installing the package, authenticate with your platform: + +```bash +dreadnode login +``` + +This opens your browser to authenticate and stores your credentials locally. + +## Commands + +### Authentication + +#### `login` + +Authenticate to a Dreadnode platform server. + +```bash +dreadnode login [--server URL] [--profile NAME] +``` + +**Options:** +- `--server`, `-s`: URL of the server (defaults to hosted platform) +- `--profile`, `-p`: Profile alias to assign or update + +**Examples:** +```bash +# Login to hosted platform +dreadnode login + +# Login to self-hosted server +dreadnode login --server https://my-server.com + +# Login with a specific profile name +dreadnode login --profile production +``` + +#### `refresh` + +Refresh data for the active server profile. + +```bash +dreadnode refresh +``` + +Updates your local profile with the latest user information from the server. + +### Profile Management + +#### `profile show` + +List all configured server profiles. + +```bash +dreadnode profile show +``` + +Shows a table with profile names, URLs, emails, usernames, and token expiration times. The active profile is marked with an asterisk. + +#### `profile switch` + +Set the active server profile. + +```bash +dreadnode profile switch PROFILE +``` + +**Arguments:** +- `PROFILE`: Name of the profile to switch to + +#### `profile forget` + +Remove a server profile. + +```bash +dreadnode profile forget PROFILE +``` + +**Arguments:** +- `PROFILE`: Name of the profile to remove + +### Repository Management + +#### `clone` + +Clone a GitHub repository. + +```bash +dreadnode clone REPO [TARGET] +``` + +**Arguments:** +- `REPO`: Repository name (e.g., `dreadnode/example-agents`) or full GitHub URL +- `TARGET`: Optional target directory (defaults to repository name) + +**Examples:** +```bash +# Clone a public repository +dreadnode clone dreadnode/example-agents + +# Clone to a specific directory +dreadnode clone dreadnode/example-agents ./my-agents + +# Clone a private dreadnode repository (requires authentication) +dreadnode clone dreadnode/private-repo +``` + + +The `clone` command can access privately shared `dreadnode/*` repositories using your authentication token. + +**Server Validation:** Private `dreadnode/*` repositories require authentication via a Dreadnode SaaS server (ending with `.dreadnode.io`). If your current profile points to a self-hosted server, the CLI will: + +1. Warn you about the server mismatch +2. Offer to switch to an available SaaS profile if one exists +3. Allow you to continue with a warning if you choose + +For other private repositories, use standard Git authentication. + + +### Meta Commands + +#### `version` + +Show version information. + +```bash +dreadnode version +``` + +Displays platform, Python version, and Dreadnode package version. + +#### `--help` + +Show help information for any command. + +```bash +dreadnode --help +dreadnode login --help +dreadnode profile --help +``` + +## Profile Configuration + +The CLI stores authentication data in `~/.dreadnode/config`. Each profile contains: + +- Server URL +- User credentials (access/refresh tokens, API key) +- User information (email, username) + +You can have multiple profiles for different servers or accounts: + +```bash +# Add different server profiles +dreadnode login --profile public +dreadnode login --server https://self-hosted --profile self-hosted + +# Switch between them +dreadnode profile switch self-hosted +dreadnode profile switch public +``` + +### Environment Variable Profile Selection + +You can override the active profile using the `DREADNODE_PROFILE` environment variable: + +```bash +# Temporarily use a different profile +export DREADNODE_PROFILE=production +dreadnode clone dreadnode/private-repo # Uses production profile + +# Or for a single command +DREADNODE_PROFILE=staging dreadnode clone dreadnode/test-repo +``` + +This affects both CLI commands and SDK configuration when using `dreadnode.configure()` without explicit server/token parameters. + +### Which Profile Gets Used? + +The CLI picks a profile in this order: + +1. **`--profile` flag** (if provided) +2. **`DREADNODE_PROFILE` environment variable** +3. **Active profile** (set via `dreadnode profile switch`) +4. **"main" profile** (default) + +The CLI will remember your server URL for future commands within that profile. \ No newline at end of file diff --git a/docs/usage/config.mdx b/docs/usage/config.mdx index 79109adc..710cd4cf 100644 --- a/docs/usage/config.mdx +++ b/docs/usage/config.mdx @@ -4,34 +4,117 @@ description: "Set configuration values" public: true --- -The quickest way to configure Strikes is to set the `DREADNODE_API_KEY` environment variable and let the library handle the rest with `dreadnode.configure()`. However, there are quite a few additional options you can set as needed. +## Self-Hosted Platforms -## Self-Hosted Platform +If you're using a **self-hosted Dreadnode platform**, you must specify your server URL during authentication: -If you're using a **self-hosted Dreadnode platform**, you must always specify your server URL explicitly. The SDK defaults to `https://platform.dreadnode.io` otherwise. +```bash +dreadnode login --server https://your-server.com +``` + +This creates a profile for your self-hosted instance. You can manage multiple servers by creating profiles with custom names: + +```bash +# Create profiles for different environments +dreadnode login --server https://dev.company.com --profile dev +dreadnode login --server https://prod.company.com --profile production +``` + +Switch between profiles anytime: + +```bash +dreadnode profile switch +``` + +Your code automatically uses the active profile - no changes needed. For automation and CI/CD with self-hosted platforms, use environment variables: + +```bash +export DREADNODE_SERVER="https://your-server.com" +export DREADNODE_API_KEY="your-api-token" +``` + +## When You Need `configure()` + +Most users never need to call `configure()` explicitly. The SDK auto-configures itself using CLI authentication or environment variables. + +**You only need `configure()` if you want to:** + +### Customize Configuration + +```python +dreadnode.configure( + local_dir="./my-custom-storage", # Custom local storage + project="my-project", # Default project name + console=False, # Disable console logging +) +``` + +### Override Auto-Detection - -```python in code +```python dreadnode.configure( - server="https://hosted-server", # Your self-hosted server URL - token="your-api-token", + server="https://platform.dreadnode.io", + token="your-api-token", # Explicit credentials + profile="production" # Specific Dreadnode profile ) ``` -```bash environment variables -export DREADNODE_SERVER="https://hosted-server" +### Use Environment Variables (CI/CD) + +```bash export DREADNODE_API_KEY="your-api-token" +# No configure() call needed - SDK auto-detects +``` + +**💡 For most users:** Skip `configure()` entirely. Use `dreadnode login` once and you're set. + +## Advanced Configuration + +### Managing Multiple Environments with Profiles + +Profiles let you manage multiple Dreadnode servers (development, staging, production, etc.) and switch between them seamlessly: + +**Create profiles for different environments:** + +```bash +# Hosted environments +dreadnode login --profile dev +dreadnode login --profile staging +dreadnode login --profile production + +# Self-hosted environments +dreadnode login --server https://dev.company.com --profile dev-internal +dreadnode login --server https://prod.company.com --profile prod-internal ``` - -## Using `configure()` +**View and manage profiles:** + +```bash +dreadnode profile show # List all profiles +dreadnode profile switch staging # Switch active profile +dreadnode profile forget dev # Remove a profile +``` -Initialize and set up connections with `configure()`. +**Use specific profiles in code:** + +```python +# Use a specific profile programmatically +dreadnode.configure(profile="production") + +# Or with environment variable +# DREADNODE_PROFILE=production +dreadnode.configure() +``` + +**Profile priority:** Environment variable `DREADNODE_PROFILE` overrides the active CLI profile. + +### Full Configuration Options ```python dreadnode.configure( - server="https://platform.dreadnode.io", # Platform URL - token="your-api-token", # API token for authentication + server="https://platform.dreadnode.io", # Platform URL (optional if using CLI/env) + token="your-api-token", # API token (optional if using CLI/env) + profile="production", # Dreadnode profile (only used if server/token not provided) local_dir="./runs", # Directory for local span storage project="my-project", # Default project name console=True, # Enable console logging @@ -43,13 +126,84 @@ dreadnode.configure( ) ``` -## Using Environment Variables +## Environment Variables Reference + +Environment variables are a great alternative for automated deployments and CI/CD pipelines. They override CLI profiles but are overridden by explicit `configure()` parameters. -Set variables to call `.config()` more easily. +### Complete Reference ```bash -export DREADNODE_SERVER="https://platform.dreadnode.io" # or DREADNODE_SERVER_URL -export DREADNODE_API_KEY="your-api-token" # or DREADNODE_API_TOKEN -export DREADNODE_LOCAL_DIR="./runs" -export DREADNODE_PROJECT="my-project" +# Authentication (choose one) +export DREADNODE_API_KEY="your-api-token" # Recommended +export DREADNODE_API_TOKEN="your-api-token" # Alternative + +# Server configuration +export DREADNODE_SERVER="https://your-server.com" # Recommended +export DREADNODE_SERVER_URL="https://your-server.com" # Alternative + +# Profile selection (when not using explicit server/token) +export DREADNODE_PROFILE="production" + +# Optional settings +export DREADNODE_LOCAL_DIR="./runs" # Local storage directory +export DREADNODE_PROJECT="my-project" # Default project name ``` + +## Configuration Priority Order + +The SDK resolves configuration in this priority order: + +### 1. Explicit Parameters (Highest Priority) + +```python +# These always override everything else +dreadnode.configure( + server="https://override.com", + token="explicit-token" +) +``` + +### 2. Environment Variables + +```python +# Set via shell or CI/CD +# export DREADNODE_SERVER="https://env.com" +# export DREADNODE_API_KEY="env-token" + +dreadnode.configure() # Uses env vars +``` + +### 3. CLI Profiles + +```python +# After: dreadnode login --profile production +dreadnode.configure(profile="production") # Uses CLI profile + +# Or let SDK auto-detect active profile +dreadnode.configure() # Uses active CLI profile +``` + +### 4. Local-Only Mode (Fallback) + +```python +# No credentials found anywhere +dreadnode.configure() # ⚠️ Works locally with warning +``` + +**Examples demonstrating priority:** + +```python +# Environment overrides profile +# export DREADNODE_API_KEY="env-token" +dreadnode.configure(profile="production") # Uses env-token (not profile) + +# Explicit param overrides environment +# export DREADNODE_SERVER="https://env.com" +dreadnode.configure(server="https://explicit.com") # Uses explicit.com + +# Profile selection with DREADNODE_PROFILE env var +# export DREADNODE_PROFILE="staging" +dreadnode.configure() # Uses "staging" profile (not active profile) +``` + +**💡 Bottom line:** Just start coding. The SDK will find your credentials or work locally. \ No newline at end of file diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index caf805f2..5b311482 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -88,7 +88,7 @@ def __init__( } if api_key: - headers["Authorization"] = f"Bearer {api_key}" + headers["X-Api-Key"] = api_key self._client = httpx.Client( headers=headers, diff --git a/dreadnode/cli/api.py b/dreadnode/cli/api.py index ad2034c8..96a59dfc 100644 --- a/dreadnode/cli/api.py +++ b/dreadnode/cli/api.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from dreadnode.api.client import ApiClient -from dreadnode.cli.config import UserConfig +from dreadnode.config import UserConfig from dreadnode.constants import ( DEFAULT_TOKEN_MAX_TTL, ) diff --git a/dreadnode/cli/github.py b/dreadnode/cli/github.py index 2e8a6bf3..7d883585 100644 --- a/dreadnode/cli/github.py +++ b/dreadnode/cli/github.py @@ -7,6 +7,9 @@ import httpx import rich +from rich.prompt import Prompt + +from dreadnode.config import UserConfig, find_dreadnode_saas_profiles, is_dreadnode_saas_server class GithubRepo(str): # noqa: SLOT000 @@ -196,3 +199,75 @@ def download_and_unzip_archive(url: str, *, headers: dict[str, str] | None = Non local_zip_path.unlink() return temp_dir + + +def validate_server_for_clone(user_config: UserConfig, current_profile: str | None) -> str | None: + """ + Validate the server configuration for git clone operations. + + Returns: + The profile name to use, or None if the user cancelled. + """ + config = user_config.get_server_config(current_profile) + current_server = config.url + + # If current server is a Dreadnode SaaS server, all good + if is_dreadnode_saas_server(current_server): + return current_profile or user_config.active_profile_name + + # Current server is not a Dreadnode SaaS server - warn user + rich.print() + rich.print(":warning: [yellow]Warning: Current server is not a Dreadnode SaaS server[/]") + rich.print(f" Current server: [cyan]{current_server}[/]") + rich.print(f" Current profile: [cyan]{current_profile or user_config.active_profile_name}[/]") + rich.print() + rich.print("Git clone for private dreadnode repositories requires a Dreadnode SaaS server") + rich.print("(ending with '.dreadnode.io') for authentication to work properly.") + rich.print() + + # Check if there are any SaaS profiles available + saas_profiles = find_dreadnode_saas_profiles(user_config) + + if saas_profiles: + rich.print("Available Dreadnode SaaS profiles:") + for profile in saas_profiles: + server_url = user_config.servers[profile].url + rich.print(f" - [green]{profile}[/] ({server_url})") + rich.print() + + choices = ["continue", "switch", "cancel"] + choice = Prompt.ask( + "Choose an option", choices=choices, default="cancel", show_choices=True + ) + + if choice == "continue": + rich.print( + ":warning: [yellow]Continuing with current server - private repository access may fail[/]" + ) + return current_profile or user_config.active_profile_name + if choice == "cancel": + rich.print("Cancelled.") + return None + if choice == "switch": + # Let user pick a profile + profile_choice = Prompt.ask( + "Select profile to use", choices=saas_profiles, default=saas_profiles[0] + ) + rich.print( + f":arrows_counterclockwise: Using profile '[green]{profile_choice}[/]' for this operation" + ) + return profile_choice + else: + # No SaaS profiles available + choice = Prompt.ask("Continue anyway?", choices=["y", "n"], default="n") + + if choice == "y": + rich.print( + ":warning: [yellow]Continuing with current server - private repository access may fail[/]" + ) + return current_profile or user_config.active_profile_name + rich.print( + "Cancelled. Use [bold]dreadnode login --server https://platform.dreadnode.io[/] to add a SaaS profile." + ) + + return None diff --git a/dreadnode/cli/main.py b/dreadnode/cli/main.py index 0370d07b..16dbef9e 100644 --- a/dreadnode/cli/main.py +++ b/dreadnode/cli/main.py @@ -12,9 +12,9 @@ from dreadnode.api.client import ApiClient from dreadnode.cli.api import create_api_client -from dreadnode.cli.config import ServerConfig, UserConfig -from dreadnode.cli.github import GithubRepo, download_and_unzip_archive +from dreadnode.cli.github import GithubRepo, download_and_unzip_archive, validate_server_for_clone from dreadnode.cli.profile import cli as profile_cli +from dreadnode.config import ServerConfig, UserConfig from dreadnode.constants import DEBUG, PLATFORM_BASE_URL cli = cyclopts.App(help="Interact with Dreadnode platforms", version_flags=[], help_on_error=True) @@ -86,8 +86,6 @@ def login( # poll for the access token after user verification tokens = client.poll_for_token(codes.device_code) - print(tokens) - client = ApiClient( server, cookies={"refresh_token": tokens.refresh_token, "access_token": tokens.access_token} ) @@ -127,7 +125,7 @@ def refresh() -> None: ) -@cli.command(help="Clone a github repository.") +@cli.command(help="Clone a github repository, typically privately shared dreadnode repositories.") def clone( repo: t.Annotated[str, cyclopts.Parameter(help="Repository name or URL")], target: t.Annotated[ @@ -155,7 +153,16 @@ def clone( # This could be a private repo that the user can access # by getting an access token from our API elif github_repo.namespace == "dreadnode": - github_access_token = create_api_client().get_github_access_token([github_repo.repo]) + # Validate server configuration for private repository access + user_config = UserConfig.read() + profile_to_use = validate_server_for_clone(user_config, None) + + if profile_to_use is None: + return # User cancelled + + github_access_token = create_api_client(profile=profile_to_use).get_github_access_token( + [github_repo.repo] + ) rich.print(":key: Accessed private repository") temp_dir = download_and_unzip_archive( github_repo.api_zip_url, diff --git a/dreadnode/cli/profile/cli.py b/dreadnode/cli/profile/cli.py index 5394026c..95e736ad 100644 --- a/dreadnode/cli/profile/cli.py +++ b/dreadnode/cli/profile/cli.py @@ -6,7 +6,7 @@ from rich.table import Table from dreadnode.cli.api import Token -from dreadnode.cli.config import UserConfig +from dreadnode.config import UserConfig from dreadnode.util import time_to cli = cyclopts.App(name="profile", help="Manage server profiles") @@ -20,7 +20,7 @@ def show() -> None: return table = Table(box=box.ROUNDED) - table.add_column("Profile", style="magenta") + table.add_column("Profile", style="orange_red1") table.add_column("URL", style="cyan") table.add_column("Email") table.add_column("Username") @@ -45,8 +45,33 @@ def show() -> None: @cli.command(help="Set the active server profile") -def switch(profile: t.Annotated[str, cyclopts.Parameter(help="Profile to switch to")]) -> None: +def switch( + profile: t.Annotated[str | None, cyclopts.Parameter(help="Profile to switch to")] = None, +) -> None: config = UserConfig.read() + + if not config.servers: + rich.print(":exclamation: No server profiles are configured") + return + + # If no profile provided, prompt user to choose + if profile is None: + from rich.prompt import Prompt + + profiles = list(config.servers.keys()) + rich.print("\nAvailable profiles:") + for i, p in enumerate(profiles, 1): + active_marker = " (current)" if p == config.active else "" + rich.print(f" {i}. [bold orange_red1]{p}[/]{active_marker}") + + choice = Prompt.ask( + "\nSelect a profile", + choices=[str(i) for i in range(1, len(profiles) + 1)] + profiles, + show_choices=False, + ) + + profile = profiles[int(choice) - 1] if choice.isdigit() else choice + if profile not in config.servers: rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist") return @@ -54,7 +79,7 @@ def switch(profile: t.Annotated[str, cyclopts.Parameter(help="Profile to switch config.active = profile config.write() - rich.print(f":laptop_computer: Switched to [bold magenta]{profile}[/]") + rich.print(f":laptop_computer: Switched to [bold orange_red1]{profile}[/]") rich.print(f"|- email: [bold]{config.servers[profile].email}[/]") rich.print(f"|- username: {config.servers[profile].username}") rich.print(f"|- url: {config.servers[profile].url}") diff --git a/dreadnode/cli/config.py b/dreadnode/config.py similarity index 84% rename from dreadnode/cli/config.py rename to dreadnode/config.py index 10cfa127..f1daa806 100644 --- a/dreadnode/cli/config.py +++ b/dreadnode/config.py @@ -92,3 +92,17 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) -> profile = profile or self.active or DEFAULT_PROFILE_NAME self.servers[profile] = config return self + + +def is_dreadnode_saas_server(url: str) -> bool: + """Check if the server URL is a Dreadnode SaaS server (ends with dreadnode.io).""" + return url.rstrip("/").endswith(".dreadnode.io") + + +def find_dreadnode_saas_profiles(user_config: UserConfig) -> list[str]: + """Find all profiles that point to Dreadnode SaaS servers.""" + saas_profiles = [] + for profile_name, server_config in user_config.servers.items(): + if is_dreadnode_saas_server(server_config.url): + saas_profiles.append(profile_name) + return saas_profiles diff --git a/dreadnode/constants.py b/dreadnode/constants.py index ca0fa7b0..e8df3b79 100644 --- a/dreadnode/constants.py +++ b/dreadnode/constants.py @@ -38,6 +38,7 @@ ENV_API_KEY = "DREADNODE_API_KEY" # pragma: allowlist secret (alternative to API_TOKEN) ENV_LOCAL_DIR = "DREADNODE_LOCAL_DIR" ENV_PROJECT = "DREADNODE_PROJECT" +ENV_PROFILE = "DREADNODE_PROFILE" # # Environment diff --git a/dreadnode/main.py b/dreadnode/main.py index 6fee3c8e..3e690df5 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -2,15 +2,15 @@ import inspect import os import random -import socket import typing as t from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from urllib.parse import ParseResult, urljoin, urlparse, urlunparse +from urllib.parse import urljoin, urlparse, urlunparse import coolname # type: ignore [import-untyped] import logfire +import rich from fsspec.implementations.local import ( # type: ignore [import-untyped] LocalFileSystem, ) @@ -24,11 +24,13 @@ from s3fs import S3FileSystem # type: ignore [import-untyped] from dreadnode.api.client import ApiClient +from dreadnode.config import UserConfig from dreadnode.constants import ( DEFAULT_SERVER_URL, ENV_API_KEY, ENV_API_TOKEN, ENV_LOCAL_DIR, + ENV_PROFILE, ENV_PROJECT, ENV_SERVER, ENV_SERVER_URL, @@ -61,7 +63,7 @@ Inherited, JsonValue, ) -from dreadnode.util import clean_str, handle_internal_errors, logger +from dreadnode.util import clean_str, handle_internal_errors, resolve_endpoint from dreadnode.version import VERSION if t.TYPE_CHECKING: @@ -135,102 +137,32 @@ def __init__( self._initialized = False - @staticmethod - def _resolve_endpoint(endpoint: str | None) -> str | None: - """Automatically resolve endpoints based on environment + def _get_profile_server(self, profile: str | None = None) -> str | None: + with contextlib.suppress(Exception): + user_config = UserConfig.read() + profile = profile or os.environ.get(ENV_PROFILE) + server_config = user_config.get_server_config(profile) + return server_config.url - Args: - endpoint: The endpoint URL to resolve. - - Returns: - str: The resolved endpoint URL. - - Raises: - ValueError: If the endpoint URL is invalid. - """ - if not endpoint: - return None - parsed = urlparse(endpoint) - - # If it's a real domain (has dots), use as-is - if not parsed.hostname: - raise ValueError(f"Invalid endpoint URL: {endpoint}") - - if "." in parsed.hostname: - return endpoint - - # If it's a service name, try to resolve it - if Dreadnode._is_docker_service_name(parsed.hostname): - return Dreadnode._resolve_docker_service(endpoint, parsed) - - return endpoint + # Silently fail if profile config is not available or invalid + return None - @staticmethod - def _is_docker_service_name(hostname: str) -> bool: - """Check if this looks like a Docker service name + def _get_profile_api_key(self, profile: str | None = None) -> str | None: + with contextlib.suppress(Exception): + user_config = UserConfig.read() + profile = profile or os.environ.get(ENV_PROFILE) + server_config = user_config.get_server_config(profile) + return server_config.api_key - Args: - hostname: The hostname to check. - - Returns: - bool: True if the hostname looks like a Docker service name, False otherwise. - """ - return bool(hostname and "." not in hostname and hostname != "localhost") - - @staticmethod - def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: - """Try different resolution strategies for Docker services - - Args: - original_endpoint: The original endpoint URL. - parsed: The parsed URL object. - - Returns: - str: The resolved endpoint URL. - - Raises: - RuntimeError: If no valid endpoint is found. - """ - strategies = [ - original_endpoint, # Try original first (works if running in same network) - f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost - f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop - f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP - ] - - for endpoint in strategies: - if Dreadnode._test_connection(endpoint): - logger.warning( - f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'." - ) - return str(endpoint) - - # If nothing works, return original and let it fail with a helpful error - raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") - - @staticmethod - def _test_connection(endpoint: str) -> bool: - """Quick connectivity test - - Args: - endpoint: The endpoint URL to test. - - Returns: - bool: True if the connection is successful, False otherwise. - """ - try: - parsed = urlparse(endpoint) - socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1) - except Exception: # noqa: BLE001 - return False - - return True + # Silently fail if profile config is not available or invalid + return None def configure( self, *, server: str | None = None, token: str | None = None, + profile: str | None = None, local_dir: str | Path | t.Literal[False] = False, project: str | None = None, service_name: str | None = None, @@ -244,15 +176,21 @@ def configure( This method should always be called before using the SDK. - If `server` and `token` are not provided, the SDK will look in - the associated environment variables: + If `server` and `token` are not provided, the SDK will look for them + in the following order: - - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` - - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + 1. Environment variables: + - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` + - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + 2. Dreadnode profile (from `dreadnode login`) + - Uses `profile` parameter if provided + - Falls back to `DREADNODE_PROFILE` environment variable + - Defaults to active profile Args: server: The Dreadnode server URL. token: The Dreadnode API token. + profile: The Dreadnode profile name to use (only used if env vars are not set). local_dir: The local directory to store data in. project: The default project name to associate all runs with. service_name: The service name to use for OpenTelemetry. @@ -264,8 +202,43 @@ def configure( self._initialized = False - self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + # Determine configuration source and active profile for logging + config_source = "explicit parameters" + active_profile = None + + if not server or not token: + # Check environment variables first + env_server = os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + env_token = os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + + if env_server or env_token: + config_source = "environment vars" + else: + # Fall back to profile + config_source = "profile" + with contextlib.suppress(Exception): + user_config = UserConfig.read() + profile_name = profile or os.environ.get(ENV_PROFILE) + if profile_name: + active_profile = profile_name + else: + active_profile = user_config.active_profile_name + + if active_profile: + config_source = f"profile: {active_profile}" + + self.server = ( + server + or os.environ.get(ENV_SERVER_URL) + or os.environ.get(ENV_SERVER) + or self._get_profile_server(profile) + ) + self.token = ( + token + or os.environ.get(ENV_API_TOKEN) + or os.environ.get(ENV_API_KEY) + or self._get_profile_api_key(profile) + ) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -283,6 +256,17 @@ def configure( self.send_to_logfire = send_to_logfire self.otel_scope = otel_scope + # Log config information for clarity + if self.server or self.token or self.local_dir: + destination = self.server or DEFAULT_SERVER_URL or "local storage" + rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})") + + # Warn the user if the profile didn't resolve + elif active_profile and not (self.server or self.token): + rich.print( + f":exclamation: Dreadnode profile [orange_red1]{active_profile}[/] appears invalid." + ) + self.initialize() def initialize(self) -> None: @@ -301,7 +285,8 @@ def initialize(self) -> None: if not (self.server or self.token or self.local_dir): warn_at_user_stacklevel( "Your current configuration won't persist run data anywhere. " - "Use `dreadnode.init(server=..., token=...)`, `dreadnode.init(local_dir=...)`, " + "Login with `dreadnode login` to set up a server and token, " + "Use `dreadnode.configure(server=..., token=...)`, `dreadnode.configure(profile=...)`, " f"or use environment variables ({ENV_SERVER_URL}, {ENV_API_TOKEN}, {ENV_LOCAL_DIR}).", category=DreadnodeConfigWarning, ) @@ -325,7 +310,7 @@ def initialize(self) -> None: ) self.server = urlunparse(parsed_new) - self._api = ApiClient(self.server, self.token) + self._api = ApiClient(self.server, api_key=self.token) self._api.list_projects() except Exception as e: @@ -359,7 +344,7 @@ def initialize(self) -> None: # ) credentials = self._api.get_user_data_credentials() - resolved_endpoint = self._resolve_endpoint(credentials.endpoint) + resolved_endpoint = resolve_endpoint(credentials.endpoint) self._fs = S3FileSystem( key=credentials.access_key_id, secret=credentials.secret_access_key, @@ -406,7 +391,7 @@ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClie An ApiClient instance. """ if server is not None and token is not None: - return ApiClient(server, token) + return ApiClient(server, api_key=token) if not self._initialized: raise RuntimeError("Call .configure() before accessing the API") @@ -776,7 +761,7 @@ def run( The run will automatically be completed when the context manager exits. """ if not self._initialized: - self.initialize() + self.configure() if name is None: name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311 # nosec @@ -830,7 +815,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan: A RunSpan object that can be used as a context manager. """ if not self._initialized: - self.initialize() + self.configure() return RunSpan.from_context( context=run_context, diff --git a/dreadnode/util.py b/dreadnode/util.py index a78230ef..cbd99b7e 100644 --- a/dreadnode/util.py +++ b/dreadnode/util.py @@ -5,12 +5,14 @@ import logging import os import re +import socket import sys import typing as t from contextlib import contextmanager from datetime import datetime from pathlib import Path from types import TracebackType +from urllib.parse import ParseResult, urlparse from logfire import suppress_instrumentation from logfire._internal.stack_info import add_non_user_code_prefix, is_user_code @@ -180,3 +182,96 @@ def handle_internal_errors() -> t.Iterator[None]: _HANDLE_INTERNAL_ERRORS_CODE = inspect.unwrap(handle_internal_errors).__code__ + + +def is_docker_service_name(hostname: str) -> bool: + """Check if this looks like a Docker service name + + Args: + hostname: The hostname to check. + + Returns: + bool: True if the hostname looks like a Docker service name, False otherwise. + """ + return bool(hostname and "." not in hostname and hostname != "localhost") + + +def resolve_endpoint(endpoint: str | None) -> str | None: + """Automatically resolve endpoints based on environment + + Args: + endpoint: The endpoint URL to resolve. + + Returns: + str: The resolved endpoint URL. + + Raises: + ValueError: If the endpoint URL is invalid. + """ + if not endpoint: + return None + parsed = urlparse(endpoint) + + # If it's a real domain (has dots), use as-is + if not parsed.hostname: + raise ValueError(f"Invalid endpoint URL: {endpoint}") + + if "." in parsed.hostname: + return endpoint + + # If it's a service name, try to resolve it + if is_docker_service_name(parsed.hostname): + return resolve_docker_service(endpoint, parsed) + + return endpoint + + +def test_connection(endpoint: str) -> bool: + """ + Simple test to check if the endpoint is reachable. + + Args: + endpoint: The endpoint URL to test. + + Returns: + bool: True if the endpoint is reachable, False otherwise. + """ + try: + parsed = urlparse(endpoint) + socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1) + except Exception: # noqa: BLE001 + return False + + return True + + +def resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: + """ + Try different resolution strategies for Docker services + + Args: + original_endpoint: The original endpoint URL. + parsed: The parsed URL object. + + Returns: + str: The resolved endpoint URL. + + Raises: + RuntimeError: If no valid endpoint is found. + """ + strategies = [ + original_endpoint, # Try original first (works if running in same network) + f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost + f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop + f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP + ] + + for endpoint in strategies: + if test_connection(endpoint): + logger.warning( + f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'." + ) + return str(endpoint) + + # If nothing works, return original and let it fail with a helpful error + raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") From 0e29d5da6c84eca17af72d3d16d7f58a23608f76 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Thu, 24 Jul 2025 04:13:59 -0600 Subject: [PATCH 5/8] Adapting feedback --- dreadnode/cli/github.py | 2 +- dreadnode/util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dreadnode/cli/github.py b/dreadnode/cli/github.py index 7d883585..6b33390d 100644 --- a/dreadnode/cli/github.py +++ b/dreadnode/cli/github.py @@ -191,7 +191,7 @@ def download_and_unzip_archive(url: str, *, headers: dict[str, str] | None = Non if file_path.startswith(os.path.realpath(temp_dir)): zf.extract(member, temp_dir) else: - raise RuntimeError("Attempted Path Traversal Attack Detected") + raise RuntimeError("Invalid file path detected in archive") finally: # always remove the zip file diff --git a/dreadnode/util.py b/dreadnode/util.py index cbd99b7e..fd026bb1 100644 --- a/dreadnode/util.py +++ b/dreadnode/util.py @@ -269,7 +269,7 @@ def resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: for endpoint in strategies: if test_connection(endpoint): logger.warning( - f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'." + f"Resolved Docker service endpoint '{parsed.hostname}' to '{endpoint}'." # noqa: G004 ) return str(endpoint) From d5c1f8ca11f3f32b4c4c0e3a9995d71d2d57d84a Mon Sep 17 00:00:00 2001 From: monoxgas Date: Thu, 24 Jul 2025 04:15:33 -0600 Subject: [PATCH 6/8] relock --- poetry.lock | 109 +++++++++++++++++++++------------------------------- 1 file changed, 44 insertions(+), 65 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1eac58b6..003108b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -172,7 +172,7 @@ version = "5.0.1" description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main", "dev"] markers = "python_version < \"3.11\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, @@ -949,7 +949,6 @@ files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "coolname" @@ -1125,66 +1124,6 @@ files = [ {file = "docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f"}, ] -[[package]] -name = "elastic-transport" -version = "8.17.1" -description = "Transport classes and utilities shared among Python Elastic client libraries" -optional = false -python-versions = ">=3.8" -groups = ["dev"] -files = [ - {file = "elastic_transport-8.17.1-py3-none-any.whl", hash = "sha256:192718f498f1d10c5e9aa8b9cf32aed405e469a7f0e9d6a8923431dbb2c59fb8"}, - {file = "elastic_transport-8.17.1.tar.gz", hash = "sha256:5edef32ac864dca8e2f0a613ef63491ee8d6b8cfb52881fa7313ba9290cac6d2"}, -] - -[package.dependencies] -certifi = "*" -urllib3 = ">=1.26.2,<3" - -[package.extras] -develop = ["aiohttp", "furo", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"] - -[[package]] -name = "elasticsearch" -version = "8.18.1" -description = "Python client for Elasticsearch" -optional = false -python-versions = ">=3.8" -groups = ["dev"] -files = [ - {file = "elasticsearch-8.18.1-py3-none-any.whl", hash = "sha256:1a8c8b5ec3ce5be88f96d2f898375671648e96272978bce0dee3137d9326aabb"}, - {file = "elasticsearch-8.18.1.tar.gz", hash = "sha256:998035f17a8c1fba7ae26b183dca797dcf95db86da6a7ecba56d31afc40f07c7"}, -] - -[package.dependencies] -elastic-transport = ">=8.15.1,<9" -python-dateutil = "*" -typing-extensions = "*" - -[package.extras] -async = ["aiohttp (>=3,<4)"] -dev = ["aiohttp", "black", "build", "coverage", "isort", "jinja2", "mapbox-vector-tile", "mypy", "nltk", "nox", "numpy", "orjson", "pandas", "pyarrow", "pyright", "pytest", "pytest-asyncio", "pytest-cov", "pytest-mock", "python-dateutil", "pyyaml (>=5.4)", "requests (>=2,<3)", "sentence-transformers", "simsimd", "tqdm", "twine", "types-python-dateutil", "types-tqdm", "unasync"] -docs = ["sphinx", "sphinx-autodoc-typehints", "sphinx-rtd-theme (>=2.0)"] -orjson = ["orjson (>=3)"] -pyarrow = ["pyarrow (>=1)"] -requests = ["requests (>=2.4.0,!=2.32.2,<3.0.0)"] -vectorstore-mmr = ["numpy (>=1)", "simsimd (>=3)"] - -[[package]] -name = "eval-type-backport" -version = "0.2.2" -description = "Like `typing._eval_type`, but lets older Python versions use newer typing features." -optional = false -python-versions = ">=3.8" -groups = ["dev"] -files = [ - {file = "eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a"}, - {file = "eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1"}, -] - -[package.extras] -tests = ["pytest"] - [[package]] name = "exceptiongroup" version = "1.3.0" @@ -1700,7 +1639,7 @@ version = "3.1.6" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"}, {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"}, @@ -2033,7 +1972,7 @@ version = "3.0.2" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"}, {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"}, @@ -4304,6 +4243,46 @@ files = [ {file = "soupsieve-2.7.tar.gz", hash = "sha256:ad282f9b6926286d2ead4750552c8a6142bc4c783fd66b0293547c8fe6ae126a"}, ] +[[package]] +name = "sse-starlette" +version = "2.4.1" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sse_starlette-2.4.1-py3-none-any.whl", hash = "sha256:08b77ea898ab1a13a428b2b6f73cfe6d0e607a7b4e15b9bb23e4a37b087fd39a"}, + {file = "sse_starlette-2.4.1.tar.gz", hash = "sha256:7c8a800a1ca343e9165fc06bbda45c78e4c6166320707ae30b416c42da070926"}, +] + +[package.dependencies] +anyio = ">=4.7.0" + +[package.extras] +daphne = ["daphne (>=4.2.0)"] +examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio,examples] (>=2.0.41)", "starlette (>=0.41.3)", "uvicorn (>=0.34.0)"] +granian = ["granian (>=2.3.1)"] +uvicorn = ["uvicorn (>=0.34.0)"] + +[[package]] +name = "starlette" +version = "0.47.2" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "starlette-0.47.2-py3-none-any.whl", hash = "sha256:c5847e96134e5c5371ee9fac6fdf1a67336d5815e09eb2a01fdb57a351ef915b"}, + {file = "starlette-0.47.2.tar.gz", hash = "sha256:6ae9aa5db235e4846decc1e7b79c4f346adf41e9777aebeb49dfd09bbd7023d8"}, +] + +[package.dependencies] +anyio = ">=3.6.2,<5" +typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""} + +[package.extras] +full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] + [[package]] name = "tiktoken" version = "0.9.0" @@ -5144,4 +5123,4 @@ training = ["transformers"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "b00184c3d067d9748c6866e7223d3d3a95bb24bc56c96837e30b51f4bf154f2e" +content-hash = "0f9e538475309634ca67a66835b23db97718351ceb13f5a835b47ad8b740908b" From e60c0e44ad0c895f2fa122f8cacd49a5b8d6c127 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Thu, 24 Jul 2025 04:18:29 -0600 Subject: [PATCH 7/8] Fix typecheck error --- dreadnode/object.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dreadnode/object.py b/dreadnode/object.py index f03495db..3f208ed3 100644 --- a/dreadnode/object.py +++ b/dreadnode/object.py @@ -1,8 +1,7 @@ import typing as t from dataclasses import dataclass -from litellm import ConfigDict -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from dreadnode.types import AnyDict From a90731611ceb0aa6ebefc579d2ad3af607e2d50c Mon Sep 17 00:00:00 2001 From: monoxgas Date: Thu, 24 Jul 2025 04:18:56 -0600 Subject: [PATCH 8/8] Adjust workflow --- .github/workflows/test.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3b6cc796..c538cbb7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -50,7 +50,9 @@ jobs: run: poetry run ruff check --output-format=github . - name: Typecheck + if: always() run: poetry run mypy . - name: Test + if: always() run: poetry run pytest