diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3b6cc796..c538cbb7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -50,7 +50,9 @@ jobs: run: poetry run ruff check --output-format=github . - name: Typecheck + if: always() run: poetry run mypy . - name: Test + if: always() run: poetry run pytest diff --git a/.secrets.baseline b/.secrets.baseline index 59ecd314..8b97ad83 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -151,9 +151,16 @@ "filename": "docs/usage/config.mdx", "hashed_secret": "3f4f9a14a2d4d72a7074c2969dd34c89f2cbe61a", "is_verified": false, - "line_number": 23 + "line_number": 33 + }, + { + "type": "Secret Keyword", + "filename": "docs/usage/config.mdx", + "hashed_secret": "01eddf49c6b18f99f87ac7ba45e81d4a227e8d3f", + "is_verified": false, + "line_number": 171 } ] }, - "generated_at": "2025-07-14T09:19:13Z" + "generated_at": "2025-07-24T10:02:58Z" } diff --git a/docs/intro.mdx b/docs/intro.mdx index ede0baa0..407d6a1e 100644 --- a/docs/intro.mdx +++ b/docs/intro.mdx @@ -16,15 +16,26 @@ Which means, in order to evaluate Offensive Security agents, we need to develop ## Basic Example +Before you start, ensure you have the `dreadnode` package installed (see [installation](/install)). You can authenticate to a platform using the CLI, which is the recommended way to get started. + +```bash +# Authenticate to platform.dreadnode.io +dreadnode login + +# For self-hosted platforms, specify the server URL +dreadnode login --server http://self-hosted +``` + + +For complete authentication and configuration guidance, see the [Configuration](/usage/config) documentation. + + The most basic use of Strikes is a run with some logged data: ```python import asyncio import dreadnode -# Initialize with default settings -dreadnode.configure() - NAMES = ["Nick", "Will", "Brad", "Brian"] # Create a new task @@ -42,7 +53,7 @@ async def main() -> None: ) # Log inputs - dn.log_input("names", NAMES) + dreadnode.log_input("names", NAMES) # Run your tasks greetings = [ @@ -51,7 +62,7 @@ async def main() -> None: ] # Save outputs - dn.log_output("greetings", greetings) + dreadnode.log_output("greetings", greetings) # Track metrics dreadnode.log_metric("accuracy", 0.65, step=0) @@ -63,19 +74,6 @@ async def main() -> None: asyncio.run(main()) ``` - -We'll assume you have installed the `dreadnode` package and have your environment variables set up. Make sure you have `DREADNODE_API_KEY=...` set to your Platform API key. - -For more information on `dreadnode.configure()`, review the [Configuration](/usage/config) topic. - -If you call `dreadnode.configure()` without any token and your environment variables are not set, you'll receive a warning in the console, so keep an eye out! You can still run any of your code without sending data to the Dreadnode Platform. - - - -**Server Configuration** -By default, the SDK connects to the hosted Dreadnode platform at `https://platform.dreadnode.io`. If you're using a self-hosted instance, you must configure the server URL explicitly in your `dreadnode.configure()` call or via the `DREADNODE_SERVER` environment variable. See the [Configuration](/usage/config) guide for details. - - This code should be very familiar if you've used an ML-experimentation library before, and all the functions you're familiar with work exactly like you would expect. Under the hood, this code did a few things: @@ -114,8 +112,6 @@ Runs are the core unit of work in Strikes. They provide the context for all your ```python import dreadnode -dreadnode.configure() - with dreadnode.run("my-experiment"): # Everything that happens here is part of the run # All data collected is associated with this run @@ -147,8 +143,6 @@ Tasks are units of work within runs. They help you structure your code and provi ```python import dreadnode -dreadnode.configure() - @dreadnode.task() async def say_hello(name: str) -> str: return f"Hello, {name}!" diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index b5017b0f..46acf043 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -12,7 +12,11 @@ ApiClient ```python ApiClient( - base_url: str, api_key: str, *, debug: bool = False + base_url: str, + *, + api_key: str | None = None, + cookies: dict[str, str] | None = None, + debug: bool = False, ) ``` @@ -29,7 +33,9 @@ Initializes the API client. (`str`) –The base URL of the Dreadnode API. * **`api_key`** - (`str`) + (`str`, default: + `None` + ) –The API key for authentication. * **`debug`** (`bool`, default: @@ -42,11 +48,13 @@ Initializes the API client. def __init__( self, base_url: str, - api_key: str, *, + api_key: str | None = None, + cookies: dict[str, str] | None = None, debug: bool = False, ): - """Initializes the API client. + """ + Initializes the API client. Args: base_url (str): The base URL of the Dreadnode API. @@ -57,12 +65,28 @@ def __init__( if not self._base_url.endswith("/api"): self._base_url += "/api" + _cookies = httpx.Cookies() + cookie_domain = urlparse(base_url).hostname + if cookie_domain is None: + raise ValueError(f"Invalid URL: {base_url}") + + if cookie_domain == "localhost": + cookie_domain = "localhost.local" + + for key, value in (cookies or {}).items(): + _cookies.set(key, value, domain=cookie_domain) + + headers = { + "User-Agent": f"dreadnode-sdk/{VERSION}", + "Accept": "application/json", + } + + if api_key: + headers["X-Api-Key"] = api_key + self._client = httpx.Client( - headers={ - "User-Agent": f"dreadnode-sdk/{VERSION}", - "Accept": "application/json", - "X-API-Key": api_key, - }, + headers=headers, + cookies=_cookies, base_url=self._base_url, timeout=30, ) @@ -133,7 +157,8 @@ def export_metrics( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports metric data for a specific project. + """ + Exports metric data for a specific project. Args: project: The project identifier. @@ -224,7 +249,8 @@ def export_parameters( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports parameter data for a specific project. + """ + Exports parameter data for a specific project. Args: project: The project identifier. @@ -306,7 +332,8 @@ def export_runs( status: StatusFilter = "completed", aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports run data for a specific project. + """ + Exports run data for a specific project. Args: project: The project identifier. @@ -398,7 +425,8 @@ def export_timeseries( time_axis: TimeAxisType = "relative", aggregations: list[TimeAggregationType] | None = None, ) -> pd.DataFrame: - """Exports timeseries data for a specific project. + """ + Exports timeseries data for a specific project. Args: project: The project identifier. @@ -427,6 +455,47 @@ def export_timeseries( ``` + + +### get\_device\_codes + +```python +get_device_codes() -> DeviceCodeResponse +``` + +Start the authentication flow by requesting user and device codes. + + +```python +def get_device_codes(self) -> DeviceCodeResponse: + """Start the authentication flow by requesting user and device codes.""" + + response = self.request("POST", "/auth/device/code") + return DeviceCodeResponse(**response.json()) +``` + + + + +### get\_github\_access\_token + +```python +get_github_access_token( + repos: list[str], +) -> GithubTokenResponse +``` + +Try to get a GitHub access token for the given repositories. + + +```python +def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse: + """Try to get a GitHub access token for the given repositories.""" + response = self.request("POST", "/github/token", json_data={"repos": repos}) + return GithubTokenResponse(**response.json()) +``` + + ### get\_project @@ -634,6 +703,26 @@ def get_run_trace( ``` + + +### get\_user + +```python +get_user() -> UserResponse +``` + +Get the user email and username. + + +```python +def get_user(self) -> UserResponse: + """Get the user email and username.""" + + response = self.request("GET", "/user") + return UserResponse(**response.json()) +``` + + ### get\_user\_data\_credentials @@ -729,6 +818,47 @@ def list_runs(self, project: str) -> list[RunSummary]: ``` + + +### poll\_for\_token + +```python +poll_for_token( + device_code: str, + interval: int = DEFAULT_POLL_INTERVAL, + max_poll_time: int = DEFAULT_MAX_POLL_TIME, +) -> AccessRefreshTokenResponse +``` + +Poll for the access token with the given device code. + + +```python +def poll_for_token( + self, + device_code: str, + interval: int = DEFAULT_POLL_INTERVAL, + max_poll_time: int = DEFAULT_MAX_POLL_TIME, +) -> AccessRefreshTokenResponse: + """Poll for the access token with the given device code.""" + + start_time = datetime.now(timezone.utc) + while (datetime.now(timezone.utc) - start_time).total_seconds() < max_poll_time: + response = self._request( + "POST", "/auth/device/token", json_data={"device_code": device_code} + ) + + if response.status_code == 200: # noqa: PLR2004 + return AccessRefreshTokenResponse(**response.json()) + if response.status_code != 401: # noqa: PLR2004 + raise RuntimeError(self._get_error_message(response)) + + time.sleep(interval) + + raise RuntimeError("Polling for token timed out") +``` + + ### request @@ -782,7 +912,8 @@ def request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Makes an HTTP request to the API and raises exceptions for errors. + """ + Makes an HTTP request to the API and raises exceptions for errors. Args: method (str): The HTTP method (e.g., "GET", "POST"). @@ -808,6 +939,25 @@ def request( ``` + + +### url\_for\_user\_code + +```python +url_for_user_code(user_code: str) -> str +``` + +Get the URL to verify the user code. + + +```python +def url_for_user_code(self, user_code: str) -> str: + """Get the URL to verify the user code.""" + + return f"{self._base_url.removesuffix('/api')}/account/device?code={user_code}" +``` + + ExportFormat ------------ diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 0a2ff169..7cd50324 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -118,7 +118,7 @@ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClie An ApiClient instance. """ if server is not None and token is not None: - return ApiClient(server, token) + return ApiClient(server, api_key=token) if not self._initialized: raise RuntimeError("Call .configure() before accessing the API") @@ -139,6 +139,7 @@ configure( *, server: str | None = None, token: str | None = None, + profile: str | None = None, local_dir: str | Path | Literal[False] = False, project: str | None = None, service_name: str | None = None, @@ -154,11 +155,16 @@ Configure the Dreadnode SDK and call `initialize()`. This method should always be called before using the SDK. -If `server` and `token` are not provided, the SDK will look in -the associated environment variables: +If `server` and `token` are not provided, the SDK will look for them +in the following order: -* `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` -* `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` +1. Environment variables: +2. `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` +3. `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` +4. Dreadnode profile (from `dreadnode login`) +5. Uses `profile` parameter if provided +6. Falls back to `DREADNODE_PROFILE` environment variable +7. Defaults to active profile **Parameters:** @@ -172,6 +178,11 @@ the associated environment variables: `None` ) –The Dreadnode API token. +* **`profile`** + (`str | None`, default: + `None` + ) + –The Dreadnode profile name to use (only used if env vars are not set). * **`local_dir`** (`str | Path | Literal[False]`, default: `False` @@ -215,6 +226,7 @@ def configure( *, server: str | None = None, token: str | None = None, + profile: str | None = None, local_dir: str | Path | t.Literal[False] = False, project: str | None = None, service_name: str | None = None, @@ -228,15 +240,21 @@ def configure( This method should always be called before using the SDK. - If `server` and `token` are not provided, the SDK will look in - the associated environment variables: + If `server` and `token` are not provided, the SDK will look for them + in the following order: - - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` - - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + 1. Environment variables: + - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` + - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + 2. Dreadnode profile (from `dreadnode login`) + - Uses `profile` parameter if provided + - Falls back to `DREADNODE_PROFILE` environment variable + - Defaults to active profile Args: server: The Dreadnode server URL. token: The Dreadnode API token. + profile: The Dreadnode profile name to use (only used if env vars are not set). local_dir: The local directory to store data in. project: The default project name to associate all runs with. service_name: The service name to use for OpenTelemetry. @@ -248,8 +266,43 @@ def configure( self._initialized = False - self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + # Determine configuration source and active profile for logging + config_source = "explicit parameters" + active_profile = None + + if not server or not token: + # Check environment variables first + env_server = os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + env_token = os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + + if env_server or env_token: + config_source = "environment vars" + else: + # Fall back to profile + config_source = "profile" + with contextlib.suppress(Exception): + user_config = UserConfig.read() + profile_name = profile or os.environ.get(ENV_PROFILE) + if profile_name: + active_profile = profile_name + else: + active_profile = user_config.active_profile_name + + if active_profile: + config_source = f"profile: {active_profile}" + + self.server = ( + server + or os.environ.get(ENV_SERVER_URL) + or os.environ.get(ENV_SERVER) + or self._get_profile_server(profile) + ) + self.token = ( + token + or os.environ.get(ENV_API_TOKEN) + or os.environ.get(ENV_API_KEY) + or self._get_profile_api_key(profile) + ) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -267,6 +320,17 @@ def configure( self.send_to_logfire = send_to_logfire self.otel_scope = otel_scope + # Log config information for clarity + if self.server or self.token or self.local_dir: + destination = self.server or DEFAULT_SERVER_URL or "local storage" + rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})") + + # Warn the user if the profile didn't resolve + elif active_profile and not (self.server or self.token): + rich.print( + f":exclamation: Dreadnode profile [orange_red1]{active_profile}[/] appears invalid." + ) + self.initialize() ``` @@ -305,7 +369,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan: A RunSpan object that can be used as a context manager. """ if not self._initialized: - self.initialize() + self.configure() return RunSpan.from_context( context=run_context, @@ -398,7 +462,8 @@ def initialize(self) -> None: if not (self.server or self.token or self.local_dir): warn_at_user_stacklevel( "Your current configuration won't persist run data anywhere. " - "Use `dreadnode.init(server=..., token=...)`, `dreadnode.init(local_dir=...)`, " + "Login with `dreadnode login` to set up a server and token, " + "Use `dreadnode.configure(server=..., token=...)`, `dreadnode.configure(profile=...)`, " f"or use environment variables ({ENV_SERVER_URL}, {ENV_API_TOKEN}, {ENV_LOCAL_DIR}).", category=DreadnodeConfigWarning, ) @@ -422,7 +487,7 @@ def initialize(self) -> None: ) self.server = urlunparse(parsed_new) - self._api = ApiClient(self.server, self.token) + self._api = ApiClient(self.server, api_key=self.token) self._api.list_projects() except Exception as e: @@ -456,7 +521,7 @@ def initialize(self) -> None: # ) credentials = self._api.get_user_data_credentials() - resolved_endpoint = self._resolve_endpoint(credentials.endpoint) + resolved_endpoint = resolve_endpoint(credentials.endpoint) self._fs = S3FileSystem( key=credentials.access_key_id, secret=credentials.secret_access_key, @@ -1639,7 +1704,7 @@ def run( The run will automatically be completed when the context manager exits. """ if not self._initialized: - self.initialize() + self.configure() if name is None: name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311 # nosec diff --git a/docs/usage/cli.mdx b/docs/usage/cli.mdx new file mode 100644 index 00000000..58524026 --- /dev/null +++ b/docs/usage/cli.mdx @@ -0,0 +1,195 @@ +--- +title: "CLI" +description: "Use the native command-line interface" +public: true +--- + +The Dreadnode CLI provides a command-line interface for authenticating with Dreadnode platforms, managing profiles, and cloning repositories. It's installed automatically with the `dreadnode` package. + +## Quick Start + +After installing the package, authenticate with your platform: + +```bash +dreadnode login +``` + +This opens your browser to authenticate and stores your credentials locally. + +## Commands + +### Authentication + +#### `login` + +Authenticate to a Dreadnode platform server. + +```bash +dreadnode login [--server URL] [--profile NAME] +``` + +**Options:** +- `--server`, `-s`: URL of the server (defaults to hosted platform) +- `--profile`, `-p`: Profile alias to assign or update + +**Examples:** +```bash +# Login to hosted platform +dreadnode login + +# Login to self-hosted server +dreadnode login --server https://my-server.com + +# Login with a specific profile name +dreadnode login --profile production +``` + +#### `refresh` + +Refresh data for the active server profile. + +```bash +dreadnode refresh +``` + +Updates your local profile with the latest user information from the server. + +### Profile Management + +#### `profile show` + +List all configured server profiles. + +```bash +dreadnode profile show +``` + +Shows a table with profile names, URLs, emails, usernames, and token expiration times. The active profile is marked with an asterisk. + +#### `profile switch` + +Set the active server profile. + +```bash +dreadnode profile switch PROFILE +``` + +**Arguments:** +- `PROFILE`: Name of the profile to switch to + +#### `profile forget` + +Remove a server profile. + +```bash +dreadnode profile forget PROFILE +``` + +**Arguments:** +- `PROFILE`: Name of the profile to remove + +### Repository Management + +#### `clone` + +Clone a GitHub repository. + +```bash +dreadnode clone REPO [TARGET] +``` + +**Arguments:** +- `REPO`: Repository name (e.g., `dreadnode/example-agents`) or full GitHub URL +- `TARGET`: Optional target directory (defaults to repository name) + +**Examples:** +```bash +# Clone a public repository +dreadnode clone dreadnode/example-agents + +# Clone to a specific directory +dreadnode clone dreadnode/example-agents ./my-agents + +# Clone a private dreadnode repository (requires authentication) +dreadnode clone dreadnode/private-repo +``` + + +The `clone` command can access privately shared `dreadnode/*` repositories using your authentication token. + +**Server Validation:** Private `dreadnode/*` repositories require authentication via a Dreadnode SaaS server (ending with `.dreadnode.io`). If your current profile points to a self-hosted server, the CLI will: + +1. Warn you about the server mismatch +2. Offer to switch to an available SaaS profile if one exists +3. Allow you to continue with a warning if you choose + +For other private repositories, use standard Git authentication. + + +### Meta Commands + +#### `version` + +Show version information. + +```bash +dreadnode version +``` + +Displays platform, Python version, and Dreadnode package version. + +#### `--help` + +Show help information for any command. + +```bash +dreadnode --help +dreadnode login --help +dreadnode profile --help +``` + +## Profile Configuration + +The CLI stores authentication data in `~/.dreadnode/config`. Each profile contains: + +- Server URL +- User credentials (access/refresh tokens, API key) +- User information (email, username) + +You can have multiple profiles for different servers or accounts: + +```bash +# Add different server profiles +dreadnode login --profile public +dreadnode login --server https://self-hosted --profile self-hosted + +# Switch between them +dreadnode profile switch self-hosted +dreadnode profile switch public +``` + +### Environment Variable Profile Selection + +You can override the active profile using the `DREADNODE_PROFILE` environment variable: + +```bash +# Temporarily use a different profile +export DREADNODE_PROFILE=production +dreadnode clone dreadnode/private-repo # Uses production profile + +# Or for a single command +DREADNODE_PROFILE=staging dreadnode clone dreadnode/test-repo +``` + +This affects both CLI commands and SDK configuration when using `dreadnode.configure()` without explicit server/token parameters. + +### Which Profile Gets Used? + +The CLI picks a profile in this order: + +1. **`--profile` flag** (if provided) +2. **`DREADNODE_PROFILE` environment variable** +3. **Active profile** (set via `dreadnode profile switch`) +4. **"main" profile** (default) + +The CLI will remember your server URL for future commands within that profile. \ No newline at end of file diff --git a/docs/usage/config.mdx b/docs/usage/config.mdx index 79109adc..710cd4cf 100644 --- a/docs/usage/config.mdx +++ b/docs/usage/config.mdx @@ -4,34 +4,117 @@ description: "Set configuration values" public: true --- -The quickest way to configure Strikes is to set the `DREADNODE_API_KEY` environment variable and let the library handle the rest with `dreadnode.configure()`. However, there are quite a few additional options you can set as needed. +## Self-Hosted Platforms -## Self-Hosted Platform +If you're using a **self-hosted Dreadnode platform**, you must specify your server URL during authentication: -If you're using a **self-hosted Dreadnode platform**, you must always specify your server URL explicitly. The SDK defaults to `https://platform.dreadnode.io` otherwise. +```bash +dreadnode login --server https://your-server.com +``` + +This creates a profile for your self-hosted instance. You can manage multiple servers by creating profiles with custom names: + +```bash +# Create profiles for different environments +dreadnode login --server https://dev.company.com --profile dev +dreadnode login --server https://prod.company.com --profile production +``` + +Switch between profiles anytime: + +```bash +dreadnode profile switch +``` + +Your code automatically uses the active profile - no changes needed. For automation and CI/CD with self-hosted platforms, use environment variables: + +```bash +export DREADNODE_SERVER="https://your-server.com" +export DREADNODE_API_KEY="your-api-token" +``` + +## When You Need `configure()` + +Most users never need to call `configure()` explicitly. The SDK auto-configures itself using CLI authentication or environment variables. + +**You only need `configure()` if you want to:** + +### Customize Configuration + +```python +dreadnode.configure( + local_dir="./my-custom-storage", # Custom local storage + project="my-project", # Default project name + console=False, # Disable console logging +) +``` + +### Override Auto-Detection - -```python in code +```python dreadnode.configure( - server="https://hosted-server", # Your self-hosted server URL - token="your-api-token", + server="https://platform.dreadnode.io", + token="your-api-token", # Explicit credentials + profile="production" # Specific Dreadnode profile ) ``` -```bash environment variables -export DREADNODE_SERVER="https://hosted-server" +### Use Environment Variables (CI/CD) + +```bash export DREADNODE_API_KEY="your-api-token" +# No configure() call needed - SDK auto-detects +``` + +**💡 For most users:** Skip `configure()` entirely. Use `dreadnode login` once and you're set. + +## Advanced Configuration + +### Managing Multiple Environments with Profiles + +Profiles let you manage multiple Dreadnode servers (development, staging, production, etc.) and switch between them seamlessly: + +**Create profiles for different environments:** + +```bash +# Hosted environments +dreadnode login --profile dev +dreadnode login --profile staging +dreadnode login --profile production + +# Self-hosted environments +dreadnode login --server https://dev.company.com --profile dev-internal +dreadnode login --server https://prod.company.com --profile prod-internal ``` - -## Using `configure()` +**View and manage profiles:** + +```bash +dreadnode profile show # List all profiles +dreadnode profile switch staging # Switch active profile +dreadnode profile forget dev # Remove a profile +``` -Initialize and set up connections with `configure()`. +**Use specific profiles in code:** + +```python +# Use a specific profile programmatically +dreadnode.configure(profile="production") + +# Or with environment variable +# DREADNODE_PROFILE=production +dreadnode.configure() +``` + +**Profile priority:** Environment variable `DREADNODE_PROFILE` overrides the active CLI profile. + +### Full Configuration Options ```python dreadnode.configure( - server="https://platform.dreadnode.io", # Platform URL - token="your-api-token", # API token for authentication + server="https://platform.dreadnode.io", # Platform URL (optional if using CLI/env) + token="your-api-token", # API token (optional if using CLI/env) + profile="production", # Dreadnode profile (only used if server/token not provided) local_dir="./runs", # Directory for local span storage project="my-project", # Default project name console=True, # Enable console logging @@ -43,13 +126,84 @@ dreadnode.configure( ) ``` -## Using Environment Variables +## Environment Variables Reference + +Environment variables are a great alternative for automated deployments and CI/CD pipelines. They override CLI profiles but are overridden by explicit `configure()` parameters. -Set variables to call `.config()` more easily. +### Complete Reference ```bash -export DREADNODE_SERVER="https://platform.dreadnode.io" # or DREADNODE_SERVER_URL -export DREADNODE_API_KEY="your-api-token" # or DREADNODE_API_TOKEN -export DREADNODE_LOCAL_DIR="./runs" -export DREADNODE_PROJECT="my-project" +# Authentication (choose one) +export DREADNODE_API_KEY="your-api-token" # Recommended +export DREADNODE_API_TOKEN="your-api-token" # Alternative + +# Server configuration +export DREADNODE_SERVER="https://your-server.com" # Recommended +export DREADNODE_SERVER_URL="https://your-server.com" # Alternative + +# Profile selection (when not using explicit server/token) +export DREADNODE_PROFILE="production" + +# Optional settings +export DREADNODE_LOCAL_DIR="./runs" # Local storage directory +export DREADNODE_PROJECT="my-project" # Default project name ``` + +## Configuration Priority Order + +The SDK resolves configuration in this priority order: + +### 1. Explicit Parameters (Highest Priority) + +```python +# These always override everything else +dreadnode.configure( + server="https://override.com", + token="explicit-token" +) +``` + +### 2. Environment Variables + +```python +# Set via shell or CI/CD +# export DREADNODE_SERVER="https://env.com" +# export DREADNODE_API_KEY="env-token" + +dreadnode.configure() # Uses env vars +``` + +### 3. CLI Profiles + +```python +# After: dreadnode login --profile production +dreadnode.configure(profile="production") # Uses CLI profile + +# Or let SDK auto-detect active profile +dreadnode.configure() # Uses active CLI profile +``` + +### 4. Local-Only Mode (Fallback) + +```python +# No credentials found anywhere +dreadnode.configure() # ⚠️ Works locally with warning +``` + +**Examples demonstrating priority:** + +```python +# Environment overrides profile +# export DREADNODE_API_KEY="env-token" +dreadnode.configure(profile="production") # Uses env-token (not profile) + +# Explicit param overrides environment +# export DREADNODE_SERVER="https://env.com" +dreadnode.configure(server="https://explicit.com") # Uses explicit.com + +# Profile selection with DREADNODE_PROFILE env var +# export DREADNODE_PROFILE="staging" +dreadnode.configure() # Uses "staging" profile (not active profile) +``` + +**💡 Bottom line:** Just start coding. The SDK will find your credentials or work locally. \ No newline at end of file diff --git a/dreadnode/__main__.py b/dreadnode/__main__.py new file mode 100644 index 00000000..cab7e045 --- /dev/null +++ b/dreadnode/__main__.py @@ -0,0 +1,10 @@ +from dreadnode.cli import cli + + +def run() -> None: + """Run the Dreadnode CLI.""" + cli.meta() + + +if __name__ == "__main__": + run() diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index 789edc23..5b311482 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -1,22 +1,19 @@ import io import json +import time import typing as t +from datetime import datetime, timezone +from urllib.parse import urlparse import httpx import pandas as pd from pydantic import BaseModel from ulid import ULID -from dreadnode.api.util import ( - convert_flat_tasks_to_tree, - convert_flat_trace_to_tree, - process_run, - process_task, -) -from dreadnode.util import logger -from dreadnode.version import VERSION - -from .models import ( +from dreadnode.api.models import ( + AccessRefreshTokenResponse, + DeviceCodeResponse, + GithubTokenResponse, MetricAggregationType, Project, RawRun, @@ -31,7 +28,17 @@ TraceSpan, TraceTree, UserDataCredentials, + UserResponse, ) +from dreadnode.api.util import ( + convert_flat_tasks_to_tree, + convert_flat_trace_to_tree, + process_run, + process_task, +) +from dreadnode.constants import DEFAULT_MAX_POLL_TIME, DEFAULT_POLL_INTERVAL +from dreadnode.util import logger +from dreadnode.version import VERSION ModelT = t.TypeVar("ModelT", bound=BaseModel) @@ -47,11 +54,13 @@ class ApiClient: def __init__( self, base_url: str, - api_key: str, *, + api_key: str | None = None, + cookies: dict[str, str] | None = None, debug: bool = False, ): - """Initializes the API client. + """ + Initializes the API client. Args: base_url (str): The base URL of the Dreadnode API. @@ -62,12 +71,28 @@ def __init__( if not self._base_url.endswith("/api"): self._base_url += "/api" + _cookies = httpx.Cookies() + cookie_domain = urlparse(base_url).hostname + if cookie_domain is None: + raise ValueError(f"Invalid URL: {base_url}") + + if cookie_domain == "localhost": + cookie_domain = "localhost.local" + + for key, value in (cookies or {}).items(): + _cookies.set(key, value, domain=cookie_domain) + + headers = { + "User-Agent": f"dreadnode-sdk/{VERSION}", + "Accept": "application/json", + } + + if api_key: + headers["X-Api-Key"] = api_key + self._client = httpx.Client( - headers={ - "User-Agent": f"dreadnode-sdk/{VERSION}", - "Accept": "application/json", - "X-API-Key": api_key, - }, + headers=headers, + cookies=_cookies, base_url=self._base_url, timeout=30, ) @@ -77,7 +102,8 @@ def __init__( self._client.event_hooks["response"].append(self._log_response) def _log_request(self, request: httpx.Request) -> None: - """Logs HTTP requests if debug mode is enabled. + """ + Logs HTTP requests if debug mode is enabled. Args: request (httpx.Request): The HTTP request object. @@ -90,7 +116,8 @@ def _log_request(self, request: httpx.Request) -> None: logger.debug("-------------------------------------------") def _log_response(self, response: httpx.Response) -> None: - """Logs HTTP responses if debug mode is enabled. + """ + Logs HTTP responses if debug mode is enabled. Args: response (httpx.Response): The HTTP response object. @@ -103,7 +130,8 @@ def _log_response(self, response: httpx.Response) -> None: logger.debug("--------------------------------------------") def _get_error_message(self, response: httpx.Response) -> str: - """Extracts the error message from an HTTP response. + """ + Extracts the error message from an HTTP response. Args: response (httpx.Response): The HTTP response object. @@ -125,7 +153,8 @@ def _request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Makes a raw HTTP request to the API. + """ + Makes a raw HTTP request to the API. Args: method (str): The HTTP method (e.g., "GET", "POST"). @@ -146,7 +175,8 @@ def request( params: dict[str, t.Any] | None = None, json_data: dict[str, t.Any] | None = None, ) -> httpx.Response: - """Makes an HTTP request to the API and raises exceptions for errors. + """ + Makes an HTTP request to the API and raises exceptions for errors. Args: method (str): The HTTP method (e.g., "GET", "POST"). @@ -170,6 +200,59 @@ def request( return response + # Auth + + def url_for_user_code(self, user_code: str) -> str: + """Get the URL to verify the user code.""" + + return f"{self._base_url.removesuffix('/api')}/account/device?code={user_code}" + + def get_device_codes(self) -> DeviceCodeResponse: + """Start the authentication flow by requesting user and device codes.""" + + response = self.request("POST", "/auth/device/code") + return DeviceCodeResponse(**response.json()) + + def poll_for_token( + self, + device_code: str, + interval: int = DEFAULT_POLL_INTERVAL, + max_poll_time: int = DEFAULT_MAX_POLL_TIME, + ) -> AccessRefreshTokenResponse: + """Poll for the access token with the given device code.""" + + start_time = datetime.now(timezone.utc) + while (datetime.now(timezone.utc) - start_time).total_seconds() < max_poll_time: + response = self._request( + "POST", "/auth/device/token", json_data={"device_code": device_code} + ) + + if response.status_code == 200: # noqa: PLR2004 + return AccessRefreshTokenResponse(**response.json()) + if response.status_code != 401: # noqa: PLR2004 + raise RuntimeError(self._get_error_message(response)) + + time.sleep(interval) + + raise RuntimeError("Polling for token timed out") + + # User + + def get_user(self) -> UserResponse: + """Get the user email and username.""" + + response = self.request("GET", "/user") + return UserResponse(**response.json()) + + # Github + + def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse: + """Try to get a GitHub access token for the given repositories.""" + response = self.request("POST", "/github/token", json_data={"repos": repos}) + return GithubTokenResponse(**response.json()) + + # Strikes + def list_projects(self) -> list[Project]: """Retrieves a list of projects. @@ -294,7 +377,8 @@ def export_runs( status: StatusFilter = "completed", aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports run data for a specific project. + """ + Exports run data for a specific project. Args: project: The project identifier. @@ -327,7 +411,8 @@ def export_metrics( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports metric data for a specific project. + """ + Exports metric data for a specific project. Args: project: The project identifier. @@ -363,7 +448,8 @@ def export_parameters( metrics: list[str] | None = None, aggregations: list[MetricAggregationType] | None = None, ) -> pd.DataFrame: - """Exports parameter data for a specific project. + """ + Exports parameter data for a specific project. Args: project: The project identifier. @@ -401,7 +487,8 @@ def export_timeseries( time_axis: TimeAxisType = "relative", aggregations: list[TimeAggregationType] | None = None, ) -> pd.DataFrame: - """Exports timeseries data for a specific project. + """ + Exports timeseries data for a specific project. Args: project: The project identifier. diff --git a/dreadnode/api/models.py b/dreadnode/api/models.py index 53685180..61c52dda 100644 --- a/dreadnode/api/models.py +++ b/dreadnode/api/models.py @@ -32,6 +32,35 @@ class UserResponse(BaseModel): api_key: UserAPIKey +class UserDataCredentials(BaseModel): + access_key_id: str + secret_access_key: str + session_token: str + expiration: datetime + region: str + bucket: str + prefix: str + endpoint: str | None + + +# Auth + + +class DeviceCodeResponse(BaseModel): + id: UUID + completed: bool + device_code: str + expires_at: datetime + expires_in: int + user_code: str + verification_url: str + + +class AccessRefreshTokenResponse(BaseModel): + access_token: str + refresh_token: str + + # Strikes SpanStatus = t.Literal[ @@ -406,15 +435,10 @@ class TraceTree(BaseModel): """Children of this span, representing nested spans or tasks.""" -# User data credentials +# Github -class UserDataCredentials(BaseModel): - access_key_id: str - secret_access_key: str - session_token: str - expiration: datetime - region: str - bucket: str - prefix: str - endpoint: str | None +class GithubTokenResponse(BaseModel): + token: str + expires_at: datetime + repos: list[str] diff --git a/dreadnode/api/util.py b/dreadnode/api/util.py index 004e3a91..b5b42adb 100644 --- a/dreadnode/api/util.py +++ b/dreadnode/api/util.py @@ -1,6 +1,6 @@ from logging import getLogger -from .models import ( +from dreadnode.api.models import ( Object, ObjectUri, ObjectVal, diff --git a/dreadnode/cli/__init__.py b/dreadnode/cli/__init__.py new file mode 100644 index 00000000..b9e50854 --- /dev/null +++ b/dreadnode/cli/__init__.py @@ -0,0 +1,3 @@ +from dreadnode.cli.main import cli + +__all__ = ["cli"] diff --git a/dreadnode/cli/api.py b/dreadnode/cli/api.py new file mode 100644 index 00000000..96a59dfc --- /dev/null +++ b/dreadnode/cli/api.py @@ -0,0 +1,79 @@ +import atexit +import base64 +import json +from datetime import datetime, timezone + +from dreadnode.api.client import ApiClient +from dreadnode.config import UserConfig +from dreadnode.constants import ( + DEFAULT_TOKEN_MAX_TTL, +) + + +class Token: + """A JWT token with an expiration time.""" + + data: str + expires_at: datetime + + @staticmethod + def parse_jwt_token_expiration(token: str) -> datetime: + """Return the expiration date from a JWT token.""" + + _, b64payload, _ = token.split(".") + payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8") + return datetime.fromtimestamp(json.loads(payload).get("exp"), tz=timezone.utc) + + def __init__(self, token: str): + self.data = token + self.expires_at = Token.parse_jwt_token_expiration(token) + + def ttl(self) -> int: + """Get number of seconds left until the token expires.""" + return int((self.expires_at - datetime.now(tz=timezone.utc)).total_seconds()) + + def is_expired(self) -> bool: + """Return True if the token is expired.""" + return self.ttl() <= 0 + + def is_close_to_expiry(self) -> bool: + """Return True if the token is close to expiry.""" + return self.ttl() <= DEFAULT_TOKEN_MAX_TTL + + +def create_api_client(*, profile: str | None = None) -> ApiClient: + """Create an authenticated API client using stored configuration data.""" + + user_config = UserConfig.read() + config = user_config.get_server_config(profile) + + client = ApiClient( + config.url, + cookies={"access_token": config.access_token, "refresh_token": config.refresh_token}, + ) + + # Preemptively check if the token is expired + if Token(config.refresh_token).is_expired(): + raise RuntimeError("Authentication expired, use [bold]dreadnode login[/]") + + def _flush_auth_changes() -> None: + """Flush the authentication data to disk if it has been updated.""" + + access_token = client._client.cookies.get("access_token") # noqa: SLF001 + refresh_token = client._client.cookies.get("refresh_token") # noqa: SLF001 + + changed: bool = False + if access_token and access_token != config.access_token: + changed = True + config.access_token = access_token + + if refresh_token and refresh_token != config.refresh_token: + changed = True + config.refresh_token = refresh_token + + if changed: + user_config.set_server_config(config, profile).write() + + atexit.register(_flush_auth_changes) + + return client diff --git a/dreadnode/cli/github.py b/dreadnode/cli/github.py new file mode 100644 index 00000000..6b33390d --- /dev/null +++ b/dreadnode/cli/github.py @@ -0,0 +1,273 @@ +import os +import pathlib +import re +import tempfile +import typing as t +import zipfile + +import httpx +import rich +from rich.prompt import Prompt + +from dreadnode.config import UserConfig, find_dreadnode_saas_profiles, is_dreadnode_saas_server + + +class GithubRepo(str): # noqa: SLOT000 + """ + A string subclass that normalizes various GitHub repository string formats. + + Supported formats: + - Full URLs: https://github.com/owner/repo + - SSH URLs: git@github.com:owner/repo.git + - Simple format: owner/repo + - With ref: owner/repo/tree/main + - With complex ref: owner/repo/tree/feature/custom + - With ref (URL): https://github.com/owner/repo/tree/main + - With .git: owner/repo.git + - Raw URLs: https://raw.githubusercontent.com/owner/repo/main + - Release URLs: owner/repo/releases/tag/v1.0.0 + - ZIP URLs: https://github.com/owner/repo/zipball/main + - Simple with ref: owner/repo@ref + """ + + # Instance properties + namespace: str + repo: str + ref: str + + # Regex patterns + SSH_PATTERN = re.compile(r"git@github\.com:([^/]+)/([^/]+?)(\.git)?$") + SIMPLE_PATTERN = re.compile(r"^([^/]+)/([^/]+?)(\.git)?$") + URL_PATTERN = re.compile(r"github\.com/([^/]+)/([^/]+?)(?:\.git|/(?:tree|blob)/(.+?))?$") + RAW_PATTERN = re.compile(r"raw\.githubusercontent\.com/([^/]+)/([^/]+)/(.+)") + RELEASE_PATTERN = re.compile(r"([^/]+)/([^/]+)/releases/tag/(.+)$") + OWN_FORMAT_PATTERN = re.compile(r"^([^/]+)/([^/@:]+)@(.+)$") + ZIPBALL_PATTERN = re.compile(r"github\.com/([^/]+)/([^/]+?)/zipball/(.+)$") + + def __new__(cls, value: t.Any, *_: t.Any, **__: t.Any) -> "GithubRepo": # noqa: PLR0912, PLR0915 + if not isinstance(value, str): + return super().__new__(cls, str(value)) + + namespace = None + repo = None + ref = "main" + + value = value.strip() + + # Try our own format first (owner/repo@ref) + match = cls.OWN_FORMAT_PATTERN.match(value) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) + + # Try as an SSH URL + elif value.startswith("git@"): + match = cls.SSH_PATTERN.search(value) + if match: + namespace, repo = match.group(1), match.group(2) + + # Try as a full URL + elif value.startswith(("http://", "https://")): + url_parts = value.split("//", 1)[1] + + # Try zipball pattern first + match = cls.ZIPBALL_PATTERN.search(url_parts) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) + + # Try raw githubusercontent pattern + elif url_parts.startswith("raw.githubusercontent.com"): + match = cls.RAW_PATTERN.search(url_parts) + if match: + namespace, repo, ref = match.group(1), match.group(2), match.group(3) + + # Try standard GitHub URL pattern + else: + match = cls.URL_PATTERN.search(url_parts) + if match: + namespace = match.group(1) + repo = match.group(2) + ref = match.group(3) or ref + + # Try release tag format + elif "/releases/tag/" in value: + match = cls.RELEASE_PATTERN.match(value) + if match: + namespace, repo, ref = match.group(1), match.group(2), match.group(3) + + # Try simple owner/repo format + else: + # First try to extract any ref + tree_parts = value.split("/tree/") + blob_parts = value.split("/blob/") + + if len(tree_parts) > 1: + value, ref = tree_parts[0], tree_parts[1] + elif len(blob_parts) > 1: + value, ref = blob_parts[0], blob_parts[1] + + # Now check for owner/repo pattern + match = cls.SIMPLE_PATTERN.match(value) + if match: + namespace, repo = match.group(1), match.group(2) + + if not namespace or not repo: + raise ValueError(f"Invalid GitHub repository format: {value}") + + repo = repo.removesuffix(".git") + + obj = super().__new__(cls, f"{namespace}/{repo}@{ref}") + + obj.namespace = namespace + obj.repo = repo + obj.ref = ref + + return obj + + @property + def zip_url(self) -> str: + """ZIP archive URL for the repository.""" + return f"https://github.com/{self.namespace}/{self.repo}/zipball/{self.ref}" + + @property + def api_zip_url(self) -> str: + """API ZIP archive URL for the repository.""" + return f"https://api.github.com/repos/{self.namespace}/{self.repo}/zipball/{self.ref}" + + @property + def tree_url(self) -> str: + """URL to view the tree at this reference.""" + return f"https://github.com/{self.namespace}/{self.repo}/tree/{self.ref}" + + @property + def exists(self) -> bool: + """Check if a repo exists (or is private) on GitHub.""" + response = httpx.get(f"https://github.com/{self.namespace}/{self.repo}") + return response.status_code == 200 # noqa: PLR2004 + + def __repr__(self) -> str: + return f"GithubRepo(namespace='{self.namespace}', repo='{self.repo}', ref='{self.ref}')" + + +def get_repo_archive_source_path(source_dir: pathlib.Path) -> pathlib.Path: + """Return the actual source directory from a git repositoryZIP archive.""" + + if not (source_dir / "Dockerfile").exists() and not (source_dir / "Dockerfile.j2").exists(): + # if src has been downloaded from a ZIP archive, it may contain a single + # '--' folder, that is the actual source we want to use. + # Check if source_dir contains only one folder and update it if so. + children = list(source_dir.iterdir()) + if len(children) == 1 and children[0].is_dir(): + source_dir = children[0] + + return source_dir + + +def download_and_unzip_archive(url: str, *, headers: dict[str, str] | None = None) -> pathlib.Path: + """ + Downloads a ZIP archive from the given URL and unzips it into a temporary directory. + """ + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + local_zip_path = temp_dir / "archive.zip" + + rich.print(f":arrow_double_down: Downloading {url} ...") + + # download to temporary file + with httpx.stream("GET", url, follow_redirects=True, verify=True, headers=headers) as response: + response.raise_for_status() + with local_zip_path.open("wb") as zip_file: + for chunk in response.iter_bytes(chunk_size=8192): + zip_file.write(chunk) + + # unzip to temporary directory + try: + with zipfile.ZipFile(local_zip_path, "r") as zf: + for member in zf.infolist(): + file_path = os.path.realpath(temp_dir / member.filename) + if file_path.startswith(os.path.realpath(temp_dir)): + zf.extract(member, temp_dir) + else: + raise RuntimeError("Invalid file path detected in archive") + + finally: + # always remove the zip file + if local_zip_path.exists(): + local_zip_path.unlink() + + return temp_dir + + +def validate_server_for_clone(user_config: UserConfig, current_profile: str | None) -> str | None: + """ + Validate the server configuration for git clone operations. + + Returns: + The profile name to use, or None if the user cancelled. + """ + config = user_config.get_server_config(current_profile) + current_server = config.url + + # If current server is a Dreadnode SaaS server, all good + if is_dreadnode_saas_server(current_server): + return current_profile or user_config.active_profile_name + + # Current server is not a Dreadnode SaaS server - warn user + rich.print() + rich.print(":warning: [yellow]Warning: Current server is not a Dreadnode SaaS server[/]") + rich.print(f" Current server: [cyan]{current_server}[/]") + rich.print(f" Current profile: [cyan]{current_profile or user_config.active_profile_name}[/]") + rich.print() + rich.print("Git clone for private dreadnode repositories requires a Dreadnode SaaS server") + rich.print("(ending with '.dreadnode.io') for authentication to work properly.") + rich.print() + + # Check if there are any SaaS profiles available + saas_profiles = find_dreadnode_saas_profiles(user_config) + + if saas_profiles: + rich.print("Available Dreadnode SaaS profiles:") + for profile in saas_profiles: + server_url = user_config.servers[profile].url + rich.print(f" - [green]{profile}[/] ({server_url})") + rich.print() + + choices = ["continue", "switch", "cancel"] + choice = Prompt.ask( + "Choose an option", choices=choices, default="cancel", show_choices=True + ) + + if choice == "continue": + rich.print( + ":warning: [yellow]Continuing with current server - private repository access may fail[/]" + ) + return current_profile or user_config.active_profile_name + if choice == "cancel": + rich.print("Cancelled.") + return None + if choice == "switch": + # Let user pick a profile + profile_choice = Prompt.ask( + "Select profile to use", choices=saas_profiles, default=saas_profiles[0] + ) + rich.print( + f":arrows_counterclockwise: Using profile '[green]{profile_choice}[/]' for this operation" + ) + return profile_choice + else: + # No SaaS profiles available + choice = Prompt.ask("Continue anyway?", choices=["y", "n"], default="n") + + if choice == "y": + rich.print( + ":warning: [yellow]Continuing with current server - private repository access may fail[/]" + ) + return current_profile or user_config.active_profile_name + rich.print( + "Cancelled. Use [bold]dreadnode login --server https://platform.dreadnode.io[/] to add a SaaS profile." + ) + + return None diff --git a/dreadnode/cli/main.py b/dreadnode/cli/main.py new file mode 100644 index 00000000..16dbef9e --- /dev/null +++ b/dreadnode/cli/main.py @@ -0,0 +1,200 @@ +import contextlib +import pathlib +import shutil +import sys +import typing as t +import webbrowser + +import cyclopts +import rich +from rich.panel import Panel +from rich.prompt import Prompt + +from dreadnode.api.client import ApiClient +from dreadnode.cli.api import create_api_client +from dreadnode.cli.github import GithubRepo, download_and_unzip_archive, validate_server_for_clone +from dreadnode.cli.profile import cli as profile_cli +from dreadnode.config import ServerConfig, UserConfig +from dreadnode.constants import DEBUG, PLATFORM_BASE_URL + +cli = cyclopts.App(help="Interact with Dreadnode platforms", version_flags=[], help_on_error=True) + +cli["--help"].group = "Meta" + +cli.command(profile_cli) + + +@cli.meta.default +def meta( + *tokens: t.Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)], +) -> None: + try: + rich.print() + cli(tokens) + except Exception as e: + if DEBUG: + raise + + rich.print() + rich.print(Panel(str(e), title="Error", title_align="left", border_style="red")) + sys.exit(1) + + +@cli.command(help="Authenticate to a platform server.", group="Auth") +def login( + *, + server: t.Annotated[ + str | None, cyclopts.Parameter(name=["--server", "-s"], help="URL of the server") + ] = None, + profile: t.Annotated[ + str | None, + cyclopts.Parameter(name=["--profile", "-p"], help="Profile alias to assign / update"), + ] = None, +) -> None: + if not server: + server = PLATFORM_BASE_URL + with contextlib.suppress(Exception): + existing_config = UserConfig.read().get_server_config(profile) + server = existing_config.url + + # create client with no auth data + client = ApiClient(base_url=server) + + rich.print(":laptop_computer: Requesting device code ...") + + # request user and device codes + codes = client.get_device_codes() + + # present verification URL to user + verification_url = client.url_for_user_code(codes.user_code) + verification_url_base = verification_url.split("?")[0] + + rich.print() + rich.print( + f"""\ +Attempting to automatically open the authorization page in your default browser. +If the browser does not open or you wish to use a different device, open the following URL: + +:link: [bold]{verification_url_base}[/] + +Then enter the code: [bold]{codes.user_code}[/] +""" + ) + + webbrowser.open(verification_url) + + # poll for the access token after user verification + tokens = client.poll_for_token(codes.device_code) + + client = ApiClient( + server, cookies={"refresh_token": tokens.refresh_token, "access_token": tokens.access_token} + ) + user = client.get_user() + + UserConfig.read().set_server_config( + ServerConfig( + url=server, + access_token=tokens.access_token, + refresh_token=tokens.refresh_token, + email=user.email_address, + username=user.username, + api_key=user.api_key.key, + ), + profile, + ).write() + + rich.print(f":white_check_mark: Authenticated as {user.email_address} ({user.username})") + + +@cli.command(help="Refresh data for the active server profile.", group="Auth") +def refresh() -> None: + user_config = UserConfig.read() + server_config = user_config.get_server_config() + + client = create_api_client() + user = client.get_user() + + server_config.email = user.email_address + server_config.username = user.username + server_config.api_key = user.api_key.key + + user_config.set_server_config(server_config).write() + + rich.print( + f":white_check_mark: Refreshed '[bold]{user_config.active}[/bold]' ([magenta]{user.email_address}[/] / [cyan]{user.username}[/])" + ) + + +@cli.command(help="Clone a github repository, typically privately shared dreadnode repositories.") +def clone( + repo: t.Annotated[str, cyclopts.Parameter(help="Repository name or URL")], + target: t.Annotated[ + pathlib.Path | None, + cyclopts.Parameter(help="The target directory"), + ] = None, +) -> None: + github_repo = GithubRepo(repo) + + # Check if the target directory exists + target = target or pathlib.Path(github_repo.repo) + if target.exists(): + if ( + Prompt.ask(f":axe: Overwrite {target.absolute()}?", choices=["y", "n"], default="n") + == "n" + ): + return + rich.print() + shutil.rmtree(target) + + # Check if the repo is accessible + if github_repo.exists: + temp_dir = download_and_unzip_archive(github_repo.zip_url) + + # This could be a private repo that the user can access + # by getting an access token from our API + elif github_repo.namespace == "dreadnode": + # Validate server configuration for private repository access + user_config = UserConfig.read() + profile_to_use = validate_server_for_clone(user_config, None) + + if profile_to_use is None: + return # User cancelled + + github_access_token = create_api_client(profile=profile_to_use).get_github_access_token( + [github_repo.repo] + ) + rich.print(":key: Accessed private repository") + temp_dir = download_and_unzip_archive( + github_repo.api_zip_url, + headers={"Authorization": f"Bearer {github_access_token.token}"}, + ) + + else: + raise RuntimeError(f"Repository '{github_repo}' not found or inaccessible") + + # We assume the repo download results in a single + # child folder which is the real target + sub_dirs = list(temp_dir.iterdir()) + if len(sub_dirs) == 1 and sub_dirs[0].is_dir(): + temp_dir = sub_dirs[0] + + shutil.move(temp_dir, target) + + rich.print() + rich.print(f":tada: Cloned [b]{repo}[/] to [b]{target.absolute()}[/]") + + +@cli.command(help="Show versions and exit.", group="Meta") +def version() -> None: + import importlib.metadata + import platform + import sys + + version = importlib.metadata.version("dreadnode") + python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + + os_name = platform.system() + arch = platform.machine() + rich.print(f"Platform: {os_name} ({arch})") + rich.print(f"Python: {python_version}") + rich.print(f"Dreadnode: {version}") diff --git a/dreadnode/cli/profile/__init__.py b/dreadnode/cli/profile/__init__.py new file mode 100644 index 00000000..77af6edd --- /dev/null +++ b/dreadnode/cli/profile/__init__.py @@ -0,0 +1,3 @@ +from dreadnode.cli.profile.cli import cli + +__all__ = ["cli"] diff --git a/dreadnode/cli/profile/cli.py b/dreadnode/cli/profile/cli.py new file mode 100644 index 00000000..95e736ad --- /dev/null +++ b/dreadnode/cli/profile/cli.py @@ -0,0 +1,101 @@ +import typing as t + +import cyclopts +import rich +from rich import box +from rich.table import Table + +from dreadnode.cli.api import Token +from dreadnode.config import UserConfig +from dreadnode.util import time_to + +cli = cyclopts.App(name="profile", help="Manage server profiles") + + +@cli.command(name=["show", "list"], help="List all server profiles") +def show() -> None: + config = UserConfig.read() + if not config.servers: + rich.print(":exclamation: No server profiles are configured") + return + + table = Table(box=box.ROUNDED) + table.add_column("Profile", style="orange_red1") + table.add_column("URL", style="cyan") + table.add_column("Email") + table.add_column("Username") + table.add_column("Valid Until") + + for profile, server in config.servers.items(): + active = profile == config.active + refresh_token = Token(server.refresh_token) + + table.add_row( + profile + ("*" if active else ""), + server.url, + server.email, + server.username, + "[red]expired[/]" + if refresh_token.is_expired() + else f"{refresh_token.expires_at.astimezone().strftime('%c')} ({time_to(refresh_token.expires_at)})", + style="bold" if active else None, + ) + + rich.print(table) + + +@cli.command(help="Set the active server profile") +def switch( + profile: t.Annotated[str | None, cyclopts.Parameter(help="Profile to switch to")] = None, +) -> None: + config = UserConfig.read() + + if not config.servers: + rich.print(":exclamation: No server profiles are configured") + return + + # If no profile provided, prompt user to choose + if profile is None: + from rich.prompt import Prompt + + profiles = list(config.servers.keys()) + rich.print("\nAvailable profiles:") + for i, p in enumerate(profiles, 1): + active_marker = " (current)" if p == config.active else "" + rich.print(f" {i}. [bold orange_red1]{p}[/]{active_marker}") + + choice = Prompt.ask( + "\nSelect a profile", + choices=[str(i) for i in range(1, len(profiles) + 1)] + profiles, + show_choices=False, + ) + + profile = profiles[int(choice) - 1] if choice.isdigit() else choice + + if profile not in config.servers: + rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist") + return + + config.active = profile + config.write() + + rich.print(f":laptop_computer: Switched to [bold orange_red1]{profile}[/]") + rich.print(f"|- email: [bold]{config.servers[profile].email}[/]") + rich.print(f"|- username: {config.servers[profile].username}") + rich.print(f"|- url: {config.servers[profile].url}") + rich.print() + + +@cli.command(help="Remove a server profile") +def forget( + profile: t.Annotated[str, cyclopts.Parameter(help="Profile of the server to remove")], +) -> None: + config = UserConfig.read() + if profile not in config.servers: + rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist") + return + + del config.servers[profile] + config.write() + + rich.print(f":axe: Forgot about [bold]{profile}[/]") diff --git a/dreadnode/config.py b/dreadnode/config.py new file mode 100644 index 00000000..f1daa806 --- /dev/null +++ b/dreadnode/config.py @@ -0,0 +1,108 @@ +import rich +from pydantic import BaseModel +from ruamel.yaml import YAML + +from dreadnode.constants import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH + + +class ServerConfig(BaseModel): + """Server specific authentication data and API URL.""" + + url: str + email: str + username: str + api_key: str + access_token: str + refresh_token: str + + +class UserConfig(BaseModel): + """User configuration supporting multiple server profiles.""" + + active: str | None = None + servers: dict[str, ServerConfig] = {} + + def _update_active(self) -> None: + """If active is not set, set it to the first available server and raise an error if no servers are configured.""" + + if self.active not in self.servers: + self.active = next(iter(self.servers)) if self.servers else None + + def _update_urls(self) -> bool: + updated = False + for search, replace in { + "//staging-crucible.dreadnode.io": "//staging-platform.dreadnode.io", + "//dev-crucible.dreadnode.io": "//dev-platform.dreadnode.io", + "//crucible.dreadnode.io": "//platform.dreadnode.io", + }.items(): + for server in self.servers.values(): + if search in server.url: + server.url = server.url.replace(search, replace) + updated = True + return updated + + @classmethod + def read(cls) -> "UserConfig": + """Read the user configuration from the file system or return an empty instance.""" + + if not USER_CONFIG_PATH.exists(): + return cls() + + with USER_CONFIG_PATH.open("r") as f: + self = cls.model_validate(YAML().load(f)) + + if self._update_urls(): + self.write() + + return self + + def write(self) -> None: + """Write the user configuration to the file system.""" + + self._update_active() + + if not USER_CONFIG_PATH.parent.exists(): + rich.print(f":rocket: Creating config at {USER_CONFIG_PATH.parent}") + USER_CONFIG_PATH.parent.mkdir(parents=True) + + with USER_CONFIG_PATH.open("w") as f: + YAML().dump(self.model_dump(mode="json"), f) + + @property + def active_profile_name(self) -> str | None: + """Get the name of the active profile.""" + self._update_active() + return self.active + + def get_server_config(self, profile: str | None = None) -> ServerConfig: + """Get the server configuration for the given profile or None if not set.""" + + profile = profile or self.active + if not profile: + raise RuntimeError("No profile is set, use [bold]dreadnode login[/] to authenticate") + + if profile not in self.servers: + raise RuntimeError(f"No server configuration for profile: {profile}") + + return self.servers[profile] + + def set_server_config(self, config: ServerConfig, profile: str | None = None) -> "UserConfig": + """Set the server configuration for the given profile.""" + + profile = profile or self.active or DEFAULT_PROFILE_NAME + self.servers[profile] = config + return self + + +def is_dreadnode_saas_server(url: str) -> bool: + """Check if the server URL is a Dreadnode SaaS server (ends with dreadnode.io).""" + return url.rstrip("/").endswith(".dreadnode.io") + + +def find_dreadnode_saas_profiles(user_config: UserConfig) -> list[str]: + """Find all profiles that point to Dreadnode SaaS servers.""" + saas_profiles = [] + for profile_name, server_config in user_config.servers.items(): + if is_dreadnode_saas_server(server_config.url): + saas_profiles.append(profile_name) + return saas_profiles diff --git a/dreadnode/constants.py b/dreadnode/constants.py index ae9dc730..e8df3b79 100644 --- a/dreadnode/constants.py +++ b/dreadnode/constants.py @@ -1,4 +1,36 @@ -# Environment variable names +import os +import pathlib + +# +# Defaults +# + +# name of the default server profile +DEFAULT_PROFILE_NAME = "main" +# default poll interval for the authentication flow +DEFAULT_POLL_INTERVAL = 5 +# default maximum poll time for the authentication flow +DEFAULT_MAX_POLL_TIME = 300 +# default maximum token TTL in seconds +DEFAULT_TOKEN_MAX_TTL = 60 +# Default values for the S3 storage +DEFAULT_MAX_INLINE_OBJECT_BYTES = 10 * 1024 # 10KB +# default platform domain +DEFAULT_PLATFORM_BASE_DOMAIN = "dreadnode.io" +# default server URL +DEFAULT_SERVER_URL = f"https://platform.{DEFAULT_PLATFORM_BASE_DOMAIN}" +# default local directory for dreadnode objects +DEFAULT_LOCAL_OBJECT_DIR = ".dreadnode/objects" +# default docker registry subdomain +DEFAULT_DOCKER_REGISTRY_SUBDOMAIN = "registry" +# default docker registry local port +DEFAULT_DOCKER_REGISTRY_LOCAL_PORT = 5005 +# default docker registry image tag +DEFAULT_DOCKER_REGISTRY_IMAGE_TAG = "registry" + +# +# Environment Variable Names +# ENV_SERVER_URL = "DREADNODE_SERVER_URL" ENV_SERVER = "DREADNODE_SERVER" # alternative to SERVER_URL @@ -6,11 +38,20 @@ ENV_API_KEY = "DREADNODE_API_KEY" # pragma: allowlist secret (alternative to API_TOKEN) ENV_LOCAL_DIR = "DREADNODE_LOCAL_DIR" ENV_PROJECT = "DREADNODE_PROJECT" +ENV_PROFILE = "DREADNODE_PROFILE" -# Default values +# +# Environment +# -DEFAULT_SERVER_URL = "https://platform.dreadnode.io" -DEFAULT_LOCAL_OBJECT_DIR = ".dreadnode/objects" +# enable debugging +DEBUG = bool(os.getenv("DREADNODE_DEBUG")) or False -# Default values for the S3 storage -MAX_INLINE_OBJECT_BYTES = 10 * 1024 # 10KB +# server url +PLATFORM_BASE_URL = os.getenv(ENV_SERVER, os.getenv(ENV_SERVER_URL, DEFAULT_SERVER_URL)) + +# path to the user configuration file +USER_CONFIG_PATH = pathlib.Path( + # allow overriding the user config file via env variable + os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config" +) diff --git a/dreadnode/data_types/__init__.py b/dreadnode/data_types/__init__.py index 04a95f21..11eac1d4 100644 --- a/dreadnode/data_types/__init__.py +++ b/dreadnode/data_types/__init__.py @@ -1,9 +1,9 @@ -from .audio import Audio -from .base import WithMeta -from .image import Image -from .object_3d import Object3D -from .table import Table -from .text import Code, Markdown, Text -from .video import Video +from dreadnode.data_types.audio import Audio +from dreadnode.data_types.base import WithMeta +from dreadnode.data_types.image import Image +from dreadnode.data_types.object_3d import Object3D +from dreadnode.data_types.table import Table +from dreadnode.data_types.text import Code, Markdown, Text +from dreadnode.data_types.video import Video __all__ = ["Audio", "Code", "Image", "Markdown", "Object3D", "Table", "Text", "Video", "WithMeta"] diff --git a/dreadnode/main.py b/dreadnode/main.py index 321ba373..da813707 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -2,15 +2,15 @@ import inspect import os import random -import socket import typing as t from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from urllib.parse import ParseResult, urljoin, urlparse, urlunparse +from urllib.parse import urljoin, urlparse, urlunparse import coolname # type: ignore [import-untyped] import logfire +import rich from fsspec.implementations.local import ( # type: ignore [import-untyped] LocalFileSystem, ) @@ -24,11 +24,13 @@ from s3fs import S3FileSystem # type: ignore [import-untyped] from dreadnode.api.client import ApiClient +from dreadnode.config import UserConfig from dreadnode.constants import ( DEFAULT_SERVER_URL, ENV_API_KEY, ENV_API_TOKEN, ENV_LOCAL_DIR, + ENV_PROFILE, ENV_PROJECT, ENV_SERVER, ENV_SERVER_URL, @@ -61,7 +63,7 @@ Inherited, JsonValue, ) -from dreadnode.util import clean_str, handle_internal_errors, logger +from dreadnode.util import clean_str, handle_internal_errors, resolve_endpoint from dreadnode.version import VERSION if t.TYPE_CHECKING: @@ -135,102 +137,32 @@ def __init__( self._initialized = False - @staticmethod - def _resolve_endpoint(endpoint: str | None) -> str | None: - """Automatically resolve endpoints based on environment + def _get_profile_server(self, profile: str | None = None) -> str | None: + with contextlib.suppress(Exception): + user_config = UserConfig.read() + profile = profile or os.environ.get(ENV_PROFILE) + server_config = user_config.get_server_config(profile) + return server_config.url - Args: - endpoint: The endpoint URL to resolve. - - Returns: - str: The resolved endpoint URL. - - Raises: - ValueError: If the endpoint URL is invalid. - """ - if not endpoint: - return None - parsed = urlparse(endpoint) - - # If it's a real domain (has dots), use as-is - if not parsed.hostname: - raise ValueError(f"Invalid endpoint URL: {endpoint}") - - if "." in parsed.hostname: - return endpoint - - # If it's a service name, try to resolve it - if Dreadnode._is_docker_service_name(parsed.hostname): - return Dreadnode._resolve_docker_service(endpoint, parsed) - - return endpoint + # Silently fail if profile config is not available or invalid + return None - @staticmethod - def _is_docker_service_name(hostname: str) -> bool: - """Check if this looks like a Docker service name + def _get_profile_api_key(self, profile: str | None = None) -> str | None: + with contextlib.suppress(Exception): + user_config = UserConfig.read() + profile = profile or os.environ.get(ENV_PROFILE) + server_config = user_config.get_server_config(profile) + return server_config.api_key - Args: - hostname: The hostname to check. - - Returns: - bool: True if the hostname looks like a Docker service name, False otherwise. - """ - return bool(hostname and "." not in hostname and hostname != "localhost") - - @staticmethod - def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: - """Try different resolution strategies for Docker services - - Args: - original_endpoint: The original endpoint URL. - parsed: The parsed URL object. - - Returns: - str: The resolved endpoint URL. - - Raises: - RuntimeError: If no valid endpoint is found. - """ - strategies = [ - original_endpoint, # Try original first (works if running in same network) - f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost - f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop - f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP - ] - - for endpoint in strategies: - if Dreadnode._test_connection(endpoint): - logger.warning( - f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'." - ) - return str(endpoint) - - # If nothing works, return original and let it fail with a helpful error - raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") - - @staticmethod - def _test_connection(endpoint: str) -> bool: - """Quick connectivity test - - Args: - endpoint: The endpoint URL to test. - - Returns: - bool: True if the connection is successful, False otherwise. - """ - try: - parsed = urlparse(endpoint) - socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1) - except Exception: # noqa: BLE001 - return False - - return True + # Silently fail if profile config is not available or invalid + return None def configure( self, *, server: str | None = None, token: str | None = None, + profile: str | None = None, local_dir: str | Path | t.Literal[False] = False, project: str | None = None, service_name: str | None = None, @@ -244,15 +176,21 @@ def configure( This method should always be called before using the SDK. - If `server` and `token` are not provided, the SDK will look in - the associated environment variables: + If `server` and `token` are not provided, the SDK will look for them + in the following order: - - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` - - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + 1. Environment variables: + - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER` + - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY` + 2. Dreadnode profile (from `dreadnode login`) + - Uses `profile` parameter if provided + - Falls back to `DREADNODE_PROFILE` environment variable + - Defaults to active profile Args: server: The Dreadnode server URL. token: The Dreadnode API token. + profile: The Dreadnode profile name to use (only used if env vars are not set). local_dir: The local directory to store data in. project: The default project name to associate all runs with. service_name: The service name to use for OpenTelemetry. @@ -264,8 +202,43 @@ def configure( self._initialized = False - self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + # Determine configuration source and active profile for logging + config_source = "explicit parameters" + active_profile = None + + if not server or not token: + # Check environment variables first + env_server = os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + env_token = os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + + if env_server or env_token: + config_source = "environment vars" + else: + # Fall back to profile + config_source = "profile" + with contextlib.suppress(Exception): + user_config = UserConfig.read() + profile_name = profile or os.environ.get(ENV_PROFILE) + if profile_name: + active_profile = profile_name + else: + active_profile = user_config.active_profile_name + + if active_profile: + config_source = f"profile: {active_profile}" + + self.server = ( + server + or os.environ.get(ENV_SERVER_URL) + or os.environ.get(ENV_SERVER) + or self._get_profile_server(profile) + ) + self.token = ( + token + or os.environ.get(ENV_API_TOKEN) + or os.environ.get(ENV_API_KEY) + or self._get_profile_api_key(profile) + ) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -283,6 +256,17 @@ def configure( self.send_to_logfire = send_to_logfire self.otel_scope = otel_scope + # Log config information for clarity + if self.server or self.token or self.local_dir: + destination = self.server or DEFAULT_SERVER_URL or "local storage" + rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})") + + # Warn the user if the profile didn't resolve + elif active_profile and not (self.server or self.token): + rich.print( + f":exclamation: Dreadnode profile [orange_red1]{active_profile}[/] appears invalid." + ) + self.initialize() def initialize(self) -> None: @@ -301,7 +285,8 @@ def initialize(self) -> None: if not (self.server or self.token or self.local_dir): warn_at_user_stacklevel( "Your current configuration won't persist run data anywhere. " - "Use `dreadnode.init(server=..., token=...)`, `dreadnode.init(local_dir=...)`, " + "Login with `dreadnode login` to set up a server and token, " + "Use `dreadnode.configure(server=..., token=...)`, `dreadnode.configure(profile=...)`, " f"or use environment variables ({ENV_SERVER_URL}, {ENV_API_TOKEN}, {ENV_LOCAL_DIR}).", category=DreadnodeConfigWarning, ) @@ -325,7 +310,7 @@ def initialize(self) -> None: ) self.server = urlunparse(parsed_new) - self._api = ApiClient(self.server, self.token) + self._api = ApiClient(self.server, api_key=self.token) self._api.list_projects() except Exception as e: @@ -359,7 +344,7 @@ def initialize(self) -> None: # ) credentials = self._api.get_user_data_credentials() - resolved_endpoint = self._resolve_endpoint(credentials.endpoint) + resolved_endpoint = resolve_endpoint(credentials.endpoint) self._fs = S3FileSystem( key=credentials.access_key_id, secret=credentials.secret_access_key, @@ -406,7 +391,7 @@ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClie An ApiClient instance. """ if server is not None and token is not None: - return ApiClient(server, token) + return ApiClient(server, api_key=token) if not self._initialized: raise RuntimeError("Call .configure() before accessing the API") @@ -773,7 +758,7 @@ def run( The run will automatically be completed when the context manager exits. """ if not self._initialized: - self.initialize() + self.configure() if name is None: name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311 # nosec @@ -827,7 +812,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan: A RunSpan object that can be used as a context manager. """ if not self._initialized: - self.initialize() + self.configure() return RunSpan.from_context( context=run_context, diff --git a/dreadnode/object.py b/dreadnode/object.py index f03495db..3f208ed3 100644 --- a/dreadnode/object.py +++ b/dreadnode/object.py @@ -1,8 +1,7 @@ import typing as t from dataclasses import dataclass -from litellm import ConfigDict -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from dreadnode.types import AnyDict diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 47baf284..3ed93910 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -31,16 +31,12 @@ from dreadnode.artifact.merger import ArtifactMerger from dreadnode.artifact.storage import ArtifactStorage from dreadnode.artifact.tree_builder import ArtifactTreeBuilder, DirectoryNode -from dreadnode.constants import MAX_INLINE_OBJECT_BYTES +from dreadnode.constants import DEFAULT_MAX_INLINE_OBJECT_BYTES from dreadnode.convert import run_span_to_graph from dreadnode.metric import Metric, MetricAggMode, MetricsDict from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal from dreadnode.serialization import Serialized, serialize -from dreadnode.types import UNSET, AnyDict, JsonDict, Unset -from dreadnode.util import clean_str -from dreadnode.version import VERSION - -from .constants import ( +from dreadnode.tracing.constants import ( EVENT_ATTRIBUTE_LINK_HASH, EVENT_ATTRIBUTE_OBJECT_HASH, EVENT_ATTRIBUTE_OBJECT_LABEL, @@ -69,6 +65,9 @@ SPAN_ATTRIBUTE_VERSION, SpanType, ) +from dreadnode.types import UNSET, AnyDict, JsonDict, Unset +from dreadnode.util import clean_str +from dreadnode.version import VERSION if t.TYPE_CHECKING: import networkx as nx # type: ignore [import-untyped] @@ -631,7 +630,7 @@ def _create_object_by_hash(self, serialized: Serialized, object_hash: str) -> Ob data_hash = serialized.data_hash schema_hash = serialized.schema_hash - if data is None or data_bytes is None or data_len <= MAX_INLINE_OBJECT_BYTES: + if data is None or data_bytes is None or data_len <= DEFAULT_MAX_INLINE_OBJECT_BYTES: return ObjectVal( hash=object_hash, value=data, diff --git a/dreadnode/util.py b/dreadnode/util.py index f9883d6a..77558d88 100644 --- a/dreadnode/util.py +++ b/dreadnode/util.py @@ -5,11 +5,14 @@ import logging import os import re +import socket import sys import typing as t from contextlib import contextmanager +from datetime import datetime from pathlib import Path from types import TracebackType +from urllib.parse import ParseResult, urlparse from logfire import suppress_instrumentation from logfire._internal.stack_info import add_non_user_code_prefix, is_user_code @@ -60,6 +63,29 @@ def safe_repr(obj: t.Any) -> str: return "" +def time_to(future_datetime: datetime) -> str: + """Get a string describing the time difference between a future datetime and now.""" + + now = datetime.now(tz=future_datetime.tzinfo) + time_difference = future_datetime - now + + days = time_difference.days + seconds = time_difference.seconds + hours = seconds // 3600 + minutes = (seconds % 3600) // 60 + seconds = seconds % 60 + + result = [] + if days > 0: + result.append(f"{days}d") + if hours > 0: + result.append(f"{hours}hr") + if minutes > 0: + result.append(f"{minutes}m") + + return ", ".join(result) if result else "Just now" + + def log_internal_error() -> None: try: current_test = os.environ.get("PYTEST_CURRENT_TEST", "") @@ -162,3 +188,96 @@ def handle_internal_errors() -> t.Iterator[None]: _HANDLE_INTERNAL_ERRORS_CODE = inspect.unwrap(handle_internal_errors).__code__ + + +def is_docker_service_name(hostname: str) -> bool: + """Check if this looks like a Docker service name + + Args: + hostname: The hostname to check. + + Returns: + bool: True if the hostname looks like a Docker service name, False otherwise. + """ + return bool(hostname and "." not in hostname and hostname != "localhost") + + +def resolve_endpoint(endpoint: str | None) -> str | None: + """Automatically resolve endpoints based on environment + + Args: + endpoint: The endpoint URL to resolve. + + Returns: + str: The resolved endpoint URL. + + Raises: + ValueError: If the endpoint URL is invalid. + """ + if not endpoint: + return None + parsed = urlparse(endpoint) + + # If it's a real domain (has dots), use as-is + if not parsed.hostname: + raise ValueError(f"Invalid endpoint URL: {endpoint}") + + if "." in parsed.hostname: + return endpoint + + # If it's a service name, try to resolve it + if is_docker_service_name(parsed.hostname): + return resolve_docker_service(endpoint, parsed) + + return endpoint + + +def test_connection(endpoint: str) -> bool: + """ + Simple test to check if the endpoint is reachable. + + Args: + endpoint: The endpoint URL to test. + + Returns: + bool: True if the endpoint is reachable, False otherwise. + """ + try: + parsed = urlparse(endpoint) + socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1) + except Exception: # noqa: BLE001 + return False + + return True + + +def resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: + """ + Try different resolution strategies for Docker services + + Args: + original_endpoint: The original endpoint URL. + parsed: The parsed URL object. + + Returns: + str: The resolved endpoint URL. + + Raises: + RuntimeError: If no valid endpoint is found. + """ + strategies = [ + original_endpoint, # Try original first (works if running in same network) + f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost + f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop + f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP + ] + + for endpoint in strategies: + if test_connection(endpoint): + logger.warning( + f"Resolved Docker service endpoint '{parsed.hostname}' to '{endpoint}'." # noqa: G004 + ) + return str(endpoint) + + # If nothing works, return original and let it fail with a helpful error + raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") diff --git a/poetry.lock b/poetry.lock index c676fd20..003108b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -199,6 +199,29 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""] +[[package]] +name = "beautifulsoup4" +version = "4.13.4" +description = "Screen-scraping library" +optional = false +python-versions = ">=3.7.0" +groups = ["dev"] +files = [ + {file = "beautifulsoup4-4.13.4-py3-none-any.whl", hash = "sha256:9bbbb14bfde9d79f38b8cd5f8c7c85f4b8f2523190ebed90e950a8dea4cb1c4b"}, + {file = "beautifulsoup4-4.13.4.tar.gz", hash = "sha256:dbb3c4e1ceae6aefebdaf2423247260cd062430a410e38c66f2baa50a8437195"}, +] + +[package.dependencies] +soupsieve = ">1.2" +typing-extensions = ">=4.0.0" + +[package.extras] +cchardet = ["cchardet"] +chardet = ["chardet"] +charset-normalizer = ["charset-normalizer"] +html5lib = ["html5lib"] +lxml = ["lxml"] + [[package]] name = "boto3" version = "1.38.14" @@ -926,7 +949,6 @@ files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "coolname" @@ -940,6 +962,30 @@ files = [ {file = "coolname-2.2.0.tar.gz", hash = "sha256:6c5d5731759104479e7ca195a9b64f7900ac5bead40183c09323c7d0be9e75c7"}, ] +[[package]] +name = "cyclopts" +version = "3.22.2" +description = "Intuitive, easy CLIs based on type hints." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "cyclopts-3.22.2-py3-none-any.whl", hash = "sha256:6681b0815fa2de2bccc364468fd25b15aa9617cb505c0b16ca62e2b18a57619e"}, + {file = "cyclopts-3.22.2.tar.gz", hash = "sha256:d3495231af6ae86479579777d212ddf77b113200f828badeaf401162ed87227d"}, +] + +[package.dependencies] +attrs = ">=23.1.0" +docstring-parser = {version = ">=0.15", markers = "python_version < \"4.0\""} +rich = ">=13.6.0" +rich-rst = ">=1.3.1,<2.0.0" +typing-extensions = {version = ">=4.8.0", markers = "python_version < \"3.11\""} + +[package.extras] +toml = ["tomli (>=2.0.0) ; python_version < \"3.11\""] +trio = ["trio (>=0.10.0)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "datasets" version = "3.6.0" @@ -1054,6 +1100,30 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "docstring-parser" +version = "0.16" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +groups = ["main"] +files = [ + {file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"}, + {file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"}, +] + +[[package]] +name = "docutils" +version = "0.21.2" +description = "Docutils -- Python Documentation Utilities" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2"}, + {file = "docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f"}, +] + [[package]] name = "exceptiongroup" version = "1.3.0" @@ -1263,6 +1333,24 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "ghp-import" +version = "2.1.0" +description = "Copy your docs directly to the gh-pages branch." +optional = false +python-versions = "*" +groups = ["dev"] +files = [ + {file = "ghp-import-2.1.0.tar.gz", hash = "sha256:9c535c4c61193c2df8871222567d7fd7e5014d835f97dc7b7439069e2413d343"}, + {file = "ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619"}, +] + +[package.dependencies] +python-dateutil = ">=2.8.1" + +[package.extras] +dev = ["flake8", "markdown", "twine", "wheel"] + [[package]] name = "googleapis-common-protos" version = "1.70.0" @@ -1281,6 +1369,21 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0)"] +[[package]] +name = "griffe" +version = "1.7.3" +description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "griffe-1.7.3-py3-none-any.whl", hash = "sha256:c6b3ee30c2f0f17f30bcdef5068d6ab7a2a4f1b8bf1a3e74b56fffd21e1c5f75"}, + {file = "griffe-1.7.3.tar.gz", hash = "sha256:52ee893c6a3a968b639ace8015bec9d36594961e156e23315c8e8e51401fa50b"}, +] + +[package.dependencies] +colorama = ">=0.4" + [[package]] name = "h11" version = "0.16.0" @@ -1536,7 +1639,7 @@ version = "3.1.6" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"}, {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"}, @@ -1806,6 +1909,22 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (==8.1.3) ; python_version >= \"3.11\"", "build (==1.2.2) ; python_version >= \"3.11\"", "colorama (==0.4.5) ; python_version < \"3.8\"", "colorama (==0.4.6) ; python_version >= \"3.8\"", "exceptiongroup (==1.1.3) ; python_version >= \"3.7\" and python_version < \"3.11\"", "freezegun (==1.1.0) ; python_version < \"3.8\"", "freezegun (==1.5.0) ; python_version >= \"3.8\"", "mypy (==v0.910) ; python_version < \"3.6\"", "mypy (==v0.971) ; python_version == \"3.6\"", "mypy (==v1.13.0) ; python_version >= \"3.8\"", "mypy (==v1.4.1) ; python_version == \"3.7\"", "myst-parser (==4.0.0) ; python_version >= \"3.11\"", "pre-commit (==4.0.1) ; python_version >= \"3.9\"", "pytest (==6.1.2) ; python_version < \"3.8\"", "pytest (==8.3.2) ; python_version >= \"3.8\"", "pytest-cov (==2.12.1) ; python_version < \"3.8\"", "pytest-cov (==5.0.0) ; python_version == \"3.8\"", "pytest-cov (==6.0.0) ; python_version >= \"3.9\"", "pytest-mypy-plugins (==1.9.3) ; python_version >= \"3.6\" and python_version < \"3.8\"", "pytest-mypy-plugins (==3.1.0) ; python_version >= \"3.8\"", "sphinx-rtd-theme (==3.0.2) ; python_version >= \"3.11\"", "tox (==3.27.1) ; python_version < \"3.8\"", "tox (==4.23.2) ; python_version >= \"3.8\"", "twine (==6.0.1) ; python_version >= \"3.11\""] +[[package]] +name = "markdown" +version = "3.8.2" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "markdown-3.8.2-py3-none-any.whl", hash = "sha256:5c83764dbd4e00bdd94d85a19b8d55ccca20fe35b2e678a1422b380324dd5f24"}, + {file = "markdown-3.8.2.tar.gz", hash = "sha256:247b9a70dd12e27f67431ce62523e675b866d254f900c4fe75ce3dda62237c45"}, +] + +[package.extras] +docs = ["mdx_gh_links (>=0.2)", "mkdocs (>=1.6)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -1831,13 +1950,29 @@ profiling = ["gprof2dot"] rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] +[[package]] +name = "markdownify" +version = "1.1.0" +description = "Convert HTML to markdown." +optional = false +python-versions = "*" +groups = ["dev"] +files = [ + {file = "markdownify-1.1.0-py3-none-any.whl", hash = "sha256:32a5a08e9af02c8a6528942224c91b933b4bd2c7d078f9012943776fc313eeef"}, + {file = "markdownify-1.1.0.tar.gz", hash = "sha256:449c0bbbf1401c5112379619524f33b63490a8fa479456d41de9dc9e37560ebd"}, +] + +[package.dependencies] +beautifulsoup4 = ">=4.9,<5" +six = ">=1.15,<2" + [[package]] name = "markupsafe" version = "3.0.2" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"}, {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"}, @@ -1944,6 +2079,126 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mergedeep" +version = "1.3.4" +description = "A deep merge function for 🐍." +optional = false +python-versions = ">=3.6" +groups = ["dev"] +files = [ + {file = "mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307"}, + {file = "mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8"}, +] + +[[package]] +name = "mkdocs" +version = "1.6.1" +description = "Project documentation with Markdown." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "mkdocs-1.6.1-py3-none-any.whl", hash = "sha256:db91759624d1647f3f34aa0c3f327dd2601beae39a366d6e064c03468d35c20e"}, + {file = "mkdocs-1.6.1.tar.gz", hash = "sha256:7b432f01d928c084353ab39c57282f29f92136665bdd6abf7c1ec8d822ef86f2"}, +] + +[package.dependencies] +click = ">=7.0" +colorama = {version = ">=0.4", markers = "platform_system == \"Windows\""} +ghp-import = ">=1.0" +jinja2 = ">=2.11.1" +markdown = ">=3.3.6" +markupsafe = ">=2.0.1" +mergedeep = ">=1.3.4" +mkdocs-get-deps = ">=0.2.0" +packaging = ">=20.5" +pathspec = ">=0.11.1" +pyyaml = ">=5.1" +pyyaml-env-tag = ">=0.1" +watchdog = ">=2.0" + +[package.extras] +i18n = ["babel (>=2.9.0)"] +min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4) ; platform_system == \"Windows\"", "ghp-import (==1.0)", "importlib-metadata (==4.4) ; python_version < \"3.10\"", "jinja2 (==2.11.1)", "markdown (==3.3.6)", "markupsafe (==2.0.1)", "mergedeep (==1.3.4)", "mkdocs-get-deps (==0.2.0)", "packaging (==20.5)", "pathspec (==0.11.1)", "pyyaml (==5.1)", "pyyaml-env-tag (==0.1)", "watchdog (==2.0)"] + +[[package]] +name = "mkdocs-autorefs" +version = "1.4.2" +description = "Automatically link across pages in MkDocs." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mkdocs_autorefs-1.4.2-py3-none-any.whl", hash = "sha256:83d6d777b66ec3c372a1aad4ae0cf77c243ba5bcda5bf0c6b8a2c5e7a3d89f13"}, + {file = "mkdocs_autorefs-1.4.2.tar.gz", hash = "sha256:e2ebe1abd2b67d597ed19378c0fff84d73d1dbce411fce7a7cc6f161888b6749"}, +] + +[package.dependencies] +Markdown = ">=3.3" +markupsafe = ">=2.0.1" +mkdocs = ">=1.1" + +[[package]] +name = "mkdocs-get-deps" +version = "0.2.0" +description = "MkDocs extension that lists all dependencies according to a mkdocs.yml file" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "mkdocs_get_deps-0.2.0-py3-none-any.whl", hash = "sha256:2bf11d0b133e77a0dd036abeeb06dec8775e46efa526dc70667d8863eefc6134"}, + {file = "mkdocs_get_deps-0.2.0.tar.gz", hash = "sha256:162b3d129c7fad9b19abfdcb9c1458a651628e4b1dea628ac68790fb3061c60c"}, +] + +[package.dependencies] +mergedeep = ">=1.3.4" +platformdirs = ">=2.2.0" +pyyaml = ">=5.1" + +[[package]] +name = "mkdocstrings" +version = "0.29.1" +description = "Automatic documentation from sources, for MkDocs." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mkdocstrings-0.29.1-py3-none-any.whl", hash = "sha256:37a9736134934eea89cbd055a513d40a020d87dfcae9e3052c2a6b8cd4af09b6"}, + {file = "mkdocstrings-0.29.1.tar.gz", hash = "sha256:8722f8f8c5cd75da56671e0a0c1bbed1df9946c0cef74794d6141b34011abd42"}, +] + +[package.dependencies] +Jinja2 = ">=2.11.1" +Markdown = ">=3.6" +MarkupSafe = ">=1.1" +mkdocs = ">=1.6" +mkdocs-autorefs = ">=1.4" +pymdown-extensions = ">=6.3" + +[package.extras] +crystal = ["mkdocstrings-crystal (>=0.3.4)"] +python = ["mkdocstrings-python (>=1.16.2)"] +python-legacy = ["mkdocstrings-python-legacy (>=0.2.1)"] + +[[package]] +name = "mkdocstrings-python" +version = "1.16.12" +description = "A Python handler for mkdocstrings." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mkdocstrings_python-1.16.12-py3-none-any.whl", hash = "sha256:22ded3a63b3d823d57457a70ff9860d5a4de9e8b1e482876fc9baabaf6f5f374"}, + {file = "mkdocstrings_python-1.16.12.tar.gz", hash = "sha256:9b9eaa066e0024342d433e332a41095c4e429937024945fea511afe58f63175d"}, +] + +[package.dependencies] +griffe = ">=1.6.2" +mkdocs-autorefs = ">=1.4" +mkdocstrings = ">=0.28.3" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + [[package]] name = "moviepy" version = "2.2.1" @@ -2532,6 +2787,18 @@ files = [ numpy = ">=1.23.5" types-pytz = ">=2022.1.1" +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pillow" version = "11.3.0" @@ -3128,6 +3395,25 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pymdown-extensions" +version = "10.16" +description = "Extension pack for Python Markdown." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pymdown_extensions-10.16-py3-none-any.whl", hash = "sha256:f5dd064a4db588cb2d95229fc4ee63a1b16cc8b4d0e6145c0899ed8723da1df2"}, + {file = "pymdown_extensions-10.16.tar.gz", hash = "sha256:71dac4fca63fabeffd3eb9038b756161a33ec6e8d230853d3cecf562155ab3de"}, +] + +[package.dependencies] +markdown = ">=3.6" +pyyaml = "*" + +[package.extras] +extra = ["pygments (>=2.19.1)"] + [[package]] name = "pytest" version = "8.4.1" @@ -3334,6 +3620,21 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "pyyaml-env-tag" +version = "1.1" +description = "A custom YAML tag for referencing environment variables in YAML files." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04"}, + {file = "pyyaml_env_tag-1.1.tar.gz", hash = "sha256:2eb38b75a2d21ee0475d6d97ec19c63287a7e140231e4214969d0eac923cd7ff"}, +] + +[package.dependencies] +pyyaml = "*" + [[package]] name = "referencing" version = "0.36.2" @@ -3497,6 +3798,22 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.1 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rich-rst" +version = "1.3.1" +description = "A beautiful reStructuredText renderer for rich" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "rich_rst-1.3.1-py3-none-any.whl", hash = "sha256:498a74e3896507ab04492d326e794c3ef76e7cda078703aa592d1853d91098c1"}, + {file = "rich_rst-1.3.1.tar.gz", hash = "sha256:fad46e3ba42785ea8c1785e2ceaa56e0ffa32dbe5410dec432f37e4107c4f383"}, +] + +[package.dependencies] +docutils = "*" +rich = ">=12.0.0" + [[package]] name = "rigging" version = "3.2.1" @@ -3914,6 +4231,18 @@ files = [ cffi = ">=1.0" numpy = "*" +[[package]] +name = "soupsieve" +version = "2.7" +description = "A modern CSS selector implementation for Beautiful Soup." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "soupsieve-2.7-py3-none-any.whl", hash = "sha256:6e60cc5c1ffaf1cebcc12e8188320b72071e922c2e897f737cadce79ad5d30c4"}, + {file = "soupsieve-2.7.tar.gz", hash = "sha256:ad282f9b6926286d2ead4750552c8a6142bc4c783fd66b0293547c8fe6ae126a"}, +] + [[package]] name = "sse-starlette" version = "2.4.1" @@ -4354,6 +4683,49 @@ platformdirs = ">=3.9.1,<5" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] +[[package]] +name = "watchdog" +version = "6.0.0" +description = "Filesystem events monitoring" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "watchdog-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d1cdb490583ebd691c012b3d6dae011000fe42edb7a82ece80965b42abd61f26"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc64ab3bdb6a04d69d4023b29422170b74681784ffb9463ed4870cf2f3e66112"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c897ac1b55c5a1461e16dae288d22bb2e412ba9807df8397a635d88f671d36c3"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6eb11feb5a0d452ee41f824e271ca311a09e250441c262ca2fd7ebcf2461a06c"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ef810fbf7b781a5a593894e4f439773830bdecb885e6880d957d5b9382a960d2"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:afd0fe1b2270917c5e23c2a65ce50c2a4abb63daafb0d419fde368e272a76b7c"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdd4e6f14b8b18c334febb9c4425a878a2ac20efd1e0b231978e7b150f92a948"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c7c15dda13c4eb00d6fb6fc508b3c0ed88b9d5d374056b239c4ad1611125c860"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f10cb2d5902447c7d0da897e2c6768bca89174d0c6e1e30abec5421af97a5b0"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e6f0e77c9417e7cd62af82529b10563db3423625c5fce018430b249bf977f9e8"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:90c8e78f3b94014f7aaae121e6b909674df5b46ec24d6bebc45c44c56729af2a"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7631a77ffb1f7d2eefa4445ebbee491c720a5661ddf6df3498ebecae5ed375c"}, + {file = "watchdog-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c7ac31a19f4545dd92fc25d200694098f42c9a8e391bc00bdd362c5736dbf881"}, + {file = "watchdog-6.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9513f27a1a582d9808cf21a07dae516f0fab1cf2d7683a742c498b93eedabb11"}, + {file = "watchdog-6.0.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7a0e56874cfbc4b9b05c60c8a1926fedf56324bb08cfbc188969777940aef3aa"}, + {file = "watchdog-6.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e6439e374fc012255b4ec786ae3c4bc838cd7309a540e5fe0952d03687d8804e"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2"}, + {file = "watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a"}, + {file = "watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680"}, + {file = "watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f"}, + {file = "watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282"}, +] + +[package.extras] +watchmedo = ["PyYAML (>=3.10)"] + [[package]] name = "win32-setctime" version = "1.2.0" @@ -4751,4 +5123,4 @@ training = ["transformers"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "80a2443f683d537f28b8a17f90ba6c7741c02bc5acfbf6cf9c06d79385f3c145" +content-hash = "0f9e538475309634ca67a66835b23db97718351ceb13f5a835b47ad8b740908b" diff --git a/pyproject.toml b/pyproject.toml index c90cef71..e24efc7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ pandas = "^2.2.3" fsspec = { version = ">=2023.1.0,<=2025.3.0", extras = [ "s3", ] } # Pinned for datasets compatibility +cyclopts = "^3.22.2" transformers = { version = "^4.41.0", optional = true } soundfile = { version = "^0.13.1", optional = true } @@ -47,6 +48,9 @@ types-requests = "^2.32.0.20250306" typer = "^0.15.2" datasets = "^3.5.0" pyarrow = "^19.0.1" +markdown = "^3.8.2" +markdownify = "^1.1.0" +mkdocstrings-python = "^1.16.12" [build-system] requires = ["poetry-core>=1.0.0", "setuptools>=42", "wheel"] @@ -58,6 +62,13 @@ packages = ["src"] [tool.hatch.build.targets.sdist] packages = ["src"] +[project.scripts] +dreadnode = 'dreadnode.__main__:run' +dn = 'dreadnode.__main__:run' + +[tool.poetry.plugins."pipx.run"] +dreadnode = 'dreadnode.__main__:run' + [tool.pytest.ini_options] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function"