diff --git a/README.md b/README.md index 1e924e78..98fe0730 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ PyPI version Python versions License - Documentation + Documentation Units TestsUnits Tests

@@ -217,10 +217,10 @@ Key platform features accessible through this Client and CLI: ## Documentation -- **Command Line Interface (CLI)**: [https://docs.datalayer.app/cli/](https://docs.datalayer.app/cli/) +- **Command Line Interface (CLI)**: [https://datalayer.ai/docs/cli/](https://datalayer.ai/docs/cli/) - **Core Python Client**: [core.datalayer.tech/python/](https://core.datalayer.tech/python/) -- **Platform Documentation**: [docs.datalayer.app](https://docs.datalayer.app/) -- **API Reference**: [API documentation](https://docs.datalayer.app/api/) +- **Platform Documentation**: [docs.datalayer.app](https://datalayer.ai/docs/) +- **API Reference**: [API documentation](https://datalayer.ai/docs/api/) ## Development @@ -317,7 +317,7 @@ This project is licensed under the [BSD 3-Clause License](https://github.com/dat ## Support -- **Documentation**: [Datalayer Platform Documentation](https://docs.datalayer.app/) +- **Documentation**: [Datalayer Platform Documentation](https://datalayer.ai/docs/) - **Issues**: [GitHub Issues](https://github.com/datalayer/core/issues) - **Community**: [Datalayer Platform](https://datalayer.app/) diff --git a/datalayer_core/__version__.py b/datalayer_core/__version__.py index 227c016c..0bad1d00 100644 --- a/datalayer_core/__version__.py +++ b/datalayer_core/__version__.py @@ -3,4 +3,4 @@ """Datalayer Core version information.""" -__version__ = "1.1.22" +__version__ = "1.1.24" diff --git a/datalayer_core/assets/about.md b/datalayer_core/assets/about.md index 46eafccc..5cf726bd 100644 --- a/datalayer_core/assets/about.md +++ b/datalayer_core/assets/about.md @@ -1,5 +1,5 @@ ## About -Datalayer provides a command line tool allowing to list, create, terminate and open a console with runtimes. +Datalayer is a managed AI agents platform for collaborative data analysis, designed to eliminate vendor lock-in. -Read more on https://docs.datalayer.app +Read more on https://datalayer.ai/docs diff --git a/datalayer_core/base/serverapplication.py b/datalayer_core/base/serverapplication.py index 2d88c0bc..0a00ee97 100644 --- a/datalayer_core/base/serverapplication.py +++ b/datalayer_core/base/serverapplication.py @@ -129,7 +129,7 @@ class Brand(Configurable): ) docs_url = Unicode( - "https://docs.datalayer.app", + "https://datalayer.ai/docs", config=True, help=("Documentation URL."), ) diff --git a/datalayer_core/cli/__main__.py b/datalayer_core/cli/__main__.py index a14b4807..8413fcd8 100644 --- a/datalayer_core/cli/__main__.py +++ b/datalayer_core/cli/__main__.py @@ -3,10 +3,17 @@ """Command line interface for Datalayer based on Typer.""" +import os +import sys + import typer from datalayer_core.__version__ import __version__ from datalayer_core.cli.commands.about import app as about_app +from datalayer_core.cli.commands.agents import agents_ls +from datalayer_core.cli.commands.agents import app as agents_app +from datalayer_core.cli.commands.agent_nodes import app as agent_nodes_app +from datalayer_core.cli.commands.agent_nodes import agent_nodes_ls from datalayer_core.cli.commands.authn import ( app as auth_app, ) @@ -16,29 +23,35 @@ whoami_root, ) from datalayer_core.cli.commands.benchmarks import app as benchmarks_app +from datalayer_core.cli.commands.cluster import app as cluster_app from datalayer_core.cli.commands.config import app as config_app from datalayer_core.cli.commands.console import app as console_app from datalayer_core.cli.commands.envs import app as envs_app -from datalayer_core.cli.commands.envs import envs_list, envs_ls +from datalayer_core.cli.commands.envs import envs_ls +from datalayer_core.cli.commands.evals import app as evals_app from datalayer_core.cli.commands.exec import main as exec_main +from datalayer_core.cli.commands.memberships import app as memberships_app from datalayer_core.cli.commands.otel import app as otel_app +from datalayer_core.cli.commands.pools import app as pools_app +from datalayer_core.cli.commands.ray import app as ray_app from datalayer_core.cli.commands.runtime_checkpoints import app as checkpoints_app from datalayer_core.cli.commands.runtime_checkpoints import ( - checkpoints_list, checkpoints_ls, ) -from datalayer_core.cli.commands.runtime_snapshots import app as snapshots_app -from datalayer_core.cli.commands.runtime_snapshots import snapshots_list, snapshots_ls +from datalayer_core.cli.commands.sandbox_snapshots import app as snapshots_app +from datalayer_core.cli.commands.sandbox_snapshots import snapshots_ls from datalayer_core.cli.commands.runtimes import app as runtimes_app -from datalayer_core.cli.commands.runtimes import runtimes_list, runtimes_ls +from datalayer_core.cli.commands.runtimes import runtimes_ls from datalayer_core.cli.commands.secrets import app as secrets_app -from datalayer_core.cli.commands.secrets import secrets_list, secrets_ls +from datalayer_core.cli.commands.secrets import secrets_ls from datalayer_core.cli.commands.subscription import app as subscription_app from datalayer_core.cli.commands.subscription import subscription_root from datalayer_core.cli.commands.tokens import app as tokens_app -from datalayer_core.cli.commands.tokens import tokens_list, tokens_ls +from datalayer_core.cli.commands.tokens import tokens_ls from datalayer_core.cli.commands.usage import app as usage_app from datalayer_core.cli.commands.usage import usage_root +from datalayer_core.cli.commands.plans import app as plans_app +from datalayer_core.cli.commands.plans import plans_root from datalayer_core.cli.commands.users import app as users_app from datalayer_core.cli.commands.web import app as web_app @@ -68,20 +81,116 @@ def main_callback( is_eager=True, help="Show version and exit", ), + run_url: str | None = typer.Option( + None, + "--run-url", + help="Override DATALAYER_RUN_URL for this CLI invocation.", + ), + iam_url: str | None = typer.Option( + None, + "--iam-url", + help="Override DATALAYER_IAM_URL for this CLI invocation.", + ), + runtimes_url: str | None = typer.Option( + None, + "--runtimes-url", + help="Override DATALAYER_RUNTIMES_URL for this CLI invocation.", + ), + spacer_url: str | None = typer.Option( + None, + "--spacer-url", + "--space-url", + help="Override DATALAYER_SPACER_URL for this CLI invocation.", + ), + library_url: str | None = typer.Option( + None, + "--library-url", + help="Override DATALAYER_LIBRARY_URL for this CLI invocation.", + ), + manager_url: str | None = typer.Option( + None, + "--manager-url", + help="Override DATALAYER_MANAGER_URL for this CLI invocation.", + ), + ai_agents_url: str | None = typer.Option( + None, + "--ai-agents-url", + help="Override DATALAYER_AI_AGENTS_URL for this CLI invocation.", + ), + ai_inference_url: str | None = typer.Option( + None, + "--ai-inference-url", + help="Override DATALAYER_AI_INFERENCE_URL for this CLI invocation.", + ), + growth_url: str | None = typer.Option( + None, + "--growth-url", + help="Override DATALAYER_GROWTH_URL for this CLI invocation.", + ), + otel_url: str | None = typer.Option( + None, + "--otel-url", + help="Override DATALAYER_OTEL_URL for this CLI invocation.", + ), + success_url: str | None = typer.Option( + None, + "--success-url", + help="Override DATALAYER_SUCCESS_URL for this CLI invocation.", + ), + status_url: str | None = typer.Option( + None, + "--status-url", + help="Override DATALAYER_STATUS_URL for this CLI invocation.", + ), + support_url: str | None = typer.Option( + None, + "--support-url", + help="Override DATALAYER_SUPPORT_URL for this CLI invocation.", + ), + mcp_server_url: str | None = typer.Option( + None, + "--mcp-server-url", + help="Override DATALAYER_MCP_SERVER_URL for this CLI invocation.", + ), ) -> None: """Main callback to handle global options.""" - pass + overrides = { + "DATALAYER_RUN_URL": run_url, + "DATALAYER_IAM_URL": iam_url, + "DATALAYER_RUNTIMES_URL": runtimes_url, + "DATALAYER_SPACER_URL": spacer_url, + "DATALAYER_LIBRARY_URL": library_url, + "DATALAYER_MANAGER_URL": manager_url, + "DATALAYER_AI_AGENTS_URL": ai_agents_url, + "DATALAYER_AI_INFERENCE_URL": ai_inference_url, + "DATALAYER_GROWTH_URL": growth_url, + "DATALAYER_OTEL_URL": otel_url, + "DATALAYER_SUCCESS_URL": success_url, + "DATALAYER_STATUS_URL": status_url, + "DATALAYER_SUPPORT_URL": support_url, + "DATALAYER_MCP_SERVER_URL": mcp_server_url, + } + for env_name, value in overrides.items(): + if value is not None: + os.environ[env_name] = value.rstrip("/") # Register commands (without name to add them at the top level) app.add_typer(about_app) +app.add_typer(agents_app) +app.add_typer(agent_nodes_app) app.add_typer(auth_app) app.add_typer(benchmarks_app) app.add_typer(checkpoints_app) +app.add_typer(cluster_app) app.add_typer(config_app) app.add_typer(console_app) app.add_typer(envs_app) +app.add_typer(evals_app) +app.add_typer(memberships_app) app.add_typer(otel_app) +app.add_typer(pools_app) +app.add_typer(ray_app) app.add_typer(runtimes_app) app.add_typer(secrets_app) app.add_typer(snapshots_app) @@ -89,6 +198,7 @@ def main_callback( app.add_typer(tokens_app) app.add_typer(users_app) app.add_typer(usage_app) +app.add_typer(plans_app) app.add_typer(web_app) # Add exec command directly to root level @@ -99,23 +209,91 @@ def main_callback( app.command(name="logout")(logout_root) app.command(name="whoami")(whoami_root) app.command(name="usage")(usage_root) +app.command(name="plans")(plans_root) app.command(name="subscription")(subscription_root) # Add convenient aliases at root level -app.command(name="envs-list")(envs_list) app.command(name="envs-ls")(envs_ls) -app.command(name="runtimes-list")(runtimes_list) app.command(name="runtimes-ls")(runtimes_ls) -app.command(name="secrets-list")(secrets_list) app.command(name="secrets-ls")(secrets_ls) -app.command(name="snapshots-list")(snapshots_list) app.command(name="snapshots-ls")(snapshots_ls) -app.command(name="checkpoints-list")(checkpoints_list) app.command(name="checkpoints-ls")(checkpoints_ls) -app.command(name="tokens-list")(tokens_list) app.command(name="tokens-ls")(tokens_ls) +app.command(name="agent-nodes-ls")(agent_nodes_ls) +app.command(name="agents-ls")(agents_ls) + + +_GLOBAL_OPTIONS_WITH_VALUES = { + "--run-url", + "--iam-url", + "--runtimes-url", + "--spacer-url", + "--space-url", + "--library-url", + "--manager-url", + "--ai-agents-url", + "--ai-inference-url", + "--growth-url", + "--otel-url", + "--success-url", + "--status-url", + "--support-url", + "--mcp-server-url", +} + +_GLOBAL_OPTIONS_NO_VALUES = { + "--version", +} + + +def _normalize_global_options(argv: list[str]) -> list[str]: + """Hoist supported global options so they work at any argument position.""" + if len(argv) <= 1: + return argv + + extracted: list[str] = [] + remaining: list[str] = [] + i = 1 + while i < len(argv): + token = argv[i] + + if token == "--": + remaining.extend(argv[i:]) + break + + if token in _GLOBAL_OPTIONS_NO_VALUES: + extracted.append(token) + i += 1 + continue + + matched_equals = next( + ( + option + for option in _GLOBAL_OPTIONS_WITH_VALUES + if token.startswith(f"{option}=") + ), + None, + ) + if matched_equals: + extracted.append(token) + i += 1 + continue + + if token in _GLOBAL_OPTIONS_WITH_VALUES: + extracted.append(token) + if i + 1 < len(argv): + extracted.append(argv[i + 1]) + i += 2 + else: + i += 1 + continue + + remaining.append(token) + i += 1 + + return [argv[0], *extracted, *remaining] def main() -> None: """Main entry point for the Datalayer Typer CLI.""" - app() + app(args=_normalize_global_options(sys.argv)[1:]) diff --git a/datalayer_core/cli/commands/about.py b/datalayer_core/cli/commands/about.py index 823ef578..a6a47c2e 100644 --- a/datalayer_core/cli/commands/about.py +++ b/datalayer_core/cli/commands/about.py @@ -8,10 +8,17 @@ import typer from rich.console import Console from rich.markdown import Markdown +from rich.text import Text # Create a Typer app for the about command app = typer.Typer() +FOOTER_ANSI = ( + "\n" + "\033[0;32m☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷☷\033[0m " + "\033[1;93m☰ DATA\033[0m\033[1;92mLAYER\033[0m" +) + @app.command() def about() -> None: @@ -24,6 +31,7 @@ def about() -> None: with open(about_file_path) as readme: markdown = Markdown(readme.read()) console.print(markdown) + console.print(Text.from_ansi(FOOTER_ANSI)) except FileNotFoundError: console.print(f"[red]Error: Could not find about.md at {about_file_path}[/red]") raise typer.Exit(1) diff --git a/datalayer_core/cli/commands/agent_nodes.py b/datalayer_core/cli/commands/agent_nodes.py new file mode 100644 index 00000000..1e3afc81 --- /dev/null +++ b/datalayer_core/cli/commands/agent_nodes.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023-2025 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Agent node commands for Datalayer CLI.""" + +import os +from typing import Any, Optional + +import requests +import typer +from rich.console import Console + +from datalayer_core.displays.agent_nodes import display_agent_nodes +from datalayer_core.utils.urls import DatalayerURLs + +app = typer.Typer( + name="agent-nodes", + help="Agent Node management commands", + invoke_without_command=True, +) + +console = Console() + + +def _resolve_token(token: Optional[str] = None) -> str: + if token: + return token + env_token = os.environ.get("DATALAYER_API_KEY") + if env_token: + return env_token + try: + from datalayer_core.client.client import DatalayerClient + + client = DatalayerClient() + return client._get_token() or "" + except Exception: + return "" + + +def _fetch_api( + path: str, + *, + method: str = "GET", + token: Optional[str] = None, + runtimes_url: Optional[str] = None, +) -> Any: + resolved_token = _resolve_token(token) + if not resolved_token: + raise RuntimeError( + "No authentication token found. Pass --token, set DATALAYER_API_KEY, or run 'datalayer login'." + ) + urls = DatalayerURLs.from_environment(runtimes_url=runtimes_url) + url = f"{urls.runtimes_url}/api/runtimes/v1{path}" + headers = {"Authorization": f"Bearer {resolved_token}"} + + response = requests.request(method, url, headers=headers, timeout=30) + response.raise_for_status() + return response.json() + + +@app.callback() +def agent_nodes_callback(ctx: typer.Context) -> None: + """Agent Node management commands.""" + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + + +@app.command(name="ls") +def list_agent_nodes( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """List registered agent nodes.""" + try: + data = _fetch_api("/agent-nodes", token=token, runtimes_url=runtimes_url) + nodes = data.get("agent_nodes", []) + if not nodes: + console.print("[yellow]No agent nodes found.[/yellow]") + raise typer.Exit(0) + display_agent_nodes(nodes) + except typer.Exit: + raise + except Exception as e: + console.print(f"[red]Error listing agent nodes: {e}[/red]") + raise typer.Exit(1) + + +def agent_nodes_list( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """List registered agent nodes (root command).""" + list_agent_nodes(token=token, runtimes_url=runtimes_url) + + +def agent_nodes_ls( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """List registered agent nodes (root alias).""" + list_agent_nodes(token=token, runtimes_url=runtimes_url) diff --git a/datalayer_core/cli/commands/agents.py b/datalayer_core/cli/commands/agents.py new file mode 100644 index 00000000..83798aef --- /dev/null +++ b/datalayer_core/cli/commands/agents.py @@ -0,0 +1,667 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Agent runtime commands for Datalayer CLI.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Optional + +import requests +import typer +import yaml +from rich.console import Console + +from datalayer_core.client.client import DatalayerClient +from datalayer_core.displays.runtimes import display_runtimes +from datalayer_core.runtimes.local import ( + DEFAULT_LOCAL_AGENT_NAME, + DEFAULT_LOCAL_HOST, + DEFAULT_LOCAL_LOG_LEVEL, + DEFAULT_LOCAL_PROTOCOL, + ensure_local_agent, + start_local_agent_runtime, + terminate_local_agent_runtime, +) +from datalayer_core.utils.urls import DatalayerURLs + +DEFAULT_AGENT_SPEC_ID = "example-simple" + +app = typer.Typer( + name="agents", + help="Agent runtime management commands.", + invoke_without_command=True, +) + +console = Console() + + +@app.callback() +def agents_callback(ctx: typer.Context) -> None: + """Agent runtime management commands.""" + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + + +def _make_client( + token: Optional[str] = None, + iam_url: Optional[str] = None, + runtimes_url: Optional[str] = None, +) -> DatalayerClient: + urls = DatalayerURLs.from_environment(iam_url=iam_url, runtimes_url=runtimes_url) + return DatalayerClient(urls=urls, token=token) + + +def _is_url(value: str) -> bool: + lowered = value.lower() + return lowered.startswith("http://") or lowered.startswith("https://") + + +def _load_agent_spec(spec_source: str) -> dict[str, Any]: + source = spec_source.strip() + if not source: + raise typer.BadParameter("--agentspec must be a non-empty URL or file path.") + + raw_text = "" + if _is_url(source): + try: + response = requests.get(source, timeout=30) + except Exception as exc: + raise RuntimeError( + f"Failed to fetch --agentspec URL '{source}': {exc}" + ) from exc + if response.status_code >= 400: + preview = (response.text or "")[:500] + raise RuntimeError( + f"--agentspec URL returned HTTP {response.status_code}: {source}\n{preview}" + ) + raw_text = response.text or "" + else: + path = Path(source) + if not path.exists(): + raise RuntimeError(f"--agentspec file does not exist: {path}") + if not path.is_file(): + raise RuntimeError(f"--agentspec path is not a file: {path}") + raw_text = path.read_text(encoding="utf-8") + + try: + parsed = yaml.safe_load(raw_text) + except Exception as exc: + raise RuntimeError(f"Failed to parse --agentspec as YAML/JSON: {exc}") from exc + + if not isinstance(parsed, dict): + raise RuntimeError("--agentspec must decode to an object (mapping).") + if not parsed: + raise RuntimeError("--agentspec decoded to an empty object.") + return parsed + + +def _create_local_agent_runtime( + *, + agent_spec_id: str, + agent_name: str, + host: str, + port: Optional[int], + protocol: str, + log_level: str, + token: Optional[str], + raw: bool, +) -> None: + """Launch a local agent-runtimes server and serve until interrupted.""" + runtime = start_local_agent_runtime( + agent_spec_id=agent_spec_id, + agent_name=agent_name, + host=host, + port=port, + protocol=protocol, + log_level=log_level, + ) + + resolved_token = (token or "").strip() + if resolved_token: + try: + ensure_local_agent( + base_url=runtime.base_url, + agent_name=agent_name, + token=resolved_token, + agent_spec_id=agent_spec_id, + transport=protocol, + ) + except Exception as exc: + terminate_local_agent_runtime(runtime) + raise RuntimeError(f"Failed to register local agent: {exc}") from exc + + if raw: + payload = { + "success": True, + "local": True, + "runtime": { + "base_url": runtime.base_url, + "agent_name": runtime.agent_name, + "agent_spec_id": runtime.agent_spec_id, + "chat_endpoint": runtime.chat_endpoint, + }, + } + console.print(json.dumps(payload, ensure_ascii=False)) + else: + console.print( + f"[green]Local agent runtime '{agent_name}' started![/green]" + ) + console.print(f"Base URL: {runtime.base_url}") + console.print(f"Agent spec id: {agent_spec_id}") + console.print(f"Chat endpoint: {runtime.chat_endpoint}") + console.print("[dim]Press Ctrl+C to stop the local runtime.[/dim]") + + process = runtime.process + try: + if process is not None: + process.wait() + except KeyboardInterrupt: + console.print("\n[yellow]Stopping local agent runtime...[/yellow]") + finally: + terminate_local_agent_runtime(runtime) + + +@app.command(name="ls") +def list_agents( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """List running agent runtimes.""" + try: + client = _make_client(token=token, iam_url=iam_url, runtimes_url=runtimes_url) + runtimes = client.list_runtimes() + runtime_dicts: list[dict[str, Any]] = [] + for runtime in runtimes: + runtime_dicts.append( + { + "given_name": runtime.name, + "environment_name": runtime.environment, + "pod_name": runtime.pod_name, + "ingress": runtime.ingress, + "reservation_id": runtime.reservation_id, + "uid": runtime.uid, + "burning_rate": runtime.burning_rate, + "token": runtime.jupyter_token, + "started_at": runtime.started_at, + "expired_at": runtime.expired_at, + } + ) + display_runtimes(runtime_dicts) + except Exception as exc: + console.print(f"[red]Error listing agent runtimes: {exc}[/red]") + raise typer.Exit(1) + + +@app.command(name="create") +def create_agent_runtime( + environment: Optional[str] = typer.Argument(None, help="Environment name."), + given_name: Optional[str] = typer.Option( + None, + "--given-name", + help="Custom name for the runtime.", + ), + spec_id: Optional[str] = typer.Option( + None, + "--agentspec-id", + help=( + "Agent spec id for runtime bootstrap. " + f"Defaults to {DEFAULT_AGENT_SPEC_ID} when --agentspec is omitted." + ), + ), + spec: Optional[str] = typer.Option( + None, + "--agentspec", + help="Agent spec source as YAML/JSON URL or local file path.", + ), + time_reservation: Optional[float] = typer.Option( + 10.0, + "--time-reservation", + help="Time reservation in minutes for the runtime.", + ), + billable_account_uid: Optional[str] = typer.Option( + None, + "--billable-account-uid", + help="Account UID to bill the runtime to (org/team).", + ), + billable_account_type: Optional[str] = typer.Option( + None, + "--billable-account-type", + help="Billable account type: user, organization, or team.", + ), + billable_account_handle: Optional[str] = typer.Option( + None, + "--billable-account-handle", + help="Billable account handle (informational).", + ), + raw: bool = typer.Option( + False, + "--raw", + help="Print machine-readable JSON payload.", + ), + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), + local: bool = typer.Option( + False, + "--local", + help="Launch the agent as a local agent-runtimes server instead of a cloud runtime.", + ), + host: str = typer.Option( + DEFAULT_LOCAL_HOST, + "--host", + help="Host interface for the local runtime (only with --local).", + ), + port: Optional[int] = typer.Option( + None, + "--port", + help="Port for the local runtime (random free port when omitted, only with --local).", + ), + protocol: str = typer.Option( + DEFAULT_LOCAL_PROTOCOL, + "--protocol", + help="Transport protocol for the local runtime (only with --local).", + ), + log_level: str = typer.Option( + DEFAULT_LOCAL_LOG_LEVEL, + "--log-level", + help="Log level for the local runtime process (only with --local).", + ), +) -> None: + """Create a new runtime preloaded with an agent spec. + + By default creates a cloud runtime. With ``--local`` it launches a local + ``agent-runtimes`` server and serves until interrupted (Ctrl+C). + """ + import questionary + + try: + if spec and spec_id: + raise typer.BadParameter( + "Use either --agentspec-id or --agentspec, not both." + ) + + if local: + if spec: + raise typer.BadParameter( + "--agentspec is not supported with --local; use --agentspec-id." + ) + _create_local_agent_runtime( + agent_spec_id=(spec_id or "").strip() or DEFAULT_AGENT_SPEC_ID, + agent_name=(given_name or "").strip() or DEFAULT_LOCAL_AGENT_NAME, + host=host, + port=port, + protocol=protocol, + log_level=log_level, + token=token, + raw=raw, + ) + return + + client = _make_client(token=token, iam_url=iam_url, runtimes_url=runtimes_url) + + if environment is None: + environments = client.list_environments() + if not environments: + console.print("[yellow]No environments available.[/yellow]") + raise typer.Exit(0) + choices = [] + for env in environments: + label = env.name + if env.title: + label += f" ({env.title})" + choices.append(questionary.Choice(title=label, value=env.name)) + + selected = questionary.select( + "Select the environment for the new agent runtime:", + choices=choices, + ).ask() + if selected is None: + raise typer.Exit(0) + environment = selected + + agent_spec_payload: dict[str, Any] | None = None + resolved_spec_id: str | None = None + if spec: + agent_spec_payload = _load_agent_spec(spec) + else: + resolved_spec_id = (spec_id or "").strip() or DEFAULT_AGENT_SPEC_ID + + final_time_reservation = time_reservation or 10.0 + runtime = client.create_runtime( + name=given_name, + environment=environment, + time_reservation=final_time_reservation, + agent_spec_id=resolved_spec_id, + agent_spec=agent_spec_payload, + billable_account_uid=billable_account_uid, + billable_account_type=billable_account_type, + billable_account_handle=billable_account_handle, + ) + + if raw: + payload = { + "success": True, + "runtime": { + "given_name": runtime.name, + "environment_name": runtime.environment, + "pod_name": runtime.pod_name, + "uid": runtime.uid, + "ingress": runtime.ingress, + "reservation_id": runtime.reservation_id, + "burning_rate": runtime.burning_rate, + "started_at": runtime.started_at, + "expired_at": runtime.expired_at, + }, + "agent_spec_id": resolved_spec_id, + "agent_spec_source": spec or "", + } + console.print(json.dumps(payload, ensure_ascii=False)) + return + + console.print(f"[green]Agent runtime '{runtime.name}' created successfully![/green]") + if runtime.pod_name: + console.print(f"Pod: {runtime.pod_name}") + if runtime.ingress: + console.print(f"Ingress: {runtime.ingress}") + if resolved_spec_id: + console.print(f"Agent spec id: {resolved_spec_id}") + elif spec: + console.print(f"Agent spec source: {spec}") + + except typer.Exit: + raise + except Exception as exc: + console.print("[red]Error creating agent runtime.[/red]") + console.print(f"[red]{exc}[/red]") + raise typer.Exit(1) + + +@app.command(name="get") +def get_agent_runtime( + pod_name: Optional[str] = typer.Argument( + None, + help="Pod name of the agent runtime to read.", + ), + raw: bool = typer.Option( + False, + "--raw", + help="Print machine-readable JSON payload.", + ), + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """Read a single agent runtime by pod name.""" + import questionary + + try: + client = _make_client(token=token, iam_url=iam_url, runtimes_url=runtimes_url) + + if pod_name is None: + runtimes = client.list_runtimes() + if not runtimes: + console.print("[yellow]No running runtimes found.[/yellow]") + raise typer.Exit(0) + choices = [] + for runtime in runtimes: + label = runtime.pod_name or "" + if runtime.name: + label = f"{runtime.pod_name} ({runtime.name})" + if runtime.environment: + label += f" [{runtime.environment}]" + choices.append(questionary.Choice(title=label, value=runtime.pod_name)) + + selected = questionary.select( + "Select the agent runtime to read:", + choices=choices, + ).ask() + if selected is None: + raise typer.Exit(0) + pod_name = selected + + runtime = client.get_runtime(pod_name) + runtime_dict = { + "given_name": runtime.name, + "environment_name": runtime.environment, + "pod_name": runtime.pod_name, + "ingress": runtime.ingress, + "reservation_id": runtime.reservation_id, + "uid": runtime.uid, + "burning_rate": runtime.burning_rate, + "token": runtime.jupyter_token, + "started_at": runtime.started_at, + "expired_at": runtime.expired_at, + } + + if raw: + console.print( + json.dumps( + {"success": True, "runtime": runtime_dict}, ensure_ascii=False + ) + ) + return + + display_runtimes([runtime_dict]) + + except typer.Exit: + raise + except Exception as exc: + console.print(f"[red]Error reading agent runtime: {exc}[/red]") + raise typer.Exit(1) + + +@app.command(name="update") +def update_agent_runtime( + pod_name: Optional[str] = typer.Argument( + None, + help="Pod name of the agent runtime to update.", + ), + capability: list[str] = typer.Option( + [], + "--capability", + help="Capability to apply (repeatable). Replaces existing capabilities.", + ), + raw: bool = typer.Option( + False, + "--raw", + help="Print machine-readable JSON payload.", + ), + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """Update an agent runtime's capabilities.""" + import questionary + + try: + client = _make_client(token=token, iam_url=iam_url, runtimes_url=runtimes_url) + + if pod_name is None: + runtimes = client.list_runtimes() + if not runtimes: + console.print("[yellow]No running runtimes found.[/yellow]") + raise typer.Exit(0) + choices = [] + for runtime in runtimes: + label = runtime.pod_name or "" + if runtime.name: + label = f"{runtime.pod_name} ({runtime.name})" + if runtime.environment: + label += f" [{runtime.environment}]" + choices.append(questionary.Choice(title=label, value=runtime.pod_name)) + + selected = questionary.select( + "Select the agent runtime to update:", + choices=choices, + ).ask() + if selected is None: + raise typer.Exit(0) + pod_name = selected + + client.update_runtime(pod_name, list(capability)) + + if raw: + console.print( + json.dumps( + { + "success": True, + "pod_name": pod_name, + "capabilities": list(capability), + }, + ensure_ascii=False, + ) + ) + return + + console.print( + f"[green]Agent runtime '{pod_name}' updated successfully![/green]" + ) + if capability: + console.print(f"Capabilities: {', '.join(capability)}") + + except typer.Exit: + raise + except Exception as exc: + console.print(f"[red]Error updating agent runtime: {exc}[/red]") + raise typer.Exit(1) + + +@app.command(name="delete") +@app.command(name="terminate") +def terminate_agent_runtime( + pod_name: Optional[str] = typer.Argument( + None, + help="Pod name of the runtime to terminate.", + ), + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """Terminate a running agent runtime.""" + import questionary + + try: + client = _make_client(token=token, iam_url=iam_url, runtimes_url=runtimes_url) + + if pod_name is None: + runtimes = client.list_runtimes() + if not runtimes: + console.print("[yellow]No running runtimes found.[/yellow]") + raise typer.Exit(0) + + choices = [] + for runtime in runtimes: + label = runtime.pod_name or "" + if runtime.name: + label = f"{runtime.pod_name} ({runtime.name})" + if runtime.environment: + label += f" [{runtime.environment}]" + choices.append(questionary.Choice(title=label, value=runtime.pod_name)) + + selected = questionary.select( + "Select the agent runtime to terminate:", + choices=choices, + ).ask() + if selected is None: + raise typer.Exit(0) + pod_name = selected + + success = client.terminate_runtime(pod_name) + if success: + console.print( + f"[green]Agent runtime '{pod_name}' terminated successfully![/green]" + ) + else: + console.print(f"[red]Failed to terminate agent runtime '{pod_name}'[/red]") + raise typer.Exit(1) + + except typer.Exit: + raise + except Exception as exc: + console.print(f"[red]Error terminating agent runtime: {exc}[/red]") + raise typer.Exit(1) + + +def agents_ls( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """List running agent runtimes (root command alias).""" + list_agents(token=token, iam_url=iam_url, runtimes_url=runtimes_url) \ No newline at end of file diff --git a/datalayer_core/cli/commands/authn.py b/datalayer_core/cli/commands/authn.py index ccbf25d0..7bfc2b09 100644 --- a/datalayer_core/cli/commands/authn.py +++ b/datalayer_core/cli/commands/authn.py @@ -4,12 +4,16 @@ """Authentication commands for Datalayer CLI - Refactored to use Client.""" import asyncio +import base64 +import json import os import threading import time -from typing import Optional +from datetime import datetime, timezone +from typing import Optional, Any import questionary +import requests import typer from rich.console import Console @@ -33,6 +37,95 @@ def auth_callback(ctx: typer.Context) -> None: typer.echo(ctx.get_help()) +def _fetch_memberships(iam_url: str, token: Optional[str]) -> Optional[list[dict]]: + """Fetch the authenticated user's organization/team memberships.""" + if not token: + return None + try: + response = requests.get( + f"{iam_url}/api/iam/v1/memberships", + headers={"Authorization": f"Bearer {token}"}, + timeout=10, + ) + if response.status_code != 200: + return None + data = response.json() + if not data.get("success", True): + return None + return data.get("memberships") or [] + except Exception: + return None + + +def _decode_jwt_claims(token: str) -> Optional[dict]: + """Decode JWT claims without verifying signature (display purpose only).""" + try: + parts = token.split(".") + if len(parts) < 2: + return None + payload = parts[1] + padding = "=" * (-len(payload) % 4) + decoded = base64.urlsafe_b64decode(payload + padding) + claims = json.loads(decoded.decode("utf-8")) + return claims if isinstance(claims, dict) else None + except Exception: + return None + + +def _coerce_unix_timestamp(value: Any) -> Optional[int]: + try: + if value is None: + return None + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + return int(value) + if isinstance(value, str): + return int(float(value.strip())) + except Exception: + return None + return None + + +def _format_unix_timestamp(ts: Optional[int]) -> str: + if ts is None: + return "unknown" + try: + return datetime.fromtimestamp(ts, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + except Exception: + return "unknown" + + +def _format_duration(seconds: int) -> str: + seconds = max(0, seconds) + days, remainder = divmod(seconds, 86400) + hours, remainder = divmod(remainder, 3600) + minutes, _ = divmod(remainder, 60) + chunks = [] + if days: + chunks.append(f"{days}d") + if hours: + chunks.append(f"{hours}h") + if minutes or not chunks: + chunks.append(f"{minutes}m") + return " ".join(chunks) + + +def _expiration_status(exp_ts: Optional[int]) -> str: + if exp_ts is None: + return "[red]unknown[/red]" + + now = int(time.time()) + remaining = exp_ts - now + if remaining <= 0: + return f"[red]expired { _format_duration(abs(remaining)) } ago[/red]" + if remaining <= 900: + return f"[red]{_format_duration(remaining)} remaining[/red]" + if remaining <= 86400: + return f"[yellow]{_format_duration(remaining)} remaining[/yellow]" + return f"[green]{_format_duration(remaining)} remaining[/green]" + + @app.command() def login( run_url: Optional[str] = typer.Option( @@ -408,6 +501,30 @@ def whoami( if user.get("last_update_ts_dt"): console.print(f"🔄 Last Updated: {user.get('last_update_ts_dt')}") + # JWT token details + token_for_details = access_token or auth.current_token or auth.get_stored_token() + if token_for_details: + claims = _decode_jwt_claims(token_for_details) + if claims: + subject = claims.get("sub") + if isinstance(subject, dict): + subject = subject.get("uid") or subject + exp_ts = _coerce_unix_timestamp(claims.get("exp")) + iat_ts = _coerce_unix_timestamp(claims.get("iat")) + + console.print("\n[bold]JWT Token:[/bold]") + if claims.get("jti"): + console.print(f" 🪪 JTI: {claims.get('jti')}") + if subject is not None: + console.print(f" 👤 Subject: {subject}") + if claims.get("iss"): + console.print(f" 🏷️ Issuer: {claims.get('iss')}") + if iat_ts is not None: + console.print(f" 🕒 Issued At: {_format_unix_timestamp(iat_ts)}") + if exp_ts is not None: + console.print(f" ⏰ Expires At: {_format_unix_timestamp(exp_ts)}") + console.print(f" ⌛ Time to Expiration: {_expiration_status(exp_ts)}") + # IAM Providers iam_providers = user.get("iam_providers", []) if iam_providers: @@ -429,10 +546,50 @@ def whoami( console.print(f" 🔗 {provider_name.capitalize()}") # Customer UID - if user.get("credits_customer_uid"): + if user.get("stripe_customer_id_s"): console.print( - f"\n💳 Credits Customer: {user.get('credits_customer_uid')}" + f"\n💳 Credits Customer: {user.get('stripe_customer_id_s')}" ) + + # Memberships (organizations + teams) + memberships = _fetch_memberships(urls.iam_url, access_token) + if memberships is not None: + orgs = [m for m in memberships if (m.get("type") or "").lower() == "organization"] + teams = [m for m in memberships if (m.get("type") or "").lower() == "team"] + org_by_uid = {m.get("uid"): m for m in orgs} + + if orgs: + console.print("\n[bold]🏢 Organizations:[/bold]") + for org in orgs: + handle = org.get("handle") or org.get("uid") or "unknown" + name = org.get("name") or "" + roles = ", ".join(org.get("roles_ss") or []) or "-" + label = f" • [cyan]{handle}[/cyan]" + if name and name != handle: + label += f" ({name})" + label += f" uid={org.get('uid')} roles={roles}" + console.print(label) + + if teams: + console.print("\n[bold]👥 Teams:[/bold]") + for team in teams: + handle = team.get("handle") or team.get("uid") or "unknown" + name = team.get("name") or "" + roles = ", ".join(team.get("roles_ss") or []) or "-" + org_uid = team.get("organization_uid") + parent = org_by_uid.get(org_uid) if org_uid else None + parent_label = ( + parent.get("handle") if parent else (org_uid or "unknown") + ) + label = f" • [cyan]{handle}[/cyan]" + if name and name != handle: + label += f" ({name})" + label += f" in [magenta]{parent_label}[/magenta]" + label += f" uid={team.get('uid')} roles={roles}" + console.print(label) + + if not orgs and not teams: + console.print("\n[dim]No organization or team memberships.[/dim]") else: console.print("[yellow]Not authenticated[/yellow]") console.print("Run 'datalayer login' to authenticate") diff --git a/datalayer_core/cli/commands/cluster.py b/datalayer_core/cli/commands/cluster.py new file mode 100644 index 00000000..61e6973b --- /dev/null +++ b/datalayer_core/cli/commands/cluster.py @@ -0,0 +1,280 @@ +# Copyright (c) 2023-2025 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Cluster visibility commands for Datalayer CLI.""" + +import os +from typing import Any, Optional + +import requests +import typer +from rich.console import Console +from rich.panel import Panel +from rich.text import Text +from rich.tree import Tree + +from datalayer_core.utils.urls import DatalayerURLs + + +app = typer.Typer( + name="cluster", + help="Cluster visibility commands", + invoke_without_command=True, +) + +console = Console() + + +def _resolve_token(token: Optional[str] = None) -> str: + if token: + return token + env_token = os.environ.get("DATALAYER_API_KEY") + if env_token: + return env_token + try: + from datalayer_core.client.client import DatalayerClient + + client = DatalayerClient() + return client._get_token() or "" + except Exception: + return "" + + +def _fetch_api( + path: str, + *, + token: Optional[str] = None, + runtimes_url: Optional[str] = None, + params: Optional[dict[str, str]] = None, +) -> Any: + resolved_token = _resolve_token(token) + if not resolved_token: + raise RuntimeError( + "No authentication token found. Pass --token, set DATALAYER_API_KEY, or run 'datalayer login'." + ) + + urls = DatalayerURLs.from_environment(runtimes_url=runtimes_url) + url = f"{urls.runtimes_url}/api/runtimes/v1{path}" + headers = {"Authorization": f"Bearer {resolved_token}"} + + response = requests.get(url, headers=headers, params=params, timeout=30) + response.raise_for_status() + return response.json() + + +def _status_style(status: str) -> str: + normalized = (status or "").lower() + if normalized in {"running", "ready", "succeeded"}: + return "green" + if normalized in {"pending", "unknown"}: + return "yellow" + if normalized in {"failed", "crashloopbackoff", "not_ready"}: + return "red" + return "white" + + +def _build_anomalies_panel(nodes_with_pods: list[Any], unassigned: list[Any]) -> Panel: + pending_pods = 0 + unschedulable_pods = 0 + failed_pods = 0 + pending_scale_up_nodes = 0 + pending_scale_down_nodes = 0 + not_ready_nodes = 0 + + for item in nodes_with_pods: + node = item.get("node", {}) if isinstance(item, dict) else {} + node_status = str(node.get("status") or "").lower() + ready = bool(node.get("ready")) + + if node_status == "pending_scale_up": + pending_scale_up_nodes += 1 + elif node_status == "pending_scale_down": + pending_scale_down_nodes += 1 + elif not ready: + not_ready_nodes += 1 + + node_pods = item.get("pods", []) if isinstance(item, dict) else [] + for pod in node_pods: + phase = str((pod or {}).get("phase") or "").lower() + if phase == "pending": + pending_pods += 1 + if phase in {"failed", "crashloopbackoff"}: + failed_pods += 1 + if bool((pod or {}).get("unschedulable")): + unschedulable_pods += 1 + + for pod in unassigned: + phase = str((pod or {}).get("phase") or "").lower() + if phase == "pending": + pending_pods += 1 + if phase in {"failed", "crashloopbackoff"}: + failed_pods += 1 + if bool((pod or {}).get("unschedulable")): + unschedulable_pods += 1 + + lines = Text() + lines.append("Pods\n", style="bold") + lines.append(f"pending pods: {pending_pods}\n", style="yellow") + lines.append(f"unschedulable pods: {unschedulable_pods}\n", style="red") + lines.append(f"unassigned pods: {len(unassigned)}\n", style="yellow") + lines.append(f"failed/crashloop pods: {failed_pods}\n", style="red") + lines.append("----------------------------------------\n", style="dim") + lines.append("Nodes\n", style="bold") + lines.append(f"not-ready nodes: {not_ready_nodes}\n", style="red") + lines.append(f"pending scale-up nodes: {pending_scale_up_nodes}\n", style="cyan") + lines.append(f"pending scale-down nodes: {pending_scale_down_nodes}", style="cyan") + + return Panel(lines, title="Anomalies", border_style="yellow") + + +@app.callback() +def cluster_callback(ctx: typer.Context) -> None: + """Cluster visibility commands.""" + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + + +@app.command(name="show") +def show_cluster( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), + phase: Optional[str] = typer.Option( + None, + "--phase", + help="Filter pods by phase (for example: Running, Pending, Failed).", + ), + no_anomalies: bool = typer.Option( + False, + "--no-anomalies", + help="Hide the anomaly summary panel.", + ), + anomalies_only: bool = typer.Option( + False, + "--anomalies-only", + help="Show only the anomaly summary panel (skip topology tree).", + ), +) -> None: + """Show cluster details with pods grouped by node and status.""" + try: + state_payload = _fetch_api( + "/cluster/state", + token=token, + runtimes_url=runtimes_url, + params={"phase": phase} if phase else None, + ) + nodes_with_pods = state_payload.get("nodes_with_pods", []) + unassigned = state_payload.get("unassigned_pods", []) + node_requests = state_payload.get("node_requests", []) + + if not anomalies_only: + root = Tree("[bold]Cluster Topology[/bold]") + + if not nodes_with_pods: + root.add("[yellow]No nodes returned by API.[/yellow]") + else: + for item in nodes_with_pods: + node = item.get("node", {}) if isinstance(item, dict) else {} + node_pods = item.get("pods", []) if isinstance(item, dict) else [] + node_name = str(node.get("name") or "") + node_status = str(node.get("status") or "unknown") + ready = "true" if bool(node.get("ready")) else "false" + schedulable = "true" if bool(node.get("schedulable")) else "false" + + node_line = Text() + node_line.append(node_name, style="bold") + node_line.append(" ") + node_line.append(f"[{node_status}]", style=_status_style(node_status)) + node_line.append(f" ready={ready} schedulable={schedulable}", style="dim") + node_line.append(f" pods={len(node_pods)}", style="cyan") + + node_branch = root.add(node_line) + + if not node_pods: + node_branch.add("[dim]No pods on this node.[/dim]") + continue + + for pod in node_pods: + pod_name = str(pod.get("name") or "") + namespace = str(pod.get("namespace") or "") + pod_phase = str(pod.get("phase") or "Unknown") + unsched = bool(pod.get("unschedulable")) + + pod_line = Text() + pod_line.append(f"{namespace}/{pod_name}" if namespace else pod_name) + pod_line.append(" ") + pod_line.append(f"[{pod_phase}]", style=_status_style(pod_phase)) + if unsched: + pod_line.append(" unschedulable", style="red") + + node_branch.add(pod_line) + if unassigned: + branch = root.add(f"[bold yellow]unassigned[/bold yellow] pods={len(unassigned)}") + for pod in unassigned: + pod_name = str(pod.get("name") or "") + namespace = str(pod.get("namespace") or "") + pod_phase = str(pod.get("phase") or "Unknown") + line = Text() + line.append(f"{namespace}/{pod_name}" if namespace else pod_name) + line.append(" ") + line.append(f"[{pod_phase}]", style=_status_style(pod_phase)) + if bool(pod.get("unschedulable")): + line.append(" unschedulable", style="red") + branch.add(line) + + console.print(root) + + if not no_anomalies: + console.print(_build_anomalies_panel(nodes_with_pods, unassigned)) + + if node_requests: + requests_text = Text() + for req in node_requests: + action_id = str((req or {}).get("action_id") or "") + operation = str((req or {}).get("operation") or "-") + status = str((req or {}).get("status") or "-") + phase = str((req or {}).get("phase") or "") + elapsed = (req or {}).get("elapsed_seconds") + requested = (req or {}).get("requested_delta_nodes") + applied = (req or {}).get("applied_delta_nodes") + target_workers = (req or {}).get("target_workers") + reason = str((req or {}).get("reason") or "") + if len(reason) > 120: + reason = reason[:117] + "..." + + requests_text.append(f"{action_id} ", style="bold") + requests_text.append(f"{operation} ", style="cyan") + requests_text.append(f"[{status}] ", style=_status_style(status)) + requests_text.append( + f"requested={requested if requested is not None else '-'} ", + style="yellow", + ) + requests_text.append( + f"applied={applied if applied is not None else '-'} ", + style="yellow", + ) + requests_text.append( + f"target_workers={target_workers if target_workers is not None else '-'}\n", + style="yellow", + ) + if phase or elapsed is not None: + requests_text.append( + " state: " + + (phase if phase else "-") + + " " + + f"elapsed={elapsed if elapsed is not None else '-'}s\n", + style="magenta", + ) + if reason: + requests_text.append(f" reason: {reason}\n", style="dim") + console.print(Panel(requests_text, title="Node Requests", border_style="cyan")) + except Exception as e: + console.print(f"[red]Error showing cluster details: {e}[/red]") + raise typer.Exit(1) diff --git a/datalayer_core/cli/commands/envs.py b/datalayer_core/cli/commands/envs.py index 92ca5246..fbc8d71e 100644 --- a/datalayer_core/cli/commands/envs.py +++ b/datalayer_core/cli/commands/envs.py @@ -22,10 +22,11 @@ def _make_client( token: Optional[str] = None, + iam_url: Optional[str] = None, runtimes_url: Optional[str] = None, ) -> DatalayerClient: """Create a DatalayerClient with optional runtimes URL override.""" - urls = DatalayerURLs.from_environment(runtimes_url=runtimes_url) + urls = DatalayerURLs.from_environment(iam_url=iam_url, runtimes_url=runtimes_url) return DatalayerClient(urls=urls, token=token) @@ -36,13 +37,18 @@ def envs_callback(ctx: typer.Context) -> None: typer.echo(ctx.get_help()) -@app.command(name="list") +@app.command(name="ls") def list_environments( token: Optional[str] = typer.Option( None, "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), runtimes_url: Optional[str] = typer.Option( None, "--runtimes-url", @@ -51,7 +57,11 @@ def list_environments( ) -> None: """List available environments.""" try: - client = _make_client(token=token, runtimes_url=runtimes_url) + client = _make_client( + token=token, + iam_url=iam_url, + runtimes_url=runtimes_url, + ) environments = client.list_environments() # Convert to dict format for display_environments @@ -84,23 +94,6 @@ def list_environments( raise typer.Exit(1) -@app.command(name="ls") -def list_environments_alias( - token: Optional[str] = typer.Option( - None, - "--token", - help="Authentication token (Bearer token for API requests).", - ), - runtimes_url: Optional[str] = typer.Option( - None, - "--runtimes-url", - help="Datalayer Runtimes server URL", - ), -) -> None: - """List available environments (alias for list).""" - list_environments(token=token, runtimes_url=runtimes_url) - - # Root level commands for convenience def envs_list( token: Optional[str] = typer.Option( @@ -108,6 +101,11 @@ def envs_list( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), runtimes_url: Optional[str] = typer.Option( None, "--runtimes-url", @@ -115,7 +113,7 @@ def envs_list( ), ) -> None: """List available environments (root command).""" - list_environments(token=token, runtimes_url=runtimes_url) + list_environments(token=token, iam_url=iam_url, runtimes_url=runtimes_url) def envs_ls( @@ -124,6 +122,11 @@ def envs_ls( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), runtimes_url: Optional[str] = typer.Option( None, "--runtimes-url", @@ -131,4 +134,4 @@ def envs_ls( ), ) -> None: """List available environments (root command alias).""" - list_environments(token=token, runtimes_url=runtimes_url) + list_environments(token=token, iam_url=iam_url, runtimes_url=runtimes_url) diff --git a/datalayer_core/cli/commands/evals.py b/datalayer_core/cli/commands/evals.py new file mode 100644 index 00000000..72f27732 --- /dev/null +++ b/datalayer_core/cli/commands/evals.py @@ -0,0 +1,1883 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Evals commands for Datalayer CLI.""" + +from __future__ import annotations + +from datetime import datetime, timezone +import csv +import json +import math +import re +import time +from pathlib import Path +from typing import Any, Optional + +import typer +from rich.console import Console +from rich.table import Table +from rich.tree import Tree + +from datalayer_core.client.client import DatalayerClient +from datalayer_core.utils.urls import DatalayerURLs + +app = typer.Typer( + name="evals", + help="Launch and monitor SaaS evalsets, experiments, runs, and live monitoring.", + invoke_without_command=True, +) + +evals_app = typer.Typer(name="evalsets", help="Manage evalsets.") +experiments_app = typer.Typer(name="experiments", help="Manage evalset experiments.") +runs_app = typer.Typer(name="runs", help="Launch and monitor evalset runs.") +live_app = typer.Typer(name="live", help="Inspect live evalset monitoring.") + +console = Console() + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + +def _timestamp_slug(raw_iso: str) -> str: + cleaned = raw_iso.replace("-", "").replace(":", "").replace(".", "") + cleaned = cleaned.replace("+0000", "Z").replace("+00:00", "Z") + cleaned = cleaned.replace("T", "T") + if cleaned.endswith("Z"): + return cleaned + return f"{cleaned}Z" + + +def _parse_json_value(raw: Optional[str], flag_name: str) -> dict[str, Any]: + if not raw: + return {} + try: + parsed = json.loads(raw) + except Exception as exc: + raise typer.BadParameter(f"Invalid JSON for {flag_name}: {exc}") from exc + if not isinstance(parsed, dict): + raise typer.BadParameter(f"{flag_name} must decode to an object") + return parsed + + +def _parse_json_file(path_value: Optional[str], flag_name: str) -> dict[str, Any]: + if not path_value: + return {} + path = Path(path_value) + if not path.exists(): + raise typer.BadParameter(f"File not found for {flag_name}: {path}") + text = path.read_text(encoding="utf-8") + return _parse_json_value(text, flag_name) + + +def _merge_dicts(*parts: dict[str, Any]) -> dict[str, Any]: + merged: dict[str, Any] = {} + for part in parts: + merged.update(part) + return merged + + +def _make_client( + token: Optional[str] = None, + ai_agents_url: Optional[str] = None, +) -> DatalayerClient: + urls = DatalayerURLs.from_environment(ai_agents_url=ai_agents_url) + return DatalayerClient(urls=urls, token=token) + + +def _status_style(status: str) -> str: + normalized = status.lower() + if normalized in {"completed", "success", "passed"}: + return "green" + if normalized in {"running", "queued", "pending"}: + return "yellow" + if normalized in {"failed", "error"}: + return "red" + return "white" + + +def _run_pass_rate(run: dict[str, Any]) -> float | None: + metrics = run.get("metrics") or {} + raw = metrics.get("pass_rate") + if isinstance(raw, (int, float)): + value = float(raw) + if value < 0: + return 0.0 + if value > 1: + return 1.0 + return value + return None + + +def _fmt_pct(raw: float | None) -> str: + if raw is None: + return "n/a" + return f"{raw * 100:.1f}%" + + +def _style_text(value: str, style: str | None, colorize: bool) -> str: + if not colorize or not style: + return value + return f"[{style}]{value}[/{style}]" + + +def _compute_baseline_and_drift(runs: list[dict[str, Any]]) -> tuple[float | None, float | None, float | None]: + pass_rates = [rate for rate in (_run_pass_rate(run) for run in runs) if rate is not None] + if not pass_rates: + return None, None, None + baseline_size = min(3, max(1, len(pass_rates) // 2)) + baseline_slice = pass_rates[:baseline_size] + baseline = sum(baseline_slice) / baseline_size + latest = pass_rates[-1] + drift = latest - baseline + return baseline, latest, drift + + +def _classify_legacy_failure(message: str) -> dict[str, Any]: + """Infer a structured stage/type/url from a free-form legacy error message. + + Older runs (and any path that only persisted a plain error string) lack a + structured ``failure_cause``. Rather than rendering ``unknown`` / + ``legacy_error`` with an empty detail excerpt, classify the most common + error shapes so the report stays actionable. + """ + text = message.strip() + lowered = text.lower() + + url_match = re.search(r"https?://[^\s]+", text) + execution_url = url_match.group(0).rstrip(".,)") if url_match else "" + + stage = "unknown" + failure_type = "legacy_error" + if "all connection attempts failed" in lowered or "connection refused" in lowered or "request failed" in lowered: + stage = "runtime_execution" + failure_type = "runtime_unreachable" + elif "returned http" in lowered or re.search(r"\bhttp\s*[45]\d\d\b", lowered): + stage = "runtime_execution" + failure_type = "runtime_http_error" + elif "traceback" in lowered: + stage = "runtime_execution" + failure_type = "runtime_traceback" + elif "no submitted code" in lowered or "missing" in lowered and "code" in lowered: + stage = "run_preparation" + failure_type = "missing_submitted_code" + elif "no interactive runtime url" in lowered or "not configured" in lowered: + stage = "runtime_resolution" + failure_type = "no_runtime_url" + + cause: dict[str, Any] = { + "stage": stage, + "type": failure_type, + "message": text, + "detail_excerpt": text, + } + if execution_url: + cause["execution_url"] = execution_url + return cause + + +def _extract_failure_cause(run: dict[str, Any]) -> dict[str, Any] | None: + """Extract a structured failure cause from a run's report/summary payload.""" + for container_key in ("report", "summary"): + container = run.get(container_key) + if isinstance(container, dict): + cause = container.get("failure_cause") + if isinstance(cause, dict) and cause: + return cause + # Fallback: synthesize a structured cause from legacy error fields. + summary = run.get("summary") if isinstance(run.get("summary"), dict) else {} + report = run.get("report") if isinstance(run.get("report"), dict) else {} + message = ( + summary.get("failure_reason") + or summary.get("execution_error") + or report.get("error") + ) + if isinstance(message, str) and message.strip(): + return _classify_legacy_failure(message) + return None + + +def _format_failure_cause(cause: dict[str, Any] | None) -> str: + """Render a failure cause as a concise single-line string.""" + if not isinstance(cause, dict) or not cause: + return "" + failure_type = str(cause.get("type") or "").strip() + message = str(cause.get("message") or "").strip() + parts: list[str] = [] + if failure_type: + parts.append(f"[{failure_type}]") + if message: + parts.append(message) + return " ".join(parts).strip() + + +def _failure_cause_detail_lines(cause: dict[str, Any]) -> list[str]: + """Render the full failure cause (message, context, diagnostics, attempts) as markdown lines.""" + lines: list[str] = [] + message = str(cause.get("message") or "").strip() + if message: + lines.append(f"- Message: {message}") + for key, label in ( + ("stage", "Stage"), + ("type", "Type"), + ("runtime_pod_name", "Runtime pod"), + ("environment_name", "Environment"), + ("execution_url", "Execution URL"), + ): + value = str(cause.get(key) or "").strip() + if value: + lines.append(f"- {label}: `{value}`") + + detail = str(cause.get("detail_excerpt") or "").strip() + if detail: + lines.append("- Detail excerpt:") + lines.append("") + lines.append("```text") + lines.extend(detail.splitlines() or [detail]) + lines.append("```") + + diagnostics = cause.get("diagnostics") + if isinstance(diagnostics, dict) and diagnostics: + for key, label in ( + ("agent_runtimes_url", "Agent runtimes URL"), + ("run_url", "Run URL"), + ): + value = diagnostics.get(key) + if value: + lines.append(f"- {label}: `{value}`") + for key, label in ( + ("route_ids", "Route IDs tried"), + ("discovered_agent_ids", "Discovered agent IDs"), + ("candidate_urls", "Candidate URLs"), + ): + value = diagnostics.get(key) + if isinstance(value, list) and value: + rendered = ", ".join(f"`{item}`" for item in value) + lines.append(f"- {label}: {rendered}") + + attempts = diagnostics.get("attempts") + if isinstance(attempts, list) and attempts: + lines.append("- Connection attempts:") + attempt_rows: list[list[str]] = [] + for attempt in attempts: + if not isinstance(attempt, dict): + continue + status_code = attempt.get("status_code") + attempt_rows.append( + [ + str(attempt.get("url") or "-"), + "ok" if attempt.get("ok") else "failed", + "-" if status_code is None else str(status_code), + str(attempt.get("error") or "-"), + ] + ) + if attempt_rows: + lines.append("") + lines.extend( + _markdown_table( + ["URL", "Result", "HTTP", "Error"], + attempt_rows, + ["left", "left", "right", "left"], + ) + ) + return lines + + +def _run_detail_record(run: dict[str, Any]) -> dict[str, Any]: + metrics = run.get("metrics") if isinstance(run.get("metrics"), dict) else {} + summary = run.get("summary") if isinstance(run.get("summary"), dict) else {} + report = run.get("report") if isinstance(run.get("report"), dict) else {} + return { + "id": str(run.get("id", "")), + "status": str(run.get("status", "")), + "created_at": str(run.get("created_at", "")), + "updated_at": str(run.get("updated_at", "")), + "pass_rate": _run_pass_rate(run), + "metrics": metrics, + "summary": summary, + "report": report, + "failure_cause": _extract_failure_cause(run), + } + + +def _report_data( + client: DatalayerClient, + evalset_id: str, + run_limit: int, + account_uid: Optional[str], +) -> dict[str, Any]: + experiments_payload = client.evals_list_experiments( + evalset_id=evalset_id, + limit=200, + offset=0, + account_uid=account_uid, + ) + experiments = experiments_payload.get("experiments") or [] + + report: dict[str, Any] = { + "evalset_id": evalset_id, + "generated_at": _now_iso(), + "experiments": [], + } + + for experiment in experiments: + experiment_id = str(experiment.get("id", "")) + experiment_name = str(experiment.get("name", experiment_id)) + + runs_payload = client.evals_list_runs( + experiment_id, + limit=run_limit, + offset=0, + account_uid=account_uid, + ) + runs = runs_payload.get("runs") or [] + total_runs = int(runs_payload.get("total") or len(runs)) + baseline, latest, drift = _compute_baseline_and_drift(runs) + + latest_two_delta: float | None = None + latest_two_run_ids: list[str] = [] + latest_two_compare: dict[str, Any] | None = None + if len(runs) >= 2: + latest_two_run_ids = [str(runs[0].get("id", "")), str(runs[1].get("id", ""))] + compare_payload = client.evals_compare_runs( + latest_two_run_ids, + account_uid=account_uid, + ) + compared_runs = compare_payload.get("runs") or [] + compared_by_id = { + str(run.get("id", "")): run + for run in compared_runs + if isinstance(run, dict) + } + run_a = compared_by_id.get(latest_two_run_ids[0], runs[0]) + run_b = compared_by_id.get(latest_two_run_ids[1], runs[1]) + pass_a = _run_pass_rate(run_a) + pass_b = _run_pass_rate(run_b) + if pass_a is not None and pass_b is not None: + latest_two_delta = pass_a - pass_b + latest_two_compare = { + "run_ids": latest_two_run_ids, + "run_a": _run_detail_record(run_a), + "run_b": _run_detail_record(run_b), + "delta_pass_rate": latest_two_delta, + } + + consecutive_comparisons: list[dict[str, Any]] = [] + for idx in range(max(0, len(runs) - 1)): + run_a = runs[idx] + run_b = runs[idx + 1] + pass_a = _run_pass_rate(run_a) + pass_b = _run_pass_rate(run_b) + delta = None + if pass_a is not None and pass_b is not None: + delta = pass_a - pass_b + consecutive_comparisons.append( + { + "run_a_id": str(run_a.get("id", "")), + "run_b_id": str(run_b.get("id", "")), + "run_a_status": str(run_a.get("status", "")), + "run_b_status": str(run_b.get("status", "")), + "run_a_pass_rate": pass_a, + "run_b_pass_rate": pass_b, + "delta_pass_rate": delta, + } + ) + + pass_rates = [ + _run_pass_rate(run) + for run in runs + if isinstance(_run_pass_rate(run), (int, float)) + ] + numeric_pass_rates = [float(value) for value in pass_rates if isinstance(value, (int, float))] + mean_pass = sum(numeric_pass_rates) / len(numeric_pass_rates) if numeric_pass_rates else None + stddev_pass = None + if numeric_pass_rates: + variance = sum((value - mean_pass) ** 2 for value in numeric_pass_rates) / len(numeric_pass_rates) + stddev_pass = math.sqrt(variance) + + report["experiments"].append( + { + "id": experiment_id, + "name": experiment_name, + "runs_total": total_runs, + "runs_fetched": len(runs), + "latest_pass_rate": latest, + "baseline_pass_rate": baseline, + "drift_delta": drift, + "latest_two_run_ids": latest_two_run_ids, + "latest_two_delta": latest_two_delta, + "latest_two_comparison": latest_two_compare, + "mean_pass_rate": mean_pass, + "stddev_pass_rate": stddev_pass, + "runs": [_run_detail_record(run) for run in runs], + "consecutive_comparisons": consecutive_comparisons, + } + ) + return report + + +def _ascii_bar( + value: float | None, + width: int = 28, + *, + full_blocks: bool = True, + colorize: bool = False, +) -> str: + if value is None: + return "-" + bounded = max(0.0, min(1.0, float(value))) + filled = int(round(bounded * width)) + fill_char = "█" if full_blocks else "#" + empty_char = "░" if full_blocks else "." + filled_part = fill_char * filled + empty_part = empty_char * (width - filled) + if not colorize: + return filled_part + empty_part + if bounded >= 0.85: + style = "green" + elif bounded >= 0.75: + style = "yellow" + else: + style = "red" + return _style_text(filled_part, style, True) + _style_text(empty_part, "grey39", True) + + +def _fmt_pts(value: float) -> str: + return f"{value * 100:.1f}" + + +def _ascii_histogram( + values: list[float], + *, + bins: int = 8, + width: int = 22, + min_value: float | None = None, + max_value: float | None = None, + full_blocks: bool = True, + colorize: bool = False, + drift_palette: bool = False, +) -> list[str]: + if not values: + return ["n/a"] + + lo = min_value if isinstance(min_value, (int, float)) else min(values) + hi = max_value if isinstance(max_value, (int, float)) else max(values) + if hi <= lo: + hi = lo + 1e-9 + + bins = max(2, bins) + counts = [0 for _ in range(bins)] + span = hi - lo + for value in values: + ratio = (value - lo) / span + idx = int(ratio * bins) + idx = max(0, min(bins - 1, idx)) + counts[idx] += 1 + + peak = max(counts) if counts else 1 + fill_char = "█" if full_blocks else "#" + empty_char = "░" if full_blocks else "." + lines: list[str] = [] + for idx, count in enumerate(counts): + left = lo + (span * idx / bins) + right = lo + (span * (idx + 1) / bins) + filled = int(round((count / peak) * width)) if peak > 0 else 0 + filled_part = fill_char * filled + empty_part = empty_char * (width - filled) + if colorize: + if drift_palette: + if right <= 0: + bar_style = "red" + elif left >= 0: + bar_style = "green" + else: + bar_style = "yellow" + elif peak > 0 and count / peak >= 0.67: + bar_style = "cyan" + elif peak > 0 and count / peak >= 0.34: + bar_style = "blue" + else: + bar_style = "magenta" + bar = _style_text(filled_part, bar_style, True) + _style_text(empty_part, "grey39", True) + else: + bar = filled_part + empty_part + lines.append( + f"{_fmt_pts(left):>6} to {_fmt_pts(right):>6} pts |{bar}| {count}" + ) + return lines + + +def _fmt_delta(value: float | None, *, colorize: bool = False) -> str: + if value is None: + return "n/a" + rendered = f"{value * 100:+.1f} pts" + if value > 0: + return _style_text(rendered, "green", colorize) + if value < 0: + return _style_text(rendered, "red", colorize) + return _style_text(rendered, "yellow", colorize) + + +def _sparkline(values: list[float], *, colorize: bool = False) -> str: + if not values: + return "n/a" + ticks = "▁▂▃▄▅▆▇█" + lo = min(values) + hi = max(values) + if hi <= lo: + base = ticks[-2] * len(values) + else: + span = hi - lo + chars = [] + for value in values: + idx = int(round(((value - lo) / span) * (len(ticks) - 1))) + idx = max(0, min(len(ticks) - 1, idx)) + chars.append(ticks[idx]) + base = "".join(chars) + if not colorize: + return base + if values[-1] >= 0.85: + style = "green" + elif values[-1] >= 0.75: + style = "yellow" + else: + style = "red" + return _style_text(base, style, True) + + +def _pairwise_latest_deltas(experiments: list[dict[str, Any]]) -> list[dict[str, Any]]: + pairs: list[dict[str, Any]] = [] + for idx, left in enumerate(experiments): + left_latest = left.get("latest_pass_rate") + if not isinstance(left_latest, (int, float)): + continue + for right in experiments[idx + 1 :]: + right_latest = right.get("latest_pass_rate") + if not isinstance(right_latest, (int, float)): + continue + pairs.append( + { + "left": str(left.get("name", "")), + "right": str(right.get("name", "")), + "left_latest": float(left_latest), + "right_latest": float(right_latest), + "delta": float(left_latest) - float(right_latest), + } + ) + pairs.sort(key=lambda item: abs(item["delta"]), reverse=True) + return pairs + + +def _markdown_table(headers: list[str], rows: list[list[str]], aligns: list[str]) -> list[str]: + widths = [len(header) for header in headers] + for row in rows: + for idx, cell in enumerate(row): + widths[idx] = max(widths[idx], len(cell)) + + def _pad(cell: str, width: int, align: str) -> str: + if align == "right": + return cell.rjust(width) + return cell.ljust(width) + + header_line = "| " + " | ".join(headers[idx].ljust(widths[idx]) for idx in range(len(headers))) + " |" + + sep_parts: list[str] = [] + for idx, align in enumerate(aligns): + width = max(3, widths[idx]) + if align == "right": + sep_parts.append("-" * (width - 1) + ":") + else: + sep_parts.append(":" + "-" * (width - 1)) + sep_line = "| " + " | ".join(sep_parts) + " |" + + body_lines = [ + "| " + " | ".join(_pad(row[idx], widths[idx], aligns[idx]) for idx in range(len(headers))) + " |" + for row in rows + ] + return [header_line, sep_line, *body_lines] + + +def _report_markdown(report: dict[str, Any], run_limit: int, *, colorize: bool = False) -> str: + evalset_id = str(report.get("evalset_id", "")) + generated_at = str(report.get("generated_at", "")) + experiments = [item for item in (report.get("experiments") or []) if isinstance(item, dict)] + + lines: list[str] = [] + lines.append(f"# Evals Report: {evalset_id}") + lines.append("") + lines.append(f"- Generated at: {generated_at}") + lines.append(f"- Experiments: {len(experiments)}") + lines.append(f"- Run window per experiment: {run_limit}") + lines.append("") + + lines.append("## Experiment Overview") + lines.append("") + overview_rows: list[list[str]] = [] + for experiment in experiments: + runs_fetched = int(experiment.get("runs_fetched") or 0) + runs_total = int(experiment.get("runs_total") or 0) + overview_rows.append( + [ + f"{experiment.get('name', '')}", + f"{runs_fetched}/{runs_total}", + _fmt_pct(experiment.get('latest_pass_rate') if isinstance(experiment.get('latest_pass_rate'), (int, float)) else None), + _fmt_pct(experiment.get('baseline_pass_rate') if isinstance(experiment.get('baseline_pass_rate'), (int, float)) else None), + _fmt_delta(experiment.get('drift_delta') if isinstance(experiment.get('drift_delta'), (int, float)) else None, colorize=colorize), + _fmt_delta(experiment.get('latest_two_delta') if isinstance(experiment.get('latest_two_delta'), (int, float)) else None, colorize=colorize), + ] + ) + lines.extend( + _markdown_table( + ["Experiment", "Runs (fetched/total)", "Latest", "Baseline", "Drift", "Latest-2 Delta"], + overview_rows, + ["left", "right", "right", "right", "right", "right"], + ) + ) + lines.append("") + + lines.append("## Comparison Combinations") + lines.append("") + + ranked_latest = sorted( + [item for item in experiments if isinstance(item.get("latest_pass_rate"), (int, float))], + key=lambda item: float(item.get("latest_pass_rate") or 0.0), + reverse=True, + ) + lines.append("### By Latest Pass Rate") + lines.append("") + latest_rows: list[list[str]] = [] + for idx, item in enumerate(ranked_latest, start=1): + latest_rows.append([str(idx), f"{item.get('name', '')}", _fmt_pct(float(item.get('latest_pass_rate') or 0.0))]) + lines.extend(_markdown_table(["Rank", "Experiment", "Latest"], latest_rows, ["right", "left", "right"])) + latest_values = [ + float(item.get("latest_pass_rate")) + for item in ranked_latest + if isinstance(item.get("latest_pass_rate"), (int, float)) + ] + lines.append("") + lines.append("Latest pass-rate histogram (pts):") + for hist_line in _ascii_histogram( + latest_values, + bins=8, + width=20, + min_value=0.0, + max_value=1.0, + full_blocks=True, + colorize=colorize, + ): + lines.append(f"`{hist_line}`") + lines.append("") + + ranked_drift = sorted( + [item for item in experiments if isinstance(item.get("drift_delta"), (int, float))], + key=lambda item: float(item.get("drift_delta") or 0.0), + ) + lines.append("### By Drift (Most Negative To Most Positive)") + lines.append("") + drift_rows: list[list[str]] = [] + for idx, item in enumerate(ranked_drift, start=1): + drift_rows.append([str(idx), f"{item.get('name', '')}", _fmt_delta(float(item.get('drift_delta') or 0.0), colorize=colorize)]) + lines.extend(_markdown_table(["Rank", "Experiment", "Drift"], drift_rows, ["right", "left", "right"])) + drift_values = [ + float(item.get("drift_delta")) + for item in ranked_drift + if isinstance(item.get("drift_delta"), (int, float)) + ] + lines.append("") + lines.append("Drift histogram (delta pts):") + for hist_line in _ascii_histogram( + drift_values, + bins=8, + width=20, + full_blocks=True, + colorize=colorize, + drift_palette=True, + ): + lines.append(f"`{hist_line}`") + lines.append("") + + ranked_stability = sorted( + [item for item in experiments if isinstance(item.get("stddev_pass_rate"), (int, float))], + key=lambda item: float(item.get("stddev_pass_rate") or 0.0), + ) + lines.append("### By Stability (Lowest Pass-Rate StdDev)") + lines.append("") + stability_rows: list[list[str]] = [] + for idx, item in enumerate(ranked_stability, start=1): + stddev = item.get("stddev_pass_rate") + mean = item.get("mean_pass_rate") + stability_rows.append( + [ + str(idx), + f"{item.get('name', '')}", + (f"{float(stddev) * 100:.2f} pts" if isinstance(stddev, (int, float)) else "n/a"), + (_fmt_pct(float(mean)) if isinstance(mean, (int, float)) else "n/a"), + ] + ) + lines.extend(_markdown_table(["Rank", "Experiment", "StdDev", "Mean"], stability_rows, ["right", "left", "right", "right"])) + lines.append("") + + pairwise = _pairwise_latest_deltas(experiments) + lines.append("### Pairwise Latest-Pass Deltas") + lines.append("") + pair_rows: list[list[str]] = [] + for pair in pairwise: + pair_rows.append( + [ + f"{pair['left']} vs {pair['right']}", + _fmt_pct(pair['left_latest']), + _fmt_pct(pair['right_latest']), + _fmt_delta(pair['delta'], colorize=colorize), + ] + ) + if not pairwise: + pair_rows.append(["n/a", "n/a", "n/a", "n/a"]) + lines.extend( + _markdown_table( + ["Pair", "Left Latest", "Right Latest", "Delta (Left-Right)"], + pair_rows, + ["left", "right", "right", "right"], + ) + ) + pair_deltas = [float(pair["delta"]) for pair in pairwise if isinstance(pair.get("delta"), (int, float))] + lines.append("") + lines.append("Pairwise latest-delta histogram (pts):") + for hist_line in _ascii_histogram( + pair_deltas, + bins=8, + width=20, + full_blocks=True, + colorize=colorize, + drift_palette=True, + ): + lines.append(f"`{hist_line}`") + lines.append("") + + lines.append("### Insight Highlights") + lines.append("") + best_latest = ranked_latest[0] if ranked_latest else None + worst_latest = ranked_latest[-1] if ranked_latest else None + most_negative = ranked_drift[0] if ranked_drift else None + most_positive = ranked_drift[-1] if ranked_drift else None + most_stable = ranked_stability[0] if ranked_stability else None + if best_latest: + lines.append( + "- Top latest pass-rate: " + + f"{best_latest.get('name', '')} ({_fmt_pct(float(best_latest.get('latest_pass_rate') or 0.0))})." + ) + if worst_latest: + lines.append( + "- Lowest latest pass-rate: " + + f"{worst_latest.get('name', '')} ({_fmt_pct(float(worst_latest.get('latest_pass_rate') or 0.0))})." + ) + if most_positive: + drift_pos = float(most_positive.get("drift_delta") or 0.0) + lines.append( + "- Strongest positive drift: " + + f"{most_positive.get('name', '')} ({_fmt_delta(drift_pos, colorize=colorize)})." + ) + if most_negative: + drift_neg = float(most_negative.get("drift_delta") or 0.0) + lines.append( + "- Strongest negative drift: " + + f"{most_negative.get('name', '')} ({_fmt_delta(drift_neg, colorize=colorize)})." + ) + if most_stable: + std = most_stable.get("stddev_pass_rate") + mean = most_stable.get("mean_pass_rate") + lines.append( + "- Stability leader: " + + f"{most_stable.get('name', '')} " + + f"(stddev={(float(std) * 100):.2f} pts, mean={_fmt_pct(float(mean)) if isinstance(mean, (int, float)) else 'n/a'})." + ) + + drift_neg_count = len([value for value in drift_values if value < 0]) + drift_flat_count = len([value for value in drift_values if value == 0]) + drift_pos_count = len([value for value in drift_values if value > 0]) + total = max(1, drift_neg_count + drift_flat_count + drift_pos_count) + neg_meter = "█" * int(round((drift_neg_count / total) * 14)) + flat_meter = "█" * int(round((drift_flat_count / total) * 14)) + pos_meter = "█" * int(round((drift_pos_count / total) * 14)) + neg_meter = neg_meter or "·" + flat_meter = flat_meter or "·" + pos_meter = pos_meter or "·" + lines.append("") + lines.append("Drift balance meter:") + lines.append( + "`NEG " + + _style_text(neg_meter, "red", colorize) + + f" ({drift_neg_count}) | FLAT " + + _style_text(flat_meter, "yellow", colorize) + + f" ({drift_flat_count}) | POS " + + _style_text(pos_meter, "green", colorize) + + f" ({drift_pos_count})`" + ) + lines.append("") + + lines.append("## Per-Experiment Details") + lines.append("") + for experiment in experiments: + lines.append(f"### {experiment.get('name', '')}") + lines.append("") + lines.append("#### Run Timeline") + lines.append("") + run_rows: list[list[str]] = [] + runs = [run for run in (experiment.get("runs") or []) if isinstance(run, dict)] + for idx, run in enumerate(runs, start=1): + pass_rate = run.get("pass_rate") if isinstance(run.get("pass_rate"), (int, float)) else None + cause_text = _format_failure_cause(run.get("failure_cause")) + run_rows.append( + [ + str(idx), + str(run.get('id', '')), + str(run.get('status', '')), + _fmt_pct(float(pass_rate)) if isinstance(pass_rate, (int, float)) else 'n/a', + f"`{_ascii_bar(float(pass_rate), full_blocks=True, colorize=colorize) if isinstance(pass_rate, (int, float)) else '-'}`", + cause_text or "-", + ] + ) + if not runs: + run_rows.append(["1", "n/a", "n/a", "n/a", "`-`", "-"]) + lines.extend(_markdown_table(["#", "Run ID", "Status", "Pass Rate", "ASCII Trend", "Failure Cause"], run_rows, ["right", "left", "left", "right", "left", "left"])) + lines.append("") + failure_rows: list[list[str]] = [] + for idx, run in enumerate(runs, start=1): + cause = run.get("failure_cause") + if not isinstance(cause, dict) or not cause: + continue + detail = str(cause.get("detail_excerpt") or "").strip() + detail_single = " ".join(detail.split()) + if len(detail_single) > 240: + detail_single = detail_single[:237] + "..." + failure_rows.append( + [ + str(idx), + str(run.get("id", "")), + str(cause.get("stage") or "-"), + str(cause.get("type") or "-"), + str(cause.get("message") or "-"), + detail_single or "-", + ] + ) + if failure_rows: + lines.append("#### Failure Causes") + lines.append("") + lines.extend( + _markdown_table( + ["#", "Run ID", "Stage", "Type", "Message", "Detail Excerpt"], + failure_rows, + ["right", "left", "left", "left", "left", "left"], + ) + ) + lines.append("") + for idx, run in enumerate(runs, start=1): + cause = run.get("failure_cause") + if not isinstance(cause, dict) or not cause: + continue + detail_lines = _failure_cause_detail_lines(cause) + if not detail_lines: + continue + lines.append(f"
Run {idx} failure detail ({run.get('id', '')})") + lines.append("") + lines.extend(detail_lines) + lines.append("") + lines.append("
") + lines.append("") + timeline_values = [ + float(run.get("pass_rate")) + for run in runs + if isinstance(run.get("pass_rate"), (int, float)) + ] + lines.append( + "Pass-rate sparkline: " + + f"`{_sparkline(timeline_values, colorize=colorize) if timeline_values else 'n/a'}`" + ) + lines.append("") + + comparisons = [ + item for item in (experiment.get("consecutive_comparisons") or []) + if isinstance(item, dict) + ] + lines.append("#### Consecutive Run Deltas (A-B)") + lines.append("") + comparison_rows: list[list[str]] = [] + for item in comparisons: + run_a = item.get("run_a_pass_rate") if isinstance(item.get("run_a_pass_rate"), (int, float)) else None + run_b = item.get("run_b_pass_rate") if isinstance(item.get("run_b_pass_rate"), (int, float)) else None + delta = item.get("delta_pass_rate") if isinstance(item.get("delta_pass_rate"), (int, float)) else None + comparison_rows.append( + [ + str(item.get('run_a_id', '')), + str(item.get('run_b_id', '')), + _fmt_pct(float(run_a)) if isinstance(run_a, (int, float)) else 'n/a', + _fmt_pct(float(run_b)) if isinstance(run_b, (int, float)) else 'n/a', + _fmt_delta(float(delta), colorize=colorize) if isinstance(delta, (int, float)) else 'n/a', + ] + ) + if not comparisons: + comparison_rows.append(["n/a", "n/a", "n/a", "n/a", "n/a"]) + lines.extend(_markdown_table(["Run A", "Run B", "A Pass", "B Pass", "Delta"], comparison_rows, ["left", "left", "right", "right", "right"])) + lines.append("") + + lines.append("## Notes") + lines.append("") + lines.append("- Drift is computed as latest - baseline.") + lines.append("- Baseline uses the first half of fetched runs (minimum 1, maximum 3).") + lines.append("- Latest-2 delta uses the latest two runs returned in the fetched window.") + lines.append("") + + return "\n".join(lines) + + +def _write_report_csv(report: dict[str, Any], output_path: Path) -> None: + experiments = [item for item in (report.get("experiments") or []) if isinstance(item, dict)] + fieldnames = [ + "row_type", + "evalset_id", + "experiment_id", + "experiment_name", + "run_index", + "run_id", + "run_status", + "run_pass_rate", + "runs_fetched", + "runs_total", + "baseline_pass_rate", + "latest_pass_rate", + "drift_delta", + "latest_two_delta", + "mean_pass_rate", + "stddev_pass_rate", + "failure_stage", + "failure_type", + "failure_message", + "generated_at", + ] + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8", newline="") as stream: + writer = csv.DictWriter(stream, fieldnames=fieldnames) + writer.writeheader() + for experiment in experiments: + writer.writerow( + { + "row_type": "experiment", + "evalset_id": str(report.get("evalset_id", "")), + "experiment_id": str(experiment.get("id", "")), + "experiment_name": str(experiment.get("name", "")), + "run_index": "", + "run_id": "", + "run_status": "", + "run_pass_rate": "", + "runs_fetched": int(experiment.get("runs_fetched") or 0), + "runs_total": int(experiment.get("runs_total") or 0), + "baseline_pass_rate": experiment.get("baseline_pass_rate"), + "latest_pass_rate": experiment.get("latest_pass_rate"), + "drift_delta": experiment.get("drift_delta"), + "latest_two_delta": experiment.get("latest_two_delta"), + "mean_pass_rate": experiment.get("mean_pass_rate"), + "stddev_pass_rate": experiment.get("stddev_pass_rate"), + "failure_stage": "", + "failure_type": "", + "failure_message": "", + "generated_at": str(report.get("generated_at", "")), + } + ) + runs = [run for run in (experiment.get("runs") or []) if isinstance(run, dict)] + for idx, run in enumerate(runs, start=1): + cause = run.get("failure_cause") if isinstance(run.get("failure_cause"), dict) else {} + writer.writerow( + { + "row_type": "run", + "evalset_id": str(report.get("evalset_id", "")), + "experiment_id": str(experiment.get("id", "")), + "experiment_name": str(experiment.get("name", "")), + "run_index": idx, + "run_id": str(run.get("id", "")), + "run_status": str(run.get("status", "")), + "run_pass_rate": run.get("pass_rate"), + "runs_fetched": int(experiment.get("runs_fetched") or 0), + "runs_total": int(experiment.get("runs_total") or 0), + "baseline_pass_rate": experiment.get("baseline_pass_rate"), + "latest_pass_rate": experiment.get("latest_pass_rate"), + "drift_delta": experiment.get("drift_delta"), + "latest_two_delta": experiment.get("latest_two_delta"), + "mean_pass_rate": experiment.get("mean_pass_rate"), + "stddev_pass_rate": experiment.get("stddev_pass_rate"), + "failure_stage": str(cause.get("stage", "")), + "failure_type": str(cause.get("type", "")), + "failure_message": str(cause.get("message", "")), + "generated_at": str(report.get("generated_at", "")), + } + ) + + +def _print_report_console(report: dict[str, Any], run_limit: int) -> None: + evalset_id = str(report.get("evalset_id", "")) + generated_at = str(report.get("generated_at", "")) + experiments = [item for item in (report.get("experiments") or []) if isinstance(item, dict)] + + console.rule(f"[bold cyan]Evals Report[/bold cyan] {evalset_id}") + console.print(f"Generated at: {generated_at}") + console.print(f"Experiments: {len(experiments)} | Run window per experiment: {run_limit}") + console.print("") + + overview = Table(title="Experiment Overview") + overview.add_column("Experiment", style="white") + overview.add_column("Runs", justify="right") + overview.add_column("Latest", justify="right") + overview.add_column("Baseline", justify="right") + overview.add_column("Drift", justify="right") + overview.add_column("Latest-2", justify="right") + for experiment in experiments: + overview.add_row( + str(experiment.get("name", "")), + f"{int(experiment.get('runs_fetched') or 0)}/{int(experiment.get('runs_total') or 0)}", + _fmt_pct(experiment.get("latest_pass_rate") if isinstance(experiment.get("latest_pass_rate"), (int, float)) else None), + _fmt_pct(experiment.get("baseline_pass_rate") if isinstance(experiment.get("baseline_pass_rate"), (int, float)) else None), + _fmt_delta(experiment.get("drift_delta") if isinstance(experiment.get("drift_delta"), (int, float)) else None, colorize=True), + _fmt_delta(experiment.get("latest_two_delta") if isinstance(experiment.get("latest_two_delta"), (int, float)) else None, colorize=True), + ) + console.print(overview) + + ranked_latest = sorted( + [item for item in experiments if isinstance(item.get("latest_pass_rate"), (int, float))], + key=lambda item: float(item.get("latest_pass_rate") or 0.0), + reverse=True, + ) + latest_table = Table(title="By Latest Pass Rate") + latest_table.add_column("Rank", justify="right", no_wrap=True) + latest_table.add_column("Experiment", style="white") + latest_table.add_column("Latest", justify="right", no_wrap=True) + for idx, item in enumerate(ranked_latest, start=1): + latest_table.add_row(str(idx), str(item.get("name", "")), _fmt_pct(float(item.get("latest_pass_rate") or 0.0))) + console.print(latest_table) + latest_values = [ + float(item.get("latest_pass_rate")) + for item in ranked_latest + if isinstance(item.get("latest_pass_rate"), (int, float)) + ] + console.print("Latest histogram:") + for hist_line in _ascii_histogram( + latest_values, + bins=8, + width=20, + min_value=0.0, + max_value=1.0, + full_blocks=True, + colorize=True, + ): + console.print(hist_line) + + ranked_drift = sorted( + [item for item in experiments if isinstance(item.get("drift_delta"), (int, float))], + key=lambda item: float(item.get("drift_delta") or 0.0), + ) + drift_table = Table(title="By Drift (Negative To Positive)") + drift_table.add_column("Rank", justify="right", no_wrap=True) + drift_table.add_column("Experiment", style="white") + drift_table.add_column("Drift", justify="right", no_wrap=True) + for idx, item in enumerate(ranked_drift, start=1): + drift_table.add_row( + str(idx), + str(item.get("name", "")), + _fmt_delta(float(item.get("drift_delta") or 0.0), colorize=True), + ) + console.print(drift_table) + drift_values = [ + float(item.get("drift_delta")) + for item in ranked_drift + if isinstance(item.get("drift_delta"), (int, float)) + ] + console.print("Drift histogram:") + for hist_line in _ascii_histogram( + drift_values, + bins=8, + width=20, + full_blocks=True, + colorize=True, + drift_palette=True, + ): + console.print(hist_line) + + pairwise = _pairwise_latest_deltas(experiments) + pairwise_table = Table(title="Pairwise Latest-Pass Deltas") + pairwise_table.add_column("Pair", style="white") + pairwise_table.add_column("Left", justify="right", no_wrap=True) + pairwise_table.add_column("Right", justify="right", no_wrap=True) + pairwise_table.add_column("Delta", justify="right", no_wrap=True) + for pair in pairwise: + pairwise_table.add_row( + f"{pair['left']} vs {pair['right']}", + _fmt_pct(pair["left_latest"]), + _fmt_pct(pair["right_latest"]), + _fmt_delta(pair["delta"], colorize=True), + ) + if not pairwise: + pairwise_table.add_row("n/a", "n/a", "n/a", "n/a") + console.print(pairwise_table) + + if ranked_latest: + console.print( + "[bold]Insight:[/bold] top latest " + f"[green]{ranked_latest[0].get('name', '')}[/green] " + f"({_fmt_pct(float(ranked_latest[0].get('latest_pass_rate') or 0.0))})" + ) + if ranked_drift: + console.print( + "[bold]Insight:[/bold] strongest drift " + f"{ranked_drift[-1].get('name', '')} " + f"({_fmt_delta(float(ranked_drift[-1].get('drift_delta') or 0.0), colorize=True)})" + ) + console.print("") + + for experiment in experiments: + console.print("") + console.print(f"[bold]Run Timeline:[/bold] {experiment.get('name', '')}") + run_table = Table() + run_table.add_column("#", justify="right", style="cyan", no_wrap=True) + run_table.add_column("Run ID", style="white", no_wrap=True) + run_table.add_column("Status", no_wrap=True) + run_table.add_column("Pass Rate", justify="right", no_wrap=True) + run_table.add_column("Trend", style="white", no_wrap=True) + run_table.add_column("Failure Cause", style="red", overflow="fold") + + runs = [run for run in (experiment.get("runs") or []) if isinstance(run, dict)] + for idx, run in enumerate(runs, start=1): + status_value = str(run.get("status", "")) + pass_rate = float(run.get("pass_rate")) if isinstance(run.get("pass_rate"), (int, float)) else None + cause_text = _format_failure_cause(run.get("failure_cause")) + run_table.add_row( + str(idx), + str(run.get("id", "")), + f"[{_status_style(status_value)}]{status_value}[/{_status_style(status_value)}]", + _fmt_pct(pass_rate), + _ascii_bar(pass_rate, width=28, full_blocks=True, colorize=True) if pass_rate is not None else "-", + cause_text or "-", + ) + if not runs: + run_table.add_row("1", "n/a", "n/a", "n/a", "-", "-") + console.print(run_table) + + for idx, run in enumerate(runs, start=1): + cause = run.get("failure_cause") + if not isinstance(cause, dict) or not cause: + continue + console.print( + f"[red bold]Run {idx} failure:[/red bold] " + f"[red]{str(cause.get('message') or 'Unknown failure.')}[/red]" + ) + for key, label in ( + ("stage", "stage"), + ("type", "type"), + ("execution_url", "execution url"), + ): + value = str(cause.get(key) or "").strip() + if value: + console.print(f" {label}: {value}") + diagnostics = cause.get("diagnostics") + if isinstance(diagnostics, dict): + for key, label in ( + ("agent_runtimes_url", "agent runtimes url"), + ("run_url", "run url"), + ): + value = diagnostics.get(key) + if value: + console.print(f" {label}: {value}") + candidate_urls = diagnostics.get("candidate_urls") + if isinstance(candidate_urls, list) and candidate_urls: + console.print(f" candidate urls: {', '.join(str(u) for u in candidate_urls)}") + attempts = diagnostics.get("attempts") + if isinstance(attempts, list) and attempts: + for attempt in attempts: + if not isinstance(attempt, dict): + continue + outcome = "ok" if attempt.get("ok") else "failed" + console.print( + f" attempt: {attempt.get('url', '')} -> {outcome} " + f"{attempt.get('error') or ''}".rstrip() + ) + detail = str(cause.get("detail_excerpt") or "").strip() + if detail: + console.print(f" detail: {detail}") + + deltas_table = Table(title="Consecutive Run Deltas") + deltas_table.add_column("Run A", style="white", no_wrap=True) + deltas_table.add_column("Run B", style="white", no_wrap=True) + deltas_table.add_column("A Pass", justify="right", no_wrap=True) + deltas_table.add_column("B Pass", justify="right", no_wrap=True) + deltas_table.add_column("Delta", justify="right", no_wrap=True) + comparisons = [ + item for item in (experiment.get("consecutive_comparisons") or []) + if isinstance(item, dict) + ] + for item in comparisons: + run_a = item.get("run_a_pass_rate") if isinstance(item.get("run_a_pass_rate"), (int, float)) else None + run_b = item.get("run_b_pass_rate") if isinstance(item.get("run_b_pass_rate"), (int, float)) else None + delta = item.get("delta_pass_rate") if isinstance(item.get("delta_pass_rate"), (int, float)) else None + deltas_table.add_row( + str(item.get("run_a_id", "")), + str(item.get("run_b_id", "")), + _fmt_pct(float(run_a)) if isinstance(run_a, (int, float)) else "n/a", + _fmt_pct(float(run_b)) if isinstance(run_b, (int, float)) else "n/a", + _fmt_delta(float(delta), colorize=True) if isinstance(delta, (int, float)) else "n/a", + ) + if not comparisons: + deltas_table.add_row("n/a", "n/a", "n/a", "n/a", "n/a") + console.print(deltas_table) + + +@app.callback() +def evals_callback(ctx: typer.Context) -> None: + """Evals command group.""" + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + + +@app.command(name="ls") +def evals_ls( + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + run_environment: Optional[str] = typer.Option(None, "--run-environment", help="Filter by run environment (ui/sdk)."), + kind: Optional[str] = typer.Option(None, "--kind", help="Filter by kind (batch/interactive)."), + q: Optional[str] = typer.Option(None, "--q", help="Search query."), + limit: int = typer.Option(50, "--limit", min=1, max=200), + offset: int = typer.Option(0, "--offset", min=0), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON output."), +) -> None: + """List all evalsets and their experiments.""" + client = _make_client(token=token, ai_agents_url=ai_agents_url) + evalsets_payload = client.evals_list_evals( + run_environment=run_environment, + kind=kind, + q=q, + limit=limit, + offset=offset, + account_uid=account_uid, + ) + evalsets = [item for item in (evalsets_payload.get("evalsets") or []) if isinstance(item, dict)] + + experiments_by_evalset: dict[str, list[dict[str, Any]]] = {} + for evalset in evalsets: + evalset_id = str(evalset.get("id", "")) + if not evalset_id: + continue + experiments_payload = client.evals_list_experiments( + evalset_id=evalset_id, + limit=200, + offset=0, + account_uid=account_uid, + ) + experiments_by_evalset[evalset_id] = [ + item + for item in (experiments_payload.get("experiments") or []) + if isinstance(item, dict) + ] + + if raw: + console.print( + { + "evalsets": evalsets, + "experiments": experiments_by_evalset, + } + ) + return + + total_experiments = sum(len(items) for items in experiments_by_evalset.values()) + tree = Tree( + f"[bold]Evals[/bold] ([cyan]{len(evalsets)}[/cyan] evalsets, " + f"[cyan]{total_experiments}[/cyan] experiments)" + ) + for evalset in evalsets: + evalset_id = str(evalset.get("id", "")) + evalset_node = tree.add( + f"[cyan]{evalset_id}[/cyan] [white]{evalset.get('name', '')}[/white] " + f"(env={evalset.get('run_environment', '')}, " + f"kind={evalset.get('kind', '')}, " + f"cases={len(evalset.get('cases') or [])})" + ) + experiments = experiments_by_evalset.get(evalset_id, []) + if not experiments: + evalset_node.add("[dim]no experiments[/dim]") + continue + for experiment in experiments: + status_value = str(experiment.get("status", "")) + evalset_node.add( + f"[cyan]{experiment.get('id', '')}[/cyan] " + f"[white]{experiment.get('name', '')}[/white] " + f"[{_status_style(status_value)}]{status_value}[/{_status_style(status_value)}]" + ) + console.print(tree) + + +@app.command(name="delete") +def evals_delete_top( + evalset_id: str = typer.Argument(..., help="Evalset UID to delete."), + yes: bool = typer.Option(False, "--yes", "-y", help="Skip the confirmation prompt."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), +) -> None: + """Delete an evalset and its associated experiments, runs, and cases.""" + if not yes: + typer.confirm( + f"Delete evalset {evalset_id} and all associated experiments, runs, and cases?", + abort=True, + ) + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_delete_eval(evalset_id, account_uid=account_uid) + cascade = payload.get("cascade") or {} + console.print( + f"[green]Eval deleted:[/green] {evalset_id} " + f"(experiments={cascade.get('experiments_deleted', 0)}, " + f"runs={cascade.get('runs_deleted', 0)}, " + f"cases={cascade.get('cases_deleted', 0)})" + ) + + +@evals_app.command(name="ls") +def evals_list( + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + run_environment: Optional[str] = typer.Option(None, "--run-environment", help="Filter by run environment (ui/sdk)."), + kind: Optional[str] = typer.Option(None, "--kind", help="Filter by kind (batch/interactive)."), + q: Optional[str] = typer.Option(None, "--q", help="Search query."), + limit: int = typer.Option(50, "--limit", min=1, max=200), + offset: int = typer.Option(0, "--offset", min=0), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON output."), +) -> None: + """List evalsets.""" + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_list_evals( + run_environment=run_environment, + kind=kind, + q=q, + limit=limit, + offset=offset, + account_uid=account_uid, + ) + if raw: + console.print(payload) + return + + evalsets = payload.get("evalsets") or [] + table = Table(title=f"Evals ({len(evalsets)})") + table.add_column("ID", style="cyan") + table.add_column("Name", style="white") + table.add_column("Run Environment", style="white") + table.add_column("Kind", style="white") + table.add_column("Cases", style="white") + table.add_column("Updated", style="white") + for item in evalsets: + table.add_row( + str(item.get("id", "")), + str(item.get("name", "")), + str(item.get("run_environment", "")), + str(item.get("kind", "")), + str(len(item.get("cases") or [])), + str(item.get("updated_at", "")), + ) + console.print(table) + + +@evals_app.command(name="create") +def evals_create( + name: Optional[str] = typer.Argument(None, help="Evalset name."), + description: Optional[str] = typer.Option(None, "--description", help="Evalset description."), + run_environment: Optional[str] = typer.Option(None, "--run-environment", help="Evalset run environment (ui/sdk)."), + kind: Optional[str] = typer.Option(None, "--kind", help="Evalset kind (batch/interactive)."), + spec_file: Optional[str] = typer.Option(None, "--spec-file", help="Path to evalset spec JSON file."), + schema_json: Optional[str] = typer.Option(None, "--schema-json", help="Schema JSON object."), + metadata_json: Optional[str] = typer.Option(None, "--metadata-json", help="Metadata JSON object."), + cases_file: Optional[str] = typer.Option(None, "--cases-file", help="Path to JSON array of cases."), + tags: list[str] = typer.Option([], "--tag", help="Repeatable tag."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON output."), +) -> None: + """Create an evalset.""" + spec = _parse_json_file(spec_file, "--spec-file") + schema = _merge_dicts( + spec.get("schema") if isinstance(spec.get("schema"), dict) else {}, + _parse_json_value(schema_json, "--schema-json"), + ) + metadata = _merge_dicts( + spec.get("metadata") if isinstance(spec.get("metadata"), dict) else {}, + _parse_json_value(metadata_json, "--metadata-json"), + ) + + cases: list[dict[str, Any]] = [] + if isinstance(spec.get("cases"), list): + cases = [case for case in spec.get("cases") if isinstance(case, dict)] + if cases_file: + text = Path(cases_file).read_text(encoding="utf-8") + decoded = json.loads(text) + if not isinstance(decoded, list): + raise typer.BadParameter("--cases-file must contain a JSON array") + cases = [case for case in decoded if isinstance(case, dict)] + + resolved_name = str(name or spec.get("name") or "").strip() + if not resolved_name: + raise typer.BadParameter("name argument is required unless provided in --spec-file") + resolved_description = str(description if description is not None else spec.get("description") or "") + resolved_run_environment = str(run_environment if run_environment is not None else spec.get("run_environment") or "sdk") + resolved_kind = str(kind if kind is not None else spec.get("kind") or "batch") + + spec_tags = spec.get("tags") if isinstance(spec.get("tags"), list) else [] + resolved_tags = tags if tags else [str(tag) for tag in spec_tags if str(tag).strip()] + + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_create_eval( + name=resolved_name, + description=resolved_description, + run_environment=resolved_run_environment, + kind=resolved_kind, + schema=schema, + metadata=metadata, + tags=resolved_tags, + cases=cases, + account_uid=account_uid, + ) + if raw: + typer.echo(json.dumps(payload)) + return + eval_record = payload.get("evalset") or {} + console.print(f"[green]Eval created:[/green] {eval_record.get('id', '')} ({eval_record.get('name', '')})") + + +@evals_app.command(name="delete") +def evals_delete( + evalset_id: str = typer.Argument(..., help="Evalset ID."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), +) -> None: + """Delete an evalset (cascade delete runs/experiments).""" + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_delete_eval(evalset_id, account_uid=account_uid) + cascade = payload.get("cascade") or {} + console.print( + "[green]Eval deleted.[/green] " + f"experiments={cascade.get('experiments_deleted', 0)} " + f"runs={cascade.get('runs_deleted', 0)} " + f"cases={cascade.get('cases_deleted', 0)}" + ) + + +def _render_report( + evalset_id: Optional[str], + run_limit: int = typer.Option(50, "--run-limit", min=2, max=200, help="Runs fetched per experiment."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + output_file: Optional[str] = typer.Option(None, "--output", help="Write markdown report to file."), + export: bool = typer.Option(False, "--export", help="Export timestamped report files report-.md and report-.csv."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON report output."), +) -> None: + """Generate a full evalset report with cross-experiment comparisons.""" + client = _make_client(token=token, ai_agents_url=ai_agents_url) + resolved_evalset_id = (evalset_id or "").strip() + if not resolved_evalset_id: + payload = client.evals_list_evals( + limit=200, + offset=0, + account_uid=account_uid, + ) + evalsets = [item for item in (payload.get("evalsets") or []) if isinstance(item, dict)] + if not evalsets: + raise typer.BadParameter("No evalsets found. Provide explicitly.") + + def _updated_key(item: dict[str, Any]) -> str: + return str(item.get("updated_at") or item.get("created_at") or "") + + latest_evalset = max(evalsets, key=_updated_key) + resolved_evalset_id = str(latest_evalset.get("id") or "").strip() + if not resolved_evalset_id: + raise typer.BadParameter("Latest evalset does not contain an id.") + console.print( + f"[yellow]No evalset id provided.[/yellow] Using latest evalset: " + f"[cyan]{resolved_evalset_id}[/cyan]" + ) + + report = _report_data( + client=client, + evalset_id=resolved_evalset_id, + run_limit=run_limit, + account_uid=account_uid, + ) + experiments = report.get("experiments") or [] + if not experiments: + console.print(f"[yellow]No experiments found for evalset[/yellow] {resolved_evalset_id}") + raise typer.Exit(0) + + if raw: + console.print(report) + return + + markdown_report = _report_markdown(report, run_limit=run_limit, colorize=False) + if export: + timestamp = _timestamp_slug(str(report.get("generated_at", _now_iso()))) + export_markdown_path = Path(f"report-{timestamp}.md") + export_csv_path = Path(f"report-{timestamp}.csv") + export_markdown_path.write_text(markdown_report + "\n", encoding="utf-8") + _write_report_csv(report, export_csv_path) + console.print(f"[green]Markdown export written:[/green] {export_markdown_path}") + console.print(f"[green]CSV export written:[/green] {export_csv_path}") + if output_file: + output_path = Path(output_file) + output_path.write_text(markdown_report + "\n", encoding="utf-8") + console.print(f"[green]Report written:[/green] {output_path}") + _print_report_console(report, run_limit=run_limit) + + +@app.command(name="report") +def evals_report( + evalset_id: Optional[str] = typer.Argument(None, help="Evalset ID to report. Defaults to latest updated evalset."), + run_limit: int = typer.Option(50, "--run-limit", min=2, max=200, help="Runs fetched per experiment."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + output_file: Optional[str] = typer.Option(None, "--output", help="Write markdown report to file."), + export: bool = typer.Option(False, "--export", help="Export timestamped report files report-.md and report-.csv."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON report output."), +) -> None: + """Generate an evalset report in markdown with comparison combinations and ASCII plots.""" + _render_report( + evalset_id=evalset_id, + run_limit=run_limit, + token=token, + ai_agents_url=ai_agents_url, + account_uid=account_uid, + output_file=output_file, + export=export, + raw=raw, + ) + + +@evals_app.command(name="compare-report") +def evals_compare_report_compat( + evalset_id: Optional[str] = typer.Argument(None, help="Evalset ID to report. Defaults to latest updated evalset."), + run_limit: int = typer.Option(50, "--run-limit", min=2, max=200, help="Runs fetched per experiment."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + output_file: Optional[str] = typer.Option(None, "--output", help="Write markdown report to file."), + export: bool = typer.Option(False, "--export", help="Export timestamped report files report-.md and report-.csv."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON report output."), +) -> None: + """Compatibility alias for report. Prefer: datalayer evals report .""" + console.print("[yellow]Deprecated:[/yellow] use [bold]datalayer evals report [/bold].") + _render_report( + evalset_id=evalset_id, + run_limit=run_limit, + token=token, + ai_agents_url=ai_agents_url, + account_uid=account_uid, + output_file=output_file, + export=export, + raw=raw, + ) + + +@experiments_app.command(name="ls") +def experiments_list( + evalset_id: Optional[str] = typer.Option(None, "--evalset-id", help="Filter by evalset ID."), + status: Optional[str] = typer.Option(None, "--status", help="Filter by status."), + limit: int = typer.Option(50, "--limit", min=1, max=200), + offset: int = typer.Option(0, "--offset", min=0), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON output."), +) -> None: + """List evalset experiments.""" + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_list_experiments( + evalset_id=evalset_id, + status=status, + limit=limit, + offset=offset, + account_uid=account_uid, + ) + if raw: + console.print(payload) + return + experiments = payload.get("experiments") or [] + table = Table(title=f"Eval Experiments ({len(experiments)})") + table.add_column("ID", style="cyan") + table.add_column("Name", style="white") + table.add_column("Eval", style="white") + table.add_column("Status", style="white") + table.add_column("Updated", style="white") + for item in experiments: + status_value = str(item.get("status", "")) + table.add_row( + str(item.get("id", "")), + str(item.get("name", "")), + str(item.get("evalset_id", "")), + f"[{_status_style(status_value)}]{status_value}[/{_status_style(status_value)}]", + str(item.get("updated_at", "")), + ) + console.print(table) + + +@experiments_app.command(name="create") +def experiments_create( + name: Optional[str] = typer.Argument(None, help="Experiment name."), + evalset_id: Optional[str] = typer.Option(None, "--evalset-id", help="Evalset ID."), + description: Optional[str] = typer.Option(None, "--description", help="Description."), + status: Optional[str] = typer.Option(None, "--status", help="Initial status."), + spec_file: Optional[str] = typer.Option(None, "--spec-file", help="Path to experiment spec JSON file."), + config_json: Optional[str] = typer.Option(None, "--config-json", help="Config JSON object."), + summary_json: Optional[str] = typer.Option(None, "--summary-json", help="Summary JSON object."), + tags: list[str] = typer.Option([], "--tag", help="Repeatable tag."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON output."), +) -> None: + """Create an evalset experiment.""" + spec = _parse_json_file(spec_file, "--spec-file") + + resolved_name = str(name or spec.get("name") or "").strip() + if not resolved_name: + raise typer.BadParameter("name argument is required unless provided in --spec-file") + resolved_evalset_id = str(evalset_id or spec.get("evalset_id") or "").strip() or None + resolved_description = str(description if description is not None else spec.get("description") or "") + resolved_status = str(status if status is not None else spec.get("status") or "draft") + resolved_config = _merge_dicts( + spec.get("config") if isinstance(spec.get("config"), dict) else {}, + _parse_json_value(config_json, "--config-json"), + ) + resolved_summary = _merge_dicts( + spec.get("summary") if isinstance(spec.get("summary"), dict) else {}, + _parse_json_value(summary_json, "--summary-json"), + ) + spec_tags = spec.get("tags") if isinstance(spec.get("tags"), list) else [] + resolved_tags = tags if tags else [str(tag) for tag in spec_tags if str(tag).strip()] + + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_create_experiment( + name=resolved_name, + evalset_id=resolved_evalset_id, + description=resolved_description, + status=resolved_status, + config=resolved_config, + summary=resolved_summary, + tags=resolved_tags, + account_uid=account_uid, + ) + if raw: + typer.echo(json.dumps(payload)) + return + experiment = payload.get("experiment") or {} + console.print(f"[green]Experiment created:[/green] {experiment.get('id', '')} ({experiment.get('name', '')})") + + +@runs_app.command(name="ls") +def runs_list( + experiment_id: str = typer.Option(..., "--experiment-id", help="Experiment ID."), + limit: int = typer.Option(50, "--limit", min=1, max=200), + offset: int = typer.Option(0, "--offset", min=0), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON output."), +) -> None: + """List runs for an experiment.""" + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_list_runs( + experiment_id, + limit=limit, + offset=offset, + account_uid=account_uid, + ) + if raw: + console.print(payload) + return + runs = payload.get("runs") or [] + table = Table(title=f"Eval Runs ({len(runs)})") + table.add_column("Run", style="cyan") + table.add_column("Status", style="white") + table.add_column("Pass Rate", style="white") + table.add_column("Run Environment", style="white") + table.add_column("Created", style="white") + for run in runs: + status_value = str(run.get("status", "")) + metrics = run.get("metrics") or {} + summary = run.get("summary") or {} + pass_rate = metrics.get("pass_rate") + if isinstance(pass_rate, (float, int)): + pass_rate_text = f"{float(pass_rate) * 100:.1f}%" + else: + pass_rate_text = "n/a" + run_environment = str(summary.get("run_environment") or summary.get("launch_source") or "") + table.add_row( + str(run.get("id", "")), + f"[{_status_style(status_value)}]{status_value}[/{_status_style(status_value)}]", + pass_rate_text, + run_environment, + str(run.get("created_at", "")), + ) + console.print(table) + + +@runs_app.command(name="launch") +def runs_launch( + experiment_id: str = typer.Option(..., "--experiment-id", help="Experiment ID."), + status: str = typer.Option("queued", "--status", help="Initial run status."), + run_mode: Optional[str] = typer.Option(None, "--run-mode", help="Run mode hint (batch/interactive)."), + runtime_pod_name: Optional[str] = typer.Option(None, "--runtime-pod-name", help="Runtime pod for interactive execution."), + submitted_code_file: Optional[str] = typer.Option(None, "--submitted-code-file", help="Python file to execute in interactive mode."), + metrics_json: Optional[str] = typer.Option(None, "--metrics-json", help="Inline metrics JSON object."), + summary_json: Optional[str] = typer.Option(None, "--summary-json", help="Inline summary JSON object."), + report_json: Optional[str] = typer.Option(None, "--report-json", help="Inline report JSON object."), + metrics_file: Optional[str] = typer.Option(None, "--metrics-file", help="Path to metrics JSON object."), + summary_file: Optional[str] = typer.Option(None, "--summary-file", help="Path to summary JSON object."), + report_file: Optional[str] = typer.Option(None, "--report-file", help="Path to report JSON object."), + started_at: Optional[str] = typer.Option(None, "--started-at", help="ISO timestamp override."), + ended_at: Optional[str] = typer.Option(None, "--ended-at", help="ISO timestamp override."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), +) -> None: + """Launch an evalset run on SaaS and tag it as CLI-launched.""" + cli_summary: dict[str, Any] = { + "launch_source": "datalayer-cli", + "launched_at": _now_iso(), + } + if run_mode: + cli_summary["run_mode"] = run_mode + if runtime_pod_name: + cli_summary["runtime_pod_name"] = runtime_pod_name + if submitted_code_file: + path = Path(submitted_code_file) + if not path.exists(): + raise typer.BadParameter(f"submitted code file not found: {submitted_code_file}") + cli_summary["submitted_code"] = path.read_text(encoding="utf-8") + + metrics = _merge_dicts( + _parse_json_file(metrics_file, "--metrics-file"), + _parse_json_value(metrics_json, "--metrics-json"), + ) + summary = _merge_dicts( + _parse_json_file(summary_file, "--summary-file"), + _parse_json_value(summary_json, "--summary-json"), + cli_summary, + ) + report = _merge_dicts( + _parse_json_file(report_file, "--report-file"), + _parse_json_value(report_json, "--report-json"), + ) + + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_create_run( + experiment_id, + status=status, + started_at=started_at, + ended_at=ended_at, + metrics=metrics, + summary=summary, + report=report, + account_uid=account_uid, + ) + run = payload.get("run") or {} + run_id = str(run.get("id", "")) + ui_url = f"{client.urls.ai_agents_url}/agents/evals" + console.print(f"[green]Run launched:[/green] {run_id}") + console.print(f"Track in UI: {ui_url}") + + +@runs_app.command(name="watch") +def runs_watch( + run_id: str = typer.Argument(..., help="Run ID."), + interval_seconds: float = typer.Option(3.0, "--interval", min=0.5, help="Polling interval."), + timeout_seconds: int = typer.Option(600, "--timeout", min=5, help="Timeout in seconds."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), +) -> None: + """Watch a run until completion/failure.""" + client = _make_client(token=token, ai_agents_url=ai_agents_url) + started = time.time() + last_status = "" + + while True: + payload = client.evals_get_run(run_id, account_uid=account_uid) + run = payload.get("run") or {} + status = str(run.get("status", "unknown")) + if status != last_status: + metrics = run.get("metrics") or {} + pass_rate = metrics.get("pass_rate") + pass_rate_text = ( + f"{float(pass_rate) * 100:.1f}%" + if isinstance(pass_rate, (int, float)) + else "n/a" + ) + console.print( + f"[{_status_style(status)}]{status}[/{_status_style(status)}] " + f"pass_rate={pass_rate_text} updated={run.get('updated_at', '')}" + ) + last_status = status + + if status.lower() in {"completed", "failed", "cancelled", "error"}: + return + + if time.time() - started >= timeout_seconds: + raise typer.Exit(1) + + time.sleep(interval_seconds) + + +@live_app.command(name="targets") +def live_targets( + window: str = typer.Option("24h", "--window", help="Window: 1h, 6h, 24h, 7d, 30d."), + limit: int = typer.Option(50, "--limit", min=1, max=200), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + ai_agents_url: Optional[str] = typer.Option(None, "--ai-agents-url", help="AI Agents base URL."), + account_uid: Optional[str] = typer.Option(None, "--account-uid", help="Organization/account UID context."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON output."), +) -> None: + """List live monitoring targets.""" + client = _make_client(token=token, ai_agents_url=ai_agents_url) + payload = client.evals_list_live_targets( + window=window, + limit=limit, + account_uid=account_uid, + ) + if raw: + console.print(payload) + return + targets = payload.get("targets") or [] + table = Table(title=f"Live Eval Targets ({len(targets)})") + table.add_column("Target", style="cyan") + table.add_column("Type", style="white") + table.add_column("Events", style="white") + table.add_column("Pass Rate", style="white") + table.add_column("Avg Value", style="white") + table.add_column("Last Event", style="white") + for item in targets: + pass_rate = item.get("pass_rate") + pass_rate_text = ( + f"{float(pass_rate) * 100:.1f}%" + if isinstance(pass_rate, (int, float)) + else "n/a" + ) + table.add_row( + str(item.get("target_id", "")), + str(item.get("target_type", "")), + str(item.get("event_count", 0)), + pass_rate_text, + str(item.get("avg_value", "n/a")), + str(item.get("last_event_at", "")), + ) + console.print(table) + + +app.add_typer(evals_app) +app.add_typer(experiments_app) +app.add_typer(runs_app) +app.add_typer(live_app) diff --git a/datalayer_core/cli/commands/memberships.py b/datalayer_core/cli/commands/memberships.py new file mode 100644 index 00000000..f3710bc1 --- /dev/null +++ b/datalayer_core/cli/commands/memberships.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023-2025 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Memberships command: list the authenticated user's organization and team memberships.""" + +import json as _json +import os +from typing import Optional + +import typer +from rich.console import Console +from rich.table import Table + +from datalayer_core.cli.commands.authn import _fetch_memberships +from datalayer_core.utils.urls import DatalayerURLs + +app = typer.Typer( + name="memberships", + help="List organization and team memberships for the authenticated user.", + invoke_without_command=True, +) + +console = Console() + + +def _print_memberships( + memberships: list[dict], + *, + only: Optional[str] = None, +) -> None: + orgs = [m for m in memberships if (m.get("type") or "").lower() == "organization"] + teams = [m for m in memberships if (m.get("type") or "").lower() == "team"] + org_by_uid = {m.get("uid"): m for m in orgs} + + if only in (None, "organization", "organizations", "org", "orgs"): + if orgs: + table = Table(title="🏢 Organizations") + table.add_column("Handle", style="cyan") + table.add_column("Name") + table.add_column("UID") + table.add_column("Roles") + for org in orgs: + table.add_row( + str(org.get("handle") or ""), + str(org.get("name") or ""), + str(org.get("uid") or ""), + ", ".join(org.get("roles_ss") or []) or "-", + ) + console.print(table) + elif only is not None: + console.print("[dim]No organization memberships.[/dim]") + + if only in (None, "team", "teams"): + if teams: + table = Table(title="👥 Teams") + table.add_column("Handle", style="cyan") + table.add_column("Name") + table.add_column("Organization", style="magenta") + table.add_column("UID") + table.add_column("Roles") + for team in teams: + org_uid = team.get("organization_uid") + parent = org_by_uid.get(org_uid) if org_uid else None + parent_label = ( + parent.get("handle") if parent else (org_uid or "unknown") + ) + table.add_row( + str(team.get("handle") or ""), + str(team.get("name") or ""), + str(parent_label or ""), + str(team.get("uid") or ""), + ", ".join(team.get("roles_ss") or []) or "-", + ) + console.print(table) + elif only is not None: + console.print("[dim]No team memberships.[/dim]") + + if only is None and not orgs and not teams: + console.print("[dim]No organization or team memberships.[/dim]") + + +@app.callback(invoke_without_command=True) +def memberships_root( + ctx: typer.Context, + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + token: Optional[str] = typer.Option( + None, + "--token", + help="User access token", + ), + only: Optional[str] = typer.Option( + None, + "--only", + help="Restrict output to one type: 'organizations' or 'teams'.", + ), + as_json: bool = typer.Option( + False, + "--json", + help="Print raw JSON memberships response.", + ), +) -> None: + """List the authenticated user's organization and team memberships.""" + if ctx.invoked_subcommand is not None: + return + + urls = DatalayerURLs.from_environment(iam_url=iam_url) + access_token = token or os.environ.get("DATALAYER_API_KEY") + if not access_token: + console.print( + "[red]No access token available. Use --token or set DATALAYER_API_KEY.[/red]" + ) + raise typer.Exit(1) + + memberships = _fetch_memberships(urls.iam_url, access_token) + if memberships is None: + console.print("[red]Failed to fetch memberships from IAM service.[/red]") + raise typer.Exit(1) + + if as_json: + typer.echo(_json.dumps(memberships, indent=2, sort_keys=True)) + return + + _print_memberships(memberships, only=only) diff --git a/datalayer_core/cli/commands/plans.py b/datalayer_core/cli/commands/plans.py new file mode 100644 index 00000000..db55ed4b --- /dev/null +++ b/datalayer_core/cli/commands/plans.py @@ -0,0 +1,394 @@ +# Copyright (c) 2023-2025 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Plans commands for Datalayer CLI.""" + +from typing import Any, Optional + +import typer +from rich.console import Console +from rich.table import Table + +from datalayer_core.client.client import DatalayerClient +from datalayer_core.utils.urls import DatalayerURLs + +app = typer.Typer( + name="plans", help="Plan and subscription details", invoke_without_command=True +) +console = Console(width=200) + + +def _normalize_value(value: Any, fallback: str = "n/a") -> str: + if value is None: + return fallback + text = str(value).strip() + return text if text else fallback + + +def _iam_get(client: DatalayerClient, path: str) -> dict[str, Any]: + return client._fetch(f"{client.urls.iam_url}{path}", method="GET").json() + + +def _iam_post( + client: DatalayerClient, path: str, body: dict[str, Any] +) -> dict[str, Any]: + return client._fetch( + f"{client.urls.iam_url}{path}", + method="POST", + json=body, + ).json() + + +def _make_client( + token: Optional[str] = None, + iam_url: Optional[str] = None, +) -> DatalayerClient: + urls = DatalayerURLs.from_environment(iam_url=iam_url) + return DatalayerClient(urls=urls, token=token) + + +@app.callback() +def plans_callback(ctx: typer.Context) -> None: + """Plans and subscription commands.""" + if ctx.invoked_subcommand is None: + ctx.invoke(plans_show) + + +def _format_number(value: Any, fallback: str = "-") -> str: + if value is None: + return fallback + try: + number = float(value) + except (TypeError, ValueError): + return _normalize_value(value, fallback=fallback) + if number.is_integer(): + return f"{int(number)}" + return f"{number:.4f}".rstrip("0").rstrip(".") or "0" + + +def _format_period(start: Any, end: Any) -> str: + start_text = _normalize_value(start, fallback="") + end_text = _normalize_value(end, fallback="") + if not start_text and not end_text: + return "-" + # Trim ISO timestamps to a date for readability. + start_short = start_text[:10] if start_text else "…" + end_short = end_text[:10] if end_text else "…" + return f"{start_short} → {end_short}" + + +def _format_runs(plan: dict[str, Any]) -> str: + included = plan.get("included_runs") + used = plan.get("used_credits") + remaining = plan.get("remaining_runs") + used_text = _format_number(used, fallback="0") + if included in (None, "", 0): + return f"{used_text} / ∞" + included_text = _format_number(included) + if remaining is not None: + remaining_text = _format_number(remaining) + return f"{used_text} / {included_text} (left {remaining_text})" + return f"{used_text} / {included_text}" + + +def _format_wallet( + plan: dict[str, Any], + wallet_balance: Any = None, +) -> str: + balance = ( + wallet_balance + if wallet_balance is not None + else plan.get("wallet_balance") + ) + quota = plan.get("wallet_quota") + is_quota = bool(plan.get("wallet_is_quota")) + balance_text = _format_number(balance, fallback="0") + if is_quota and quota not in (None, ""): + return f"{balance_text} / {_format_number(quota)}" + return balance_text + + +def _render_plan_row( + table: Table, + scope_label: str, + handle: str, + name: str, + account_uid: str, + plan: dict[str, Any], + wallet_balance: Any = None, + is_eligible: Any = None, + parent: str = "", +) -> None: + plan_name = plan.get("plan_name") or plan.get("plan_code") or "Free" + status = plan.get("status") or "unknown" + eligible = ( + "yes" if is_eligible is True else ("no" if is_eligible is False else "-") + ) + handle_text = _normalize_value(handle, fallback="-") + if name and name != handle: + handle_text = f"{handle_text} ({name})" + table.add_row( + scope_label, + handle_text, + _normalize_value(parent, fallback="-"), + _normalize_value(plan_name), + _normalize_value(status), + _format_wallet(plan, wallet_balance=wallet_balance), + _format_number(plan.get("current_credits"), fallback="0"), + _format_runs(plan), + _format_period( + plan.get("current_period_start"), plan.get("current_period_end") + ), + eligible, + _normalize_value(account_uid), + ) + + +def _add_plan_columns(table: Table) -> None: + table.add_column("Scope", style="cyan", no_wrap=True) + table.add_column("Handle", style="white", no_wrap=True) + table.add_column("Parent Org", style="magenta", no_wrap=True) + table.add_column("Plan", style="green", no_wrap=True) + table.add_column("Status", style="white", no_wrap=True) + table.add_column( + "Wallet (balance/quota)", style="yellow", justify="right", no_wrap=True + ) + table.add_column( + "Current Credits", style="white", justify="right", no_wrap=True + ) + table.add_column( + "Runs (used/included)", style="white", justify="right", no_wrap=True + ) + table.add_column("Period", style="white", no_wrap=True) + table.add_column("Eligible", style="white", no_wrap=True) + table.add_column("Account UID", style="dim", no_wrap=True) + + + +@app.command(name="show") +def plans_show( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + raw: bool = typer.Option( + False, + "--raw", + help="Print raw JSON payload from IAM.", + ), +) -> None: + """Show the authenticated user's plan plus plans of org/team memberships.""" + try: + client = _make_client(token=token, iam_url=iam_url) + + # 1. Authenticated user plan. + self_plan_response = _iam_get(client, "/api/iam/v1/plans") + if not self_plan_response.get("success", True): + console.print( + f"[red]Error: {self_plan_response.get('message', 'Unknown error')}[/red]" + ) + raise typer.Exit(1) + + # 2. Memberships (organizations + teams). + memberships_response = _iam_get(client, "/api/iam/v1/memberships") + memberships = ( + memberships_response.get("memberships") or [] + if memberships_response.get("success", True) + else [] + ) + + # 3. Resolve plans for all org/team memberships in one batch. + membership_uids = [ + m.get("uid") for m in memberships if m.get("uid") + ] + accounts_details: list[dict[str, Any]] = [] + if membership_uids: + details_response = _iam_post( + client, + "/api/iam/v1/plans/accounts/details", + {"account_uids": membership_uids}, + ) + if details_response.get("success", True): + accounts_details = details_response.get("accounts") or [] + + if raw: + console.print( + { + "self_plan": self_plan_response, + "memberships": memberships_response, + "accounts_details": accounts_details, + } + ) + return + + table = Table(title="Plans") + _add_plan_columns(table) + + # Self row. + self_plan = self_plan_response.get("plan") or {} + self_account_uid = self_plan_response.get("account_uid") or self_plan.get( + "account_uid" + ) or "" + self_handle = self_plan.get("account_handle") or "-" + _render_plan_row( + table, + scope_label="user (self)", + handle=self_handle, + name=self_handle, + account_uid=self_account_uid, + plan=self_plan, + wallet_balance=self_plan.get("wallet_balance"), + is_eligible=None, + parent="", + ) + + # Memberships rows. + details_by_uid: dict[str, dict[str, Any]] = { + entry.get("account_uid"): entry for entry in accounts_details + } + orgs_by_uid = { + m.get("uid"): m + for m in memberships + if (m.get("type") or "").lower() == "organization" + } + + # Organizations first, then teams (with parent label). + for membership in memberships: + mtype = (membership.get("type") or "").lower() + if mtype != "organization": + continue + uid = membership.get("uid") or "" + detail = details_by_uid.get(uid) or {} + plan = detail.get("subscription") or {} + _render_plan_row( + table, + scope_label="organization", + handle=membership.get("handle") or "-", + name=membership.get("name") or membership.get("handle") or "-", + account_uid=uid, + plan=plan, + wallet_balance=detail.get("wallet_balance"), + is_eligible=detail.get("is_eligible"), + parent="", + ) + + for membership in memberships: + mtype = (membership.get("type") or "").lower() + if mtype != "team": + continue + uid = membership.get("uid") or "" + detail = details_by_uid.get(uid) or {} + plan = detail.get("subscription") or {} + parent_uid = membership.get("organization_uid") or "" + parent_org = orgs_by_uid.get(parent_uid) + parent_label = ( + parent_org.get("handle") if parent_org else (parent_uid or "-") + ) + _render_plan_row( + table, + scope_label="team", + handle=membership.get("handle") or "-", + name=membership.get("name") or membership.get("handle") or "-", + account_uid=uid, + plan=plan, + wallet_balance=detail.get("wallet_balance"), + is_eligible=detail.get("is_eligible"), + parent=parent_label or "-", + ) + + console.print(table) + except typer.Exit: + raise + except Exception as e: + console.print(f"[red]Error fetching plans: {e}[/red]") + raise typer.Exit(1) + + +@app.command(name="catalog") +def plans_catalog( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + billable_account_uid: Optional[str] = typer.Option( + None, + "--billable-account-uid", + help="Optional billable account UID scope.", + ), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON payload."), +) -> None: + """List available plans from the catalog.""" + try: + client = _make_client(token=token, iam_url=iam_url) + suffix = ( + f"?billable_account_uid={billable_account_uid}" + if billable_account_uid + else "" + ) + response = _iam_get(client, f"/api/iam/v1/plans/catalog{suffix}") + if not response.get("success", True): + console.print( + f"[red]Error: {response.get('message', 'Unknown error')}[/red]" + ) + raise typer.Exit(1) + + if raw: + console.print(response) + return + + plans = response.get("plans") or response.get("available_plans") or [] + table = Table(title="Available Plans") + table.add_column("ID", style="cyan") + table.add_column("Name", style="white") + table.add_column("Code", style="white") + table.add_column("Price", style="white", justify="right") + table.add_column("Currency", style="white") + table.add_column("Included Runs", style="white", justify="right") + for plan in plans: + if not isinstance(plan, dict): + continue + table.add_row( + _normalize_value(plan.get("id")), + _normalize_value(plan.get("name")), + _normalize_value(plan.get("code") or plan.get("plan_code")), + _normalize_value(plan.get("price"), fallback="-"), + _normalize_value(plan.get("currency"), fallback="-"), + _normalize_value(plan.get("included_runs"), fallback="-"), + ) + console.print(table) + except typer.Exit: + raise + except Exception as e: + console.print(f"[red]Error fetching plans catalog: {e}[/red]") + raise typer.Exit(1) + + +# Root-level command for convenience. + + +def plans_root( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), +) -> None: + """Show plans for the authenticated user and memberships (root command).""" + plans_show(token=token, iam_url=iam_url) diff --git a/datalayer_core/cli/commands/pools.py b/datalayer_core/cli/commands/pools.py new file mode 100644 index 00000000..6d19244e --- /dev/null +++ b/datalayer_core/cli/commands/pools.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Runtime pool administration commands for Datalayer CLI.""" + +import os +from typing import Any, Optional + +import requests +import typer +from rich.console import Console +from rich.table import Table + +from datalayer_core.client.client import DatalayerClient +from datalayer_core.utils.urls import DatalayerURLs + + +app = typer.Typer( + name="pools", + help="Runtime pool administration commands", + invoke_without_command=True, +) + +console = Console() + + +def _resolve_token(token: Optional[str] = None) -> str: + if token: + return token + env_token = os.environ.get("DATALAYER_API_KEY") + if env_token: + return env_token + try: + client = DatalayerClient() + return client._get_token() or "" + except Exception: + return "" + + +def _runtimes_base_url(runtimes_url: Optional[str] = None) -> str: + urls = DatalayerURLs.from_environment(runtimes_url=runtimes_url) + return urls.runtimes_url.rstrip("/") + + +def _api_get(path: str, *, token: Optional[str], runtimes_url: Optional[str]) -> Any: + resolved_token = _resolve_token(token) + if not resolved_token: + raise RuntimeError( + "No authentication token found. Pass --token, set DATALAYER_API_KEY, or run 'datalayer login'." + ) + url = f"{_runtimes_base_url(runtimes_url)}/api/runtimes/v1{path}" + headers = {"Authorization": f"Bearer {resolved_token}"} + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + return response.json() + + +def _api_post(path: str, payload: dict[str, Any], *, token: Optional[str], runtimes_url: Optional[str]) -> Any: + resolved_token = _resolve_token(token) + if not resolved_token: + raise RuntimeError( + "No authentication token found. Pass --token, set DATALAYER_API_KEY, or run 'datalayer login'." + ) + url = f"{_runtimes_base_url(runtimes_url)}/api/runtimes/v1{path}" + headers = {"Authorization": f"Bearer {resolved_token}"} + response = requests.post(url, headers=headers, json=payload, timeout=30) + response.raise_for_status() + return response.json() + + +@app.callback() +def pools_callback(ctx: typer.Context) -> None: + """Runtime pool administration commands.""" + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + + +@app.command(name="ls") +def show_pools( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """List runtime pools with details (admin-only).""" + + try: + payload = _api_get( + "/cluster/admin/pools", + token=token, + runtimes_url=runtimes_url, + ) + pools = payload.get("pools", []) if isinstance(payload, dict) else [] + + table = Table(title="Runtime Pools") + table.add_column("Pool", style="bold") + table.add_column("Desired", justify="right") + table.add_column("Available", justify="right") + table.add_column("Pending", justify="right") + table.add_column("Assigned", justify="right") + + for pool in pools: + table.add_row( + str(pool.get("name") or "-"), + str(pool.get("desired") if pool.get("desired") is not None else "-"), + str(pool.get("available") if pool.get("available") is not None else "-"), + str(pool.get("pending") if pool.get("pending") is not None else "-"), + str(pool.get("assigned") if pool.get("assigned") is not None else "-"), + ) + + console.print(table) + except Exception as exc: + console.print(f"[red]Error listing pools: {exc}[/red]") + raise typer.Exit(1) + + +@app.command(name="set-size") +def set_pool_size( + size: int = typer.Argument(..., help="Desired pool size (>= 0)."), + pool: str = typer.Option(..., "--pool", help="Runtime pool name."), + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + runtimes_url: Optional[str] = typer.Option( + None, + "--runtimes-url", + help="Datalayer Runtimes server URL", + ), +) -> None: + """Update runtime pool size (admin-only).""" + + if size < 0: + console.print("[red]Size must be >= 0.[/red]") + raise typer.Exit(1) + + try: + payload = _api_post( + "/cluster/admin/pools/set-size", + {"pool": pool, "size": int(size)}, + token=token, + runtimes_url=runtimes_url, + ) + updated_pool = str(payload.get("pool") or pool) + updated_size = payload.get("size", size) + console.print( + f"[green]Updated pool '{updated_pool}' size to {updated_size}.[/green]" + ) + except Exception as exc: + console.print(f"[red]Error updating pool size: {exc}[/red]") + raise typer.Exit(1) diff --git a/datalayer_core/cli/commands/ray.py b/datalayer_core/cli/commands/ray.py new file mode 100644 index 00000000..b9060c9e --- /dev/null +++ b/datalayer_core/cli/commands/ray.py @@ -0,0 +1,513 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Ray commands for Datalayer CLI.""" + +from __future__ import annotations + +import ast +import json +from pathlib import Path +import re +import shlex +import sys +import time +from typing import Any, Optional + +import typer +from rich.console import Console +from rich.table import Table + +from datalayer_core.client.client import DatalayerClient +from datalayer_core.utils.urls import DatalayerURLs + +app = typer.Typer( + name="ray", + help="Manage Ray clusters and Ray jobs through the Datalayer runtimes service.", + invoke_without_command=True, +) + +clusters_app = typer.Typer( + name="clusters", + help="Manage Ray clusters.", + invoke_without_command=True, +) +jobs_app = typer.Typer( + name="jobs", + help="Manage Ray jobs.", + invoke_without_command=True, +) + +console = Console() + +_ANSI_ESCAPE_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") + + +@app.callback() +def ray_callback(ctx: typer.Context) -> None: + """Ray management commands.""" + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + + +@clusters_app.callback() +def clusters_callback(ctx: typer.Context) -> None: + """Ray cluster commands.""" + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + + +@jobs_app.callback() +def jobs_callback(ctx: typer.Context) -> None: + """Ray job commands.""" + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + + +def _make_client( + token: Optional[str] = None, +) -> DatalayerClient: + urls = DatalayerURLs.from_environment() + # Ray CLI is intentionally routed via runtimes, never directly to ray_url. + urls.ray_url = urls.runtimes_url + return DatalayerClient(urls=urls, token=token) + + +def _print_json(payload: dict[str, Any]) -> None: + console.print_json(data=payload) + + +def _load_json(raw: Optional[str], flag_name: str) -> dict[str, Any]: + if not raw: + return {} + try: + value = json.loads(raw) + except Exception as exc: + raise typer.BadParameter(f"Invalid JSON for {flag_name}: {exc}") from exc + if not isinstance(value, dict): + raise typer.BadParameter(f"{flag_name} must decode to a JSON object") + return value + + +def _resolve_python_inline(raw: Optional[str]) -> Optional[str]: + """Resolve inline Python payload, supporting stdin/file references. + + Supported syntaxes for --python-inline/--py: + - raw source text + - @- : read from stdin (supports multiline heredoc pipelines) + - @ : read from local file path + """ + if raw is None: + return None + + value = str(raw) + if value == "@-": + return sys.stdin.read() + + if value.startswith("@") and len(value) > 1: + path = Path(value[1:]).expanduser() + try: + return path.read_text() + except Exception as exc: + raise typer.BadParameter( + f"Unable to read inline Python source from {path}: {exc}" + ) from exc + + return value + + +def _normalize_logs_text(value: Any) -> str: + """Normalize logs payloads into readable plain text. + + Handles legacy payloads where logs are serialized as Python bytes repr, + e.g. `b"..."`, and strips ANSI terminal escape sequences. + """ + + if value is None: + return "" + + text: str + if isinstance(value, bytes): + text = value.decode("utf-8", errors="replace") + else: + text = str(value) + + stripped = text.strip() + if stripped.startswith(("b'", 'b"')): + try: + literal = ast.literal_eval(stripped) + if isinstance(literal, bytes): + text = literal.decode("utf-8", errors="replace") + else: + text = str(literal) + except Exception: + pass + + text = _ANSI_ESCAPE_RE.sub("", text) + return text + + +def _format_scope_label(kind: str, handle: str, uid: str, fallback_kind: str) -> str: + scope_kind = (kind or fallback_kind).strip() + scope_handle = (handle or "").strip() + scope_uid = (uid or "").strip() + if scope_handle: + return f"{scope_kind}: @{scope_handle}" + if scope_uid: + return f"{scope_kind}: {scope_uid}" + return "" + + +@clusters_app.command(name="list") +@clusters_app.command(name="ls") +def clusters_list( + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON."), +) -> None: + client = _make_client(token=token) + payload = client.ray_list_clusters(namespace=namespace) + if raw: + _print_json(payload) + return + + items = payload.get("clusters") or [] + table = Table(title=f"Ray Clusters ({len(items)})") + table.add_column("Name", style="cyan") + table.add_column("Namespace") + table.add_column("State") + table.add_column("Workers") + table.add_column("Principal") + table.add_column("Billable") + + for item in items: + metadata = item.get("metadata") or {} + status = item.get("status") or {} + ownership = item.get("ownership") or {} + desired = status.get("desiredWorkerReplicas") + available = status.get("availableWorkerReplicas") + workers = f"{available}/{desired}" if desired is not None else str(available or "") + principal = _format_scope_label( + str(item.get("principal_kind") or ownership.get("principal_kind") or ""), + str(item.get("principal_handle") or ownership.get("principal_handle") or ""), + str(item.get("principal_uid") or ownership.get("principal_uid") or ""), + "principal", + ) + billable = _format_scope_label( + str(item.get("billable_account_kind") or ownership.get("billable_account_kind") or ""), + str(item.get("billable_account_handle") or ownership.get("billable_account_handle") or ""), + str(item.get("billable_account_uid") or ownership.get("billable_account_uid") or ""), + "account", + ) + table.add_row( + str(metadata.get("name", "")), + str(metadata.get("namespace", namespace)), + str(status.get("state", "")), + workers, + principal, + billable, + ) + + console.print(table) + + +@clusters_app.command(name="create") +def clusters_create( + name: str = typer.Argument(..., help="RayCluster name."), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + image: str = typer.Option("rayproject/ray:2.38.0", "--image", help="Ray container image."), + ray_version: str = typer.Option("2.38.0", "--ray-version", help="Ray version in CR spec."), + worker_replicas: int = typer.Option(1, "--worker-replicas", min=0), + worker_min_replicas: int = typer.Option(1, "--worker-min-replicas", min=0), + worker_max_replicas: int = typer.Option(3, "--worker-max-replicas", min=0), + custom_spec_json: Optional[str] = typer.Option( + None, + "--custom-spec-json", + help="Optional full RayCluster spec JSON object.", + ), + token: Optional[str] = typer.Option(None, "--token", help="API token."), +) -> None: + custom_spec = _load_json(custom_spec_json, "--custom-spec-json") + payload: dict[str, Any] = { + "name": name, + "namespace": namespace, + "image": image, + "ray_version": ray_version, + "worker_replicas": worker_replicas, + "worker_min_replicas": worker_min_replicas, + "worker_max_replicas": worker_max_replicas, + } + if custom_spec: + payload["custom_spec"] = custom_spec + + client = _make_client(token=token) + result = client.ray_create_cluster(payload) + if result.get("success") is False: + reason = str(result.get("message") or result.get("reason") or "Unable to create cluster") + console.print(f"[red]Cluster creation failed:[/red] {reason}") + raise typer.Exit(code=1) + + cluster = result.get("cluster") or {} + metadata = cluster.get("metadata") or {} + console.print( + f"[green]Cluster created:[/green] {metadata.get('name', '')} " + f"(ns={metadata.get('namespace', namespace)})" + ) + console.print("[dim]Next: dla ray clusters ls --namespace {0}[/dim]".format(namespace)) + + +@clusters_app.command(name="get") +def clusters_get( + name: str = typer.Argument(..., help="RayCluster name."), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), +) -> None: + client = _make_client(token=token) + payload = client.ray_get_cluster(name, namespace=namespace) + _print_json(payload) + + +@clusters_app.command(name="delete") +def clusters_delete( + name: str = typer.Argument(..., help="RayCluster name."), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), +) -> None: + client = _make_client(token=token) + client.ray_delete_cluster(name, namespace=namespace) + console.print(f"[green]Cluster deleted:[/green] {name} (ns={namespace})") + + +@jobs_app.command(name="submit") +def jobs_submit( + cluster_name: str = typer.Argument(..., help="Target RayCluster name."), + entrypoint: Optional[str] = typer.Option( + None, + "--entrypoint", + help="Ray job entrypoint command.", + ), + python_inline: Optional[str] = typer.Option( + None, + "--python-inline", + "--py", + help="Inline Python source; supports @- (stdin) and @ for multiline input.", + ), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + job_name: Optional[str] = typer.Option(None, "--job-name", help="Optional RayJob name."), + runtime_env_yaml: Optional[str] = typer.Option(None, "--runtime-env-yaml", help="Raw runtimeEnvYAML string."), + shutdown_after_job_finishes: bool = typer.Option(True, "--shutdown-after-job-finishes/--keep-cluster"), + ttl_seconds_after_finished: Optional[int] = typer.Option(3600, "--ttl-seconds-after-finished", min=0), + token: Optional[str] = typer.Option(None, "--token", help="API token."), +) -> None: + resolved_python_inline = _resolve_python_inline(python_inline) + + if bool(entrypoint) == bool(resolved_python_inline): + raise typer.BadParameter( + "Provide exactly one of --entrypoint or --python-inline/--py." + ) + + payload: dict[str, Any] = { + "namespace": namespace, + "shutdown_after_job_finishes": shutdown_after_job_finishes, + "ttl_seconds_after_finished": ttl_seconds_after_finished, + } + if entrypoint: + payload["entrypoint"] = entrypoint + if resolved_python_inline: + # Backward compatibility: older ray addon APIs require `entrypoint`. + # Keep sending a concrete entrypoint while also passing python_inline + # for newer servers that natively support it. + payload["entrypoint"] = f"python -c {shlex.quote(resolved_python_inline)}" + payload["python_inline"] = resolved_python_inline + if job_name: + payload["job_name"] = job_name + if runtime_env_yaml: + payload["runtime_env_yaml"] = runtime_env_yaml + + client = _make_client(token=token) + result = client.ray_submit_job(cluster_name, payload) + job = result.get("job") or {} + metadata = job.get("metadata") or {} + console.print( + f"[green]Job submitted:[/green] {metadata.get('name', '')} " + f"(cluster={cluster_name}, ns={namespace})" + ) + if metadata.get("name"): + console.print( + "[dim]Next: dla ray jobs monitor {0} --namespace {1}[/dim]".format( + metadata.get("name"), namespace + ) + ) + + +@jobs_app.command(name="list") +@jobs_app.command(name="ls") +def jobs_list( + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + cluster_name: Optional[str] = typer.Option(None, "--cluster-name", help="Filter by cluster label."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON."), +) -> None: + client = _make_client(token=token) + payload = client.ray_list_jobs(namespace=namespace, cluster_name=cluster_name) + if raw: + _print_json(payload) + return + + items = payload.get("jobs") or [] + table = Table(title=f"Ray Jobs ({len(items)})") + table.add_column("Name", style="cyan") + table.add_column("Namespace") + table.add_column("Cluster") + table.add_column("Status") + + for item in items: + metadata = item.get("metadata") or {} + labels = metadata.get("labels") or {} + status = item.get("status") or {} + table.add_row( + str(metadata.get("name", "")), + str(metadata.get("namespace", namespace)), + str(labels.get("ray.io/cluster", "")), + str(status.get("jobStatus", "")), + ) + + console.print(table) + + +@jobs_app.command(name="status") +def jobs_status( + name: str = typer.Argument(..., help="RayJob name."), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), +) -> None: + client = _make_client(token=token) + payload = client.ray_get_job(name, namespace=namespace) + _print_json(payload) + + +@jobs_app.command(name="delete") +def jobs_delete( + name: str = typer.Argument(..., help="RayJob name."), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), +) -> None: + client = _make_client(token=token) + client.ray_delete_job(name, namespace=namespace) + console.print(f"[green]Job deleted:[/green] {name} (ns={namespace})") + + +@jobs_app.command(name="logs") +def jobs_logs( + name: str = typer.Argument(..., help="RayJob name."), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + pod_name: Optional[str] = typer.Option(None, "--pod-name", help="Optional explicit pod name."), + container: Optional[str] = typer.Option(None, "--container", help="Optional pod container name."), + tail_lines: int = typer.Option(200, "--tail-lines", min=1, max=5000), + token: Optional[str] = typer.Option(None, "--token", help="API token."), +) -> None: + client = _make_client(token=token) + payload = client.ray_get_job_logs( + name, + namespace=namespace, + pod_name=pod_name, + container=container, + tail_lines=tail_lines, + ) + console.print( + f"[bold]Logs[/bold] job={payload.get('job_name', name)} " + f"pod={payload.get('pod_name', '')}" + ) + console.print(_normalize_logs_text(payload.get("logs", ""))) + + +@jobs_app.command(name="events") +def jobs_events( + name: str = typer.Argument(..., help="RayJob name."), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + limit: int = typer.Option(100, "--limit", min=1, max=1000), + token: Optional[str] = typer.Option(None, "--token", help="API token."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON."), +) -> None: + client = _make_client(token=token) + payload = client.ray_get_job_events(name, namespace=namespace, limit=limit) + if raw: + _print_json(payload) + return + + events = payload.get("events") or [] + table = Table(title=f"Ray Job Events ({len(events)})") + table.add_column("Type") + table.add_column("Reason") + table.add_column("Target") + table.add_column("Time") + table.add_column("Message") + + for event in events: + table.add_row( + str(event.get("type") or ""), + str(event.get("reason") or ""), + str(event.get("involved_object_name") or ""), + str( + event.get("event_time") + or event.get("last_timestamp") + or event.get("first_timestamp") + or "" + ), + str(event.get("message") or ""), + ) + + console.print(table) + + +@jobs_app.command(name="monitor") +def jobs_monitor( + name: str = typer.Argument(..., help="RayJob name."), + namespace: str = typer.Option("default", "--namespace", help="Kubernetes namespace."), + interval_seconds: int = typer.Option(5, "--interval-seconds", min=1, help="Polling interval in seconds."), + timeout_seconds: int = typer.Option(600, "--timeout-seconds", min=1, help="Maximum time to wait before exiting."), + show_events: bool = typer.Option(False, "--show-events", help="Show latest events on each poll."), + token: Optional[str] = typer.Option(None, "--token", help="API token."), +) -> None: + """Monitor RayJob status until it reaches a terminal state.""" + client = _make_client(token=token) + started = time.time() + last_status: Optional[str] = None + terminal_statuses = {"SUCCEEDED", "FAILED", "STOPPED"} + + while True: + payload = client.ray_get_job(name, namespace=namespace) + status = str(payload.get("status") or "UNKNOWN").upper() + if status != last_status: + console.print(f"[bold]job={name}[/bold] ns={namespace} status={status}") + last_status = status + + if show_events: + events_payload = client.ray_get_job_events(name, namespace=namespace, limit=5) + events = events_payload.get("events") or [] + for event in events[:3]: + console.print( + "[dim]{0} {1}: {2}[/dim]".format( + event.get("type") or "", + event.get("reason") or "", + event.get("message") or "", + ) + ) + + if status in terminal_statuses: + console.print(f"[green]Job reached terminal status:[/green] {status}") + if status != "SUCCEEDED": + raise typer.Exit(1) + return + + if (time.time() - started) >= timeout_seconds: + console.print( + f"[red]Timed out after {timeout_seconds}s while waiting for job status.[/red]" + ) + raise typer.Exit(1) + + time.sleep(interval_seconds) + + +app.add_typer(clusters_app) +app.add_typer(jobs_app) diff --git a/datalayer_core/cli/commands/runtime_checkpoints.py b/datalayer_core/cli/commands/runtime_checkpoints.py index 3681d25d..16face1c 100644 --- a/datalayer_core/cli/commands/runtime_checkpoints.py +++ b/datalayer_core/cli/commands/runtime_checkpoints.py @@ -70,7 +70,7 @@ def checkpoints_callback(ctx: typer.Context) -> None: typer.echo(ctx.get_help()) -@app.command(name="list") +@app.command(name="ls") def checkpoints_list( runtime_uid: Optional[str] = typer.Option( None, @@ -108,7 +108,6 @@ def checkpoints_list( raise typer.Exit(1) -@app.command(name="ls") def checkpoints_ls( runtime_uid: Optional[str] = typer.Option( None, @@ -127,7 +126,7 @@ def checkpoints_ls( help="Datalayer Runtimes server URL.", ), ) -> None: - """List runtime checkpoints (alias for list).""" + """List runtime checkpoints (root command alias).""" checkpoints_list(runtime_uid=runtime_uid, token=token, runtimes_url=runtimes_url) diff --git a/datalayer_core/cli/commands/runtimes.py b/datalayer_core/cli/commands/runtimes.py index eb9d4ac4..7a0de637 100644 --- a/datalayer_core/cli/commands/runtimes.py +++ b/datalayer_core/cli/commands/runtimes.py @@ -29,20 +29,26 @@ def runtimes_callback(ctx: typer.Context) -> None: def _make_client( token: Optional[str] = None, + iam_url: Optional[str] = None, runtimes_url: Optional[str] = None, ) -> DatalayerClient: """Create a DatalayerClient with optional runtimes URL override.""" - urls = DatalayerURLs.from_environment(runtimes_url=runtimes_url) + urls = DatalayerURLs.from_environment(iam_url=iam_url, runtimes_url=runtimes_url) return DatalayerClient(urls=urls, token=token) -@app.command(name="list") +@app.command(name="ls") def list_runtimes( token: Optional[str] = typer.Option( None, "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), runtimes_url: Optional[str] = typer.Option( None, "--runtimes-url", @@ -51,7 +57,11 @@ def list_runtimes( ) -> None: """List running runtimes.""" try: - client = _make_client(token=token, runtimes_url=runtimes_url) + client = _make_client( + token=token, + iam_url=iam_url, + runtimes_url=runtimes_url, + ) runtimes = client.list_runtimes() # Convert to dict format for display_runtimes @@ -79,23 +89,6 @@ def list_runtimes( raise typer.Exit(1) -@app.command(name="ls") -def list_runtimes_alias( - token: Optional[str] = typer.Option( - None, - "--token", - help="Authentication token (Bearer token for API requests).", - ), - runtimes_url: Optional[str] = typer.Option( - None, - "--runtimes-url", - help="Datalayer Runtimes server URL", - ), -) -> None: - """List running runtimes (alias for list).""" - list_runtimes(token=token, runtimes_url=runtimes_url) - - @app.command(name="create") def create_runtime( environment: Optional[str] = typer.Argument(None, help="Environment name"), @@ -114,11 +107,31 @@ def create_runtime( "--time-reservation", help="Time reservation in minutes for the runtime", ), + billable_account_uid: Optional[str] = typer.Option( + None, + "--billable-account-uid", + help="Account UID to bill the runtime to (org/team). Defaults to the authenticated user.", + ), + billable_account_type: Optional[str] = typer.Option( + None, + "--billable-account-type", + help="Billable account type: user, organization, or team.", + ), + billable_account_handle: Optional[str] = typer.Option( + None, + "--billable-account-handle", + help="Billable account handle (informational).", + ), token: Optional[str] = typer.Option( None, "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), runtimes_url: Optional[str] = typer.Option( None, "--runtimes-url", @@ -129,7 +142,11 @@ def create_runtime( import questionary try: - client = _make_client(token=token, runtimes_url=runtimes_url) + client = _make_client( + token=token, + iam_url=iam_url, + runtimes_url=runtimes_url, + ) if environment is None: # List environments and let the user pick one @@ -160,6 +177,9 @@ def create_runtime( name=given_name, environment=environment, time_reservation=final_time_reservation, + billable_account_uid=billable_account_uid, + billable_account_type=billable_account_type, + billable_account_handle=billable_account_handle, ) console.print( @@ -185,6 +205,11 @@ def terminate_runtime( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), runtimes_url: Optional[str] = typer.Option( None, "--runtimes-url", @@ -195,7 +220,11 @@ def terminate_runtime( import questionary try: - client = _make_client(token=token, runtimes_url=runtimes_url) + client = _make_client( + token=token, + iam_url=iam_url, + runtimes_url=runtimes_url, + ) if pod_name is None: # List runtimes and let the user pick one @@ -247,6 +276,11 @@ def runtimes_list( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), runtimes_url: Optional[str] = typer.Option( None, "--runtimes-url", @@ -254,7 +288,7 @@ def runtimes_list( ), ) -> None: """List running runtimes (root command).""" - list_runtimes(token=token, runtimes_url=runtimes_url) + list_runtimes(token=token, iam_url=iam_url, runtimes_url=runtimes_url) def runtimes_ls( @@ -263,6 +297,11 @@ def runtimes_ls( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), runtimes_url: Optional[str] = typer.Option( None, "--runtimes-url", @@ -270,4 +309,4 @@ def runtimes_ls( ), ) -> None: """List running runtimes (root command alias).""" - list_runtimes(token=token, runtimes_url=runtimes_url) + list_runtimes(token=token, iam_url=iam_url, runtimes_url=runtimes_url) diff --git a/datalayer_core/cli/commands/runtime_snapshots.py b/datalayer_core/cli/commands/sandbox_snapshots.py similarity index 89% rename from datalayer_core/cli/commands/runtime_snapshots.py rename to datalayer_core/cli/commands/sandbox_snapshots.py index 0b63bbf1..bdc0caa3 100644 --- a/datalayer_core/cli/commands/runtime_snapshots.py +++ b/datalayer_core/cli/commands/sandbox_snapshots.py @@ -9,11 +9,11 @@ from rich.console import Console from datalayer_core.client.client import DatalayerClient -from datalayer_core.displays.runtime_snapshots import display_runtime_snapshots +from datalayer_core.displays.sandbox_snapshots import display_code_sandbox_snapshots # Create a Typer app for snapshot commands app = typer.Typer( - name="runtime-snapshots", + name="sandbox-snapshots", help="Runtime snapshots management commands", invoke_without_command=True, ) @@ -28,7 +28,7 @@ def snapshots_callback(ctx: typer.Context) -> None: typer.echo(ctx.get_help()) -@app.command(name="list") +@app.command(name="ls") def list_snapshots( token: Optional[str] = typer.Option( None, @@ -54,25 +54,13 @@ def list_snapshots( } ) - display_runtime_snapshots(snapshot_dicts) + display_code_sandbox_snapshots(snapshot_dicts) except Exception as e: console.print(f"[red]Error listing snapshots: {e}[/red]") raise typer.Exit(1) -@app.command(name="ls") -def list_snapshots_alias( - token: Optional[str] = typer.Option( - None, - "--token", - help="Authentication token (Bearer token for API requests).", - ), -) -> None: - """List all snapshots (alias for list).""" - list_snapshots(token=token) - - @app.command(name="create") def create_snapshot( pod_name: Optional[str] = typer.Option( @@ -121,7 +109,7 @@ def create_snapshot( "metadata": snapshot.metadata, } - display_runtime_snapshots([snapshot_dict]) + display_code_sandbox_snapshots([snapshot_dict]) console.print( f"[green]Snapshot '{snapshot.name}' created successfully![/green]" ) diff --git a/datalayer_core/cli/commands/secrets.py b/datalayer_core/cli/commands/secrets.py index 061310e3..1acee689 100644 --- a/datalayer_core/cli/commands/secrets.py +++ b/datalayer_core/cli/commands/secrets.py @@ -27,7 +27,7 @@ def secrets_callback(ctx: typer.Context) -> None: typer.echo(ctx.get_help()) -@app.command(name="list") +@app.command(name="ls") def list_secrets( token: Optional[str] = typer.Option( None, @@ -59,18 +59,6 @@ def list_secrets( raise typer.Exit(1) -@app.command(name="ls") -def list_secrets_alias( - token: Optional[str] = typer.Option( - None, - "--token", - help="Authentication token (Bearer token for API requests).", - ), -) -> None: - """List all secrets (alias for list).""" - list_secrets(token=token) - - @app.command(name="create") def create_secret( name: str = typer.Argument(..., help="Name of the secret"), diff --git a/datalayer_core/cli/commands/subscription.py b/datalayer_core/cli/commands/subscription.py index c4d85ce7..be73efe9 100644 --- a/datalayer_core/cli/commands/subscription.py +++ b/datalayer_core/cli/commands/subscription.py @@ -21,7 +21,7 @@ def _extract_subscription(payload: dict[str, Any]) -> dict[str, Any]: - return payload.get("subscription") or {} + return payload.get("plan") or {} def _normalize_value(value: Any, fallback: str = "Not available") -> str: @@ -71,12 +71,8 @@ def _as_plan_list(value: Any) -> list[dict[str, Any]]: def _extract_available_plans(payload: dict[str, Any]) -> list[dict[str, Any]]: subscription = _extract_subscription(payload) candidates = [ - payload.get("available_subscriptions"), payload.get("available_plans"), payload.get("plans"), - subscription.get("available_subscriptions") - if isinstance(subscription, dict) - else None, subscription.get("available_plans") if isinstance(subscription, dict) else None, subscription.get("plans") if isinstance(subscription, dict) else None, ] @@ -572,8 +568,8 @@ def subscription_stats( paid_count = 0 for user in users: - status = str(user.get("subscription_status_s") or "none").lower() - plan = str(user.get("subscription_plan_s") or "none") + status = str(user.get("plan_status_s") or "none").lower() + plan = str(user.get("plan_name_s") or "none") status_counter[status] += 1 plan_counter[plan] += 1 @@ -663,9 +659,9 @@ def subscription_admin_users( for user in users: table.add_row( _normalize_value(user.get("handle_s")), - _normalize_value(user.get("subscription_plan_s"), fallback="none"), - _normalize_value(user.get("subscription_status_s"), fallback="none"), - _normalize_value(user.get("credits_customer_uid"), fallback="none"), + _normalize_value(user.get("plan_name_s"), fallback="none"), + _normalize_value(user.get("plan_status_s"), fallback="none"), + _normalize_value(user.get("stripe_customer_id_s"), fallback="none"), ) console.print(table) @@ -740,13 +736,13 @@ def subscription_dry_run( if sub_resp.get("success", True): sub = _extract_subscription(sub_resp) console.print( - "[green]OK[/green] /api/iam/v1/subscription " + "[green]OK[/green] /api/iam/v1/plans " f"plan={_normalize_value(sub.get('plan_name'), 'unknown')} " f"status={_normalize_value(sub.get('status'), 'unknown')}" ) else: console.print( - "[red]FAILED[/red] /api/iam/v1/subscription " + "[red]FAILED[/red] /api/iam/v1/plans " f"{sub_resp.get('message', 'Unknown error')}" ) diff --git a/datalayer_core/cli/commands/tokens.py b/datalayer_core/cli/commands/tokens.py index 28f73d6c..3d7d50f4 100644 --- a/datalayer_core/cli/commands/tokens.py +++ b/datalayer_core/cli/commands/tokens.py @@ -27,7 +27,7 @@ def tokens_callback(ctx: typer.Context) -> None: typer.echo(ctx.get_help()) -@app.command(name="list") +@app.command(name="ls") def list_tokens( token: Optional[str] = typer.Option( None, @@ -59,18 +59,6 @@ def list_tokens( raise typer.Exit(1) -@app.command(name="ls") -def list_tokens_alias( - token: Optional[str] = typer.Option( - None, - "--token", - help="Authentication token (Bearer token for API requests).", - ), -) -> None: - """List all tokens (alias for list).""" - list_tokens(token=token) - - @app.command(name="create") def create_token( name: str = typer.Argument(..., help="Name of the token"), diff --git a/datalayer_core/cli/commands/usage.py b/datalayer_core/cli/commands/usage.py index cca86e1e..accd4316 100644 --- a/datalayer_core/cli/commands/usage.py +++ b/datalayer_core/cli/commands/usage.py @@ -3,6 +3,7 @@ """Usage/credits commands for Datalayer CLI.""" +from datetime import datetime, timezone from typing import Any, Optional import typer @@ -11,11 +12,12 @@ from datalayer_core.client.client import DatalayerClient from datalayer_core.displays.usage import display_usage +from datalayer_core.utils.urls import DatalayerURLs app = typer.Typer( name="usage", help="Usage and credits commands", invoke_without_command=True ) -console = Console() +console = Console(width=200) def _normalize_value(value: Any, fallback: str = "n/a") -> str: @@ -39,6 +41,39 @@ def _iam_post( ).json() +def _make_client( + token: Optional[str] = None, + iam_url: Optional[str] = None, +) -> DatalayerClient: + urls = DatalayerURLs.from_environment(iam_url=iam_url) + return DatalayerClient(urls=urls, token=token) + + +def _parse_iso_dt(value: Any) -> datetime | None: + if not value: + return None + text = str(value).strip() + if not text: + return None + try: + normalized = text.replace("Z", "+00:00") + parsed = datetime.fromisoformat(normalized) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + except Exception: + return None + + +def _format_duration_seconds(start: Any, end: Any) -> str: + start_dt = _parse_iso_dt(start) + end_dt = _parse_iso_dt(end) + if start_dt is None or end_dt is None: + return "n/a" + duration = max(0.0, (end_dt - start_dt).total_seconds()) + return f"{duration:.3f}" + + @app.callback() def usage_callback(ctx: typer.Context) -> None: """Usage and credits commands.""" @@ -53,6 +88,11 @@ def usage_show( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), raw: bool = typer.Option( False, "--raw", @@ -61,7 +101,7 @@ def usage_show( ) -> None: """Show credits usage and reservations.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) usage = client.get_usage_credits() if not usage.get("success", True): console.print(f"[red]Error: {usage.get('message', 'Unknown error')}[/red]") @@ -77,6 +117,246 @@ def usage_show( raise typer.Exit(1) +@app.command(name="records") +def usage_records( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + billable_account_uid: Optional[str] = typer.Option( + None, + "--billable-account-uid", + help="Optional account UID scope. Defaults to the authenticated account.", + ), + billable_account_kind: Optional[str] = typer.Option( + None, + "--billable-account-kind", + help="Optional account kind scope: user or organization.", + ), + limit: int = typer.Option(20, "--limit", help="Maximum number of usage records."), + group_by_billable: bool = typer.Option( + False, + "--group-by-billable", + help="Render one table per billable account.", + ), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON payload from IAM."), +) -> None: + """Show detailed usage records for the authenticated account scope.""" + try: + client = _make_client(token=token, iam_url=iam_url) + params: list[str] = [] + if billable_account_uid: + params.append(f"billable_account_uid={billable_account_uid}") + if billable_account_kind: + params.append(f"billable_account_kind={billable_account_kind}") + query_suffix = f"?{'&'.join(params)}" if params else "" + response = _iam_get(client, f"/api/iam/v1/usage/user{query_suffix}") + if not response.get("success", True): + console.print( + f"[red]Error: {response.get('message', 'Unknown error')}[/red]" + ) + raise typer.Exit(1) + + usages = (response.get("usages") or [])[: max(1, limit)] + if raw: + console.print(response) + return + + def _add_columns(table: Table) -> None: + table.add_column("Resource", style="cyan", no_wrap=True) + table.add_column("Type", style="white", no_wrap=True) + table.add_column("State", style="white", no_wrap=True) + table.add_column("Creator", style="dim", no_wrap=True) + table.add_column("Billable", style="dim", no_wrap=True) + table.add_column("Start", style="white", no_wrap=True) + table.add_column("End", style="white", no_wrap=True) + table.add_column("Duration(s)", style="white", justify="right", no_wrap=True) + table.add_column("Credits", style="yellow", justify="right", no_wrap=True) + table.add_column("Burn/s", style="white", justify="right", no_wrap=True) + + def _row_for(usage: dict[str, Any]) -> tuple[str, ...]: + metadata = usage.get("metadata") or {} + resource = ( + usage.get("resource_given_name") + or usage.get("resource_uid") + or usage.get("id") + or "-" + ) + start = usage.get("start_date") + end = usage.get("end_date") + creator = usage.get("account_uid") + billable = ( + usage.get("billable_account_uid") + or usage.get("account_uid") + ) + return ( + _normalize_value(resource), + _normalize_value(usage.get("resource_type")), + _normalize_value( + usage.get("resource_state") + or usage.get("state") + or metadata.get("resource_state") + ), + _normalize_value(creator), + _normalize_value(billable), + _normalize_value(start), + _normalize_value(end), + _format_duration_seconds(start, end), + _normalize_value(usage.get("credits"), fallback="0"), + _normalize_value(usage.get("burning_rate"), fallback="0"), + ) + + if group_by_billable: + groups: dict[str, list[dict[str, Any]]] = {} + for usage in usages: + key = ( + usage.get("billable_account_uid") + or usage.get("account_uid") + or "unknown" + ) + groups.setdefault(key, []).append(usage) + for billable_uid, group_usages in sorted(groups.items()): + total_credits = 0.0 + for u in group_usages: + try: + total_credits += float(u.get("credits") or 0) + except (TypeError, ValueError): + pass + table = Table( + title=( + f"Billable Account [bold]{billable_uid}[/bold] " + f"— {len(group_usages)} record(s), " + f"{total_credits:.4f} credits" + ) + ) + _add_columns(table) + for usage in group_usages: + table.add_row(*_row_for(usage)) + console.print(table) + else: + table = Table(title="Usage Records") + _add_columns(table) + for usage in usages: + table.add_row(*_row_for(usage)) + console.print(table) + except Exception as e: + console.print(f"[red]Error fetching usage records: {e}[/red]") + raise typer.Exit(1) + + +@app.command(name="reservations") +def usage_reservations( + token: Optional[str] = typer.Option( + None, + "--token", + help="Authentication token (Bearer token for API requests).", + ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), + reservation_type: Optional[str] = typer.Option( + None, + "--type", + help="Optional reservation type filter.", + ), + billable_account_uid: Optional[str] = typer.Option( + None, + "--billable-account-uid", + help="Optional account UID scope for fallback credits view.", + ), + billable_account_kind: Optional[str] = typer.Option( + None, + "--billable-account-kind", + help="Optional account kind scope for fallback credits view: user or organization.", + ), + limit: int = typer.Option(20, "--limit", help="Maximum number of reservations."), + raw: bool = typer.Option(False, "--raw", help="Print raw JSON payload from IAM."), +) -> None: + """Show reservations from IAM reservations endpoint.""" + try: + client = _make_client(token=token, iam_url=iam_url) + query_suffix = f"?type={reservation_type}" if reservation_type else "" + response = _iam_get(client, f"/api/iam/v1/usage/reservations{query_suffix}") + if not response.get("success", True): + console.print( + f"[red]Error: {response.get('message', 'Unknown error')}[/red]" + ) + raise typer.Exit(1) + + data = response.get("data") or {} + reservations = data.get("reservations") or [] + source = "usage/reservations" + + if not reservations: + params: list[str] = [] + if billable_account_uid: + params.append(f"billable_account_uid={billable_account_uid}") + if billable_account_kind: + params.append(f"billable_account_kind={billable_account_kind}") + credits_query = f"?{'&'.join(params)}" if params else "" + credits_response = _iam_get( + client, + f"/api/iam/v1/usage/credits{credits_query}", + ) + if credits_response.get("success", True): + reservations = credits_response.get("reservations") or [] + source = "usage/credits" + + reservations = reservations[: max(1, limit)] + if raw: + console.print(response) + return + + if source == "usage/credits": + console.print( + "[yellow]No reservations from /usage/reservations; showing active reservations from /usage/credits.[/yellow]" + ) + + table = Table(title="Reservations") + table.add_column("Reservation", style="cyan") + table.add_column("Resource", style="white") + table.add_column("Type", style="white") + table.add_column("Credits", style="white", justify="right") + table.add_column("Burn/s", style="white", justify="right") + table.add_column("Start", style="white") + table.add_column("Last Update", style="white") + + for reservation in reservations: + table.add_row( + _normalize_value(reservation.get("id")), + _normalize_value( + reservation.get("resource") + or reservation.get("resource_uid") + or reservation.get("resource_given_name") + ), + _normalize_value(reservation.get("resource_type")), + _normalize_value( + reservation.get("credits") + or reservation.get("credits_limit"), + fallback="0", + ), + _normalize_value(reservation.get("burning_rate"), fallback="0"), + _normalize_value(reservation.get("start_date")), + _normalize_value( + reservation.get("last_update") + or reservation.get("updated_at") + or reservation.get("last_update_ts_dt") + ), + ) + console.print(table) + except Exception as e: + console.print(f"[red]Error fetching reservations: {e}[/red]") + raise typer.Exit(1) + + @app.command(name="org-overview") def usage_org_overview( organization_uid: str = typer.Option( @@ -89,11 +369,16 @@ def usage_org_overview( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), raw: bool = typer.Option(False, "--raw", help="Print raw JSON payload."), ) -> None: """Show organization/team credits allocation overview.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_get( client, f"/api/iam/v1/usage/credits/allocations/organizations/{organization_uid}/overview", @@ -155,11 +440,16 @@ def usage_team_overview( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), raw: bool = typer.Option(False, "--raw", help="Print raw JSON payload."), ) -> None: """Show team/member credits allocation overview.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_get( client, f"/api/iam/v1/usage/credits/allocations/teams/{team_uid}/overview", @@ -211,11 +501,16 @@ def usage_org_history( ..., "--organization-uid", help="Organization UID." ), token: Optional[str] = typer.Option(None, "--token", help="Authentication token."), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), limit: int = typer.Option(20, "--limit", help="Max events to print."), ) -> None: """Show organization/team credits transfer history.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_get( client, f"/api/iam/v1/usage/credits/allocations/organizations/{organization_uid}/history", @@ -249,11 +544,16 @@ def usage_org_history( def usage_team_history( team_uid: str = typer.Option(..., "--team-uid", help="Team UID."), token: Optional[str] = typer.Option(None, "--token", help="Authentication token."), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), limit: int = typer.Option(20, "--limit", help="Max events to print."), ) -> None: """Show team/member credits transfer history.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_get( client, f"/api/iam/v1/usage/credits/allocations/teams/{team_uid}/history", @@ -289,13 +589,18 @@ def usage_org_monitor( ..., "--organization-uid", help="Organization UID." ), token: Optional[str] = typer.Option(None, "--token", help="Authentication token."), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), window_hours: int = typer.Option( 24, "--window-hours", help="Monitoring window in hours." ), ) -> None: """Show organization/team credits monitoring metrics and recommendations.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_get( client, f"/api/iam/v1/usage/credits/allocations/organizations/{organization_uid}/monitoring?window_hours={max(1, window_hours)}", @@ -372,13 +677,18 @@ def usage_org_monitor( def usage_team_monitor( team_uid: str = typer.Option(..., "--team-uid", help="Team UID."), token: Optional[str] = typer.Option(None, "--token", help="Authentication token."), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), window_hours: int = typer.Option( 24, "--window-hours", help="Monitoring window in hours." ), ) -> None: """Show team/member credits monitoring metrics and recommendations.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_get( client, f"/api/iam/v1/usage/credits/allocations/teams/{team_uid}/monitoring?window_hours={max(1, window_hours)}", @@ -458,10 +768,15 @@ def usage_org_allocate_team( ..., "--amount", help="Amount of credits to allocate." ), token: Optional[str] = typer.Option(None, "--token", help="Authentication token."), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), ) -> None: """Allocate credits from organization to team.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_post( client, f"/api/iam/v1/usage/credits/allocations/organizations/{organization_uid}/teams/{team_uid}", @@ -487,10 +802,15 @@ def usage_org_revoke_team( team_uid: str = typer.Option(..., "--team-uid", help="Team UID."), amount: float = typer.Option(..., "--amount", help="Amount of credits to revoke."), token: Optional[str] = typer.Option(None, "--token", help="Authentication token."), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), ) -> None: """Revoke credits from team back to organization.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_post( client, f"/api/iam/v1/usage/credits/allocations/organizations/{organization_uid}/teams/{team_uid}/revoke", @@ -516,10 +836,15 @@ def usage_team_allocate_member( ..., "--amount", help="Amount of credits to allocate." ), token: Optional[str] = typer.Option(None, "--token", help="Authentication token."), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), ) -> None: """Allocate credits from team to member.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_post( client, f"/api/iam/v1/usage/credits/allocations/teams/{team_uid}/members/{member_uid}", @@ -543,10 +868,15 @@ def usage_team_revoke_member( member_uid: str = typer.Option(..., "--member-uid", help="Member UID."), amount: float = typer.Option(..., "--amount", help="Amount of credits to revoke."), token: Optional[str] = typer.Option(None, "--token", help="Authentication token."), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), ) -> None: """Revoke credits from member back to team.""" try: - client = DatalayerClient(token=token) + client = _make_client(token=token, iam_url=iam_url) response = _iam_post( client, f"/api/iam/v1/usage/credits/allocations/teams/{team_uid}/members/{member_uid}/revoke", @@ -573,6 +903,11 @@ def usage_root( "--token", help="Authentication token (Bearer token for API requests).", ), + iam_url: Optional[str] = typer.Option( + None, + "--iam-url", + help="Datalayer IAM server URL", + ), ) -> None: """Show credits usage and reservations (root command).""" - usage_show(token=token) + usage_show(token=token, iam_url=iam_url) diff --git a/datalayer_core/client/client.py b/datalayer_core/client/client.py index a1f59033..8bd226fa 100644 --- a/datalayer_core/client/client.py +++ b/datalayer_core/client/client.py @@ -16,8 +16,10 @@ from datalayer_core.mixins.authn import AuthnMixin from datalayer_core.mixins.environments import EnvironmentsMixin +from datalayer_core.mixins.evals import EvalsMixin from datalayer_core.mixins.events import EventsMixin -from datalayer_core.mixins.runtime_snapshots import RuntimeSnapshotsMixin +from datalayer_core.mixins.ray import RayMixin +from datalayer_core.mixins.sandbox_snapshots import SandboxSnapshotsMixin from datalayer_core.mixins.runtimes import RuntimesMixin from datalayer_core.mixins.secrets import SecretsMixin from datalayer_core.mixins.tokens import TokensMixin @@ -25,12 +27,12 @@ from datalayer_core.mixins.whoami import WhoamiAppMixin from datalayer_core.models import UserModel from datalayer_core.models.environment import EnvironmentModel -from datalayer_core.models.runtime_snapshot import RuntimeSnapshotModel +from datalayer_core.models.sandbox_snapshot import SandboxSnapshotModel from datalayer_core.models.secret import SecretModel, SecretVariant from datalayer_core.models.token import TokenModel, TokenType from datalayer_core.runtimes.runtime import RuntimeService -from datalayer_core.runtimes.runtime_snapshot import ( - as_runtime_snapshots, +from datalayer_core.runtimes.sandbox_snapshot import ( + as_code_sandbox_snapshots, create_snapshot, ) from datalayer_core.utils.defaults import ( @@ -47,9 +49,11 @@ class DatalayerClient( AuthnMixin, RuntimesMixin, EnvironmentsMixin, + EvalsMixin, EventsMixin, + RayMixin, SecretsMixin, - RuntimeSnapshotsMixin, + SandboxSnapshotsMixin, TokensMixin, UsageMixin, WhoamiAppMixin, @@ -260,6 +264,11 @@ def create_runtime( environment: str = DEFAULT_ENVIRONMENT, time_reservation: Minutes = DEFAULT_TIME_RESERVATION, snapshot_name: Optional[str] = None, + agent_spec_id: Optional[str] = None, + agent_spec: Optional[dict[str, Any]] = None, + billable_account_uid: Optional[str] = None, + billable_account_type: Optional[str] = None, + billable_account_handle: Optional[str] = None, ) -> RuntimeService: """ Create a new runtime (kernel) for code execution. @@ -320,20 +329,43 @@ def create_runtime( given_name=name, environment_name=environment, from_snapshot_uid=snapshot_uid, + agent_spec_id=agent_spec_id, + agent_spec=agent_spec, credits_limit=credits_limit, + billable_account_uid=billable_account_uid, + billable_account_type=billable_account_type, + billable_account_handle=billable_account_handle, ) else: # Create runtime without snapshot response = self._create_runtime( given_name=name, environment_name=environment, + agent_spec_id=agent_spec_id, + agent_spec=agent_spec, credits_limit=credits_limit, + billable_account_uid=billable_account_uid, + billable_account_type=billable_account_type, + billable_account_handle=billable_account_handle, ) # Process the response and create RuntimesService object if not response.get("success", True): + message = response.get("message", "Unknown error") + context_parts = [f"environment='{environment}'"] + if agent_spec_id: + context_parts.append(f"agent_spec_id='{agent_spec_id}'") + if agent_spec: + context_parts.append("agent_spec=") + reason = response.get("reason") + if reason: + context_parts.append(f"reason='{reason}'") + retry_after = response.get("retry_after_seconds") + if retry_after: + context_parts.append(f"retry_after_seconds={retry_after}") + context = ", ".join(context_parts) raise RuntimeError( - f"Runtime creation failed: {response.get('message', 'Unknown error')}" + f"Runtime creation failed ({context}): {message}" ) runtime_data = response["runtime"] @@ -422,6 +454,91 @@ def terminate_runtime(self, runtime: Union[RuntimeService, str]) -> bool: else: return False + def get_runtime(self, runtime: Union[RuntimeService, str]) -> RuntimeService: + """ + Get a single running Runtime by pod name. + + Parameters + ---------- + runtime : Union[Runtime, str] + Runtime object or pod name string to fetch. + + Returns + ------- + Runtime + The Runtime object matching the pod name. + + Raises + ------ + RuntimeError + If the runtime cannot be retrieved. + """ + pod_name = runtime.pod_name if isinstance(runtime, RuntimeService) else runtime + if not pod_name: + raise RuntimeError("A pod name is required to get a runtime.") + + response = self._get_runtime(pod_name) + if not response.get("success", True): + message = response.get("message", "Unknown error") + raise RuntimeError(f"Failed to get runtime '{pod_name}': {message}") + + runtime_data = response.get("runtime") + if not isinstance(runtime_data, dict): + raise RuntimeError( + f"Failed to get runtime '{pod_name}': missing 'runtime' field in response" + ) + + return RuntimeService( + name=runtime_data.get("given_name", pod_name), + environment=runtime_data.get("environment_name", ""), + pod_name=runtime_data.get("pod_name", pod_name), + token=self._token, + ingress=runtime_data.get("ingress"), + reservation_id=runtime_data.get("reservation_id"), + uid=runtime_data.get("uid"), + burning_rate=runtime_data.get("burning_rate"), + jupyter_token=runtime_data.get("token"), + run_url=self._urls.run_url, + iam_url=self._urls.iam_url, + started_at=runtime_data.get("started_at"), + expired_at=runtime_data.get("expired_at"), + ) + + def update_runtime( + self, + runtime: Union[RuntimeService, str], + capabilities: list[str], + ) -> bool: + """ + Update a running Runtime's capabilities. + + Parameters + ---------- + runtime : Union[Runtime, str] + Runtime object or pod name string to update. + capabilities : list[str] + New capabilities to apply to the runtime. + + Returns + ------- + bool + True if the update succeeded. + + Raises + ------ + RuntimeError + If the update fails. + """ + pod_name = runtime.pod_name if isinstance(runtime, RuntimeService) else runtime + if not pod_name: + raise RuntimeError("A pod name is required to update a runtime.") + + response = self._update_runtime(pod_name, capabilities) + if not response.get("success", True): + message = response.get("message", "Unknown error") + raise RuntimeError(f"Failed to update runtime '{pod_name}': {message}") + return True + def list_secrets(self) -> list[SecretModel]: """ List all secrets available in the Datalayer environment. @@ -511,7 +628,7 @@ def create_snapshot( name: Optional[str] = None, description: Optional[str] = None, stop: bool = True, - ) -> "RuntimeSnapshotModel": + ) -> "SandboxSnapshotModel": """ Create a snapshot of the current runtime state. @@ -530,7 +647,7 @@ def create_snapshot( Returns ------- - RuntimeSnapshotModel + SandboxSnapshotModel The created snapshot object. """ if pod_name is None and runtime is None: @@ -556,7 +673,7 @@ def create_snapshot( raise RuntimeError( f"Failed to create snapshot '{name}': {response.get('message', 'unknown error')}" ) - snapshot: Optional[RuntimeSnapshotModel] = None + snapshot: Optional[SandboxSnapshotModel] = None max_poll_attempts = max( 1, int(os.getenv("DATALAYER_SNAPSHOT_POLL_ATTEMPTS", "30")), @@ -577,7 +694,7 @@ def create_snapshot( f"Snapshot '{name}' was created but not found in snapshot listing" ) - return RuntimeSnapshotModel( + return SandboxSnapshotModel( uid=snapshot.uid, name=name, description=description, @@ -585,28 +702,28 @@ def create_snapshot( metadata=response, ) - def list_snapshots(self) -> list[RuntimeSnapshotModel]: + def list_snapshots(self) -> list[SandboxSnapshotModel]: """ List all snapshots. Returns ------- - list[RuntimeSnapshotModel] + list[SandboxSnapshotModel] A list of snapshots associated with the user. """ response = self._list_snapshots() - snapshot_objects = as_runtime_snapshots(response) + snapshot_objects = as_code_sandbox_snapshots(response) return snapshot_objects def delete_snapshot( - self, snapshot: Union[str, RuntimeSnapshotModel] + self, snapshot: Union[str, SandboxSnapshotModel] ) -> dict[str, str]: """ Delete a specific snapshot. Parameters ---------- - snapshot : Union[str, RuntimeSnapshotModel] + snapshot : Union[str, SandboxSnapshotModel] Snapshot object or UID string to delete. Returns @@ -615,7 +732,7 @@ def delete_snapshot( The result of the deletion operation. """ snapshot_uid = ( - snapshot.uid if isinstance(snapshot, RuntimeSnapshotModel) else snapshot + snapshot.uid if isinstance(snapshot, SandboxSnapshotModel) else snapshot ) return self._delete_snapshot(snapshot_uid) diff --git a/datalayer_core/decorators/datalayer.py b/datalayer_core/decorators/datalayer.py index c6f47c1c..13301ee6 100644 --- a/datalayer_core/decorators/datalayer.py +++ b/datalayer_core/decorators/datalayer.py @@ -48,7 +48,7 @@ def datalayer( output : str, optional The name of the output variable for the function. snapshot_name : str, optional - The name of the runtime snapshot to use. + The name of the code sandbox snapshot to use. token : str, optional Authentication token. If not provided, will be resolved from env/keyring. debug : bool diff --git a/datalayer_core/displays/agent_nodes.py b/datalayer_core/displays/agent_nodes.py new file mode 100644 index 00000000..c627be54 --- /dev/null +++ b/datalayer_core/displays/agent_nodes.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023-2025 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Rich display helpers for agent nodes.""" + +from typing import Any + +from rich.console import Console +from rich.table import Table + +console = Console() + + +def display_agent_nodes(agent_nodes: list[dict[str, Any]]) -> None: + """Display agent nodes in a Rich table.""" + table = Table(title="Agent Nodes") + table.add_column("Node ID", style="cyan", no_wrap=True) + table.add_column("Name") + table.add_column("Mode") + table.add_column("Status") + table.add_column("Last Seen") + + for node in agent_nodes: + configuration = node.get("configuration") or {} + table.add_row( + str(node.get("node_id") or ""), + str(node.get("node_name") or ""), + str(configuration.get("mode") or "sleep"), + str(node.get("status") or "stale"), + str(node.get("last_seen_at") or ""), + ) + + console.print(table) diff --git a/datalayer_core/displays/runtime_snapshots.py b/datalayer_core/displays/sandbox_snapshots.py similarity index 71% rename from datalayer_core/displays/runtime_snapshots.py rename to datalayer_core/displays/sandbox_snapshots.py index 5fd6c692..0b9ac30a 100644 --- a/datalayer_core/displays/runtime_snapshots.py +++ b/datalayer_core/displays/sandbox_snapshots.py @@ -11,9 +11,9 @@ from rich.table import Table -def _new_runtime_snapshots_table(title: str = "Snapshots") -> Table: +def _new_code_sandbox_snapshots_table(title: str = "Snapshots") -> Table: """ - Create a new runtime snapshots table. + Create a new code sandbox snapshots table. Parameters ---------- @@ -33,9 +33,9 @@ def _new_runtime_snapshots_table(title: str = "Snapshots") -> Table: return table -def _add_runtime_snapshot_to_table(table: Table, snapshot: dict[str, Any]) -> None: +def _add_code_sandbox_snapshot_to_table(table: Table, snapshot: dict[str, Any]) -> None: """ - Add a runtime snapshot row to the table. + Add a code sandbox snapshot row to the table. Parameters ---------- @@ -52,17 +52,17 @@ def _add_runtime_snapshot_to_table(table: Table, snapshot: dict[str, Any]) -> No ) -def display_runtime_snapshots(snapshots: list[dict[str, Any]]) -> None: +def display_code_sandbox_snapshots(snapshots: list[dict[str, Any]]) -> None: """ - Display a list of runtime snapshots in the console. + Display a list of code sandbox snapshots in the console. Parameters ---------- snapshots : list[dict[str, Any]] List of snapshot dictionaries to display. """ - table = _new_runtime_snapshots_table(title="Runtime Snapshots") + table = _new_code_sandbox_snapshots_table(title="Runtime Snapshots") for snapshot in snapshots: - _add_runtime_snapshot_to_table(table, snapshot) + _add_code_sandbox_snapshot_to_table(table, snapshot) console = Console() console.print(table) diff --git a/datalayer_core/mixins/__init__.py b/datalayer_core/mixins/__init__.py index cdd27246..8370f351 100644 --- a/datalayer_core/mixins/__init__.py +++ b/datalayer_core/mixins/__init__.py @@ -2,7 +2,7 @@ # Distributed under the terms of the Modified BSD License. from .authn import AuthnMixin from .environments import EnvironmentsMixin -from .runtime_snapshots import RuntimeSnapshotsMixin +from .sandbox_snapshots import SandboxSnapshotsMixin from .runtimes import RuntimesMixin from .secrets import SecretsMixin from .tokens import TokensMixin @@ -12,7 +12,7 @@ __all__ = [ "AuthnMixin", "EnvironmentsMixin", - "RuntimeSnapshotsMixin", + "SandboxSnapshotsMixin", "RuntimesMixin", "SecretsMixin", "TokensMixin", diff --git a/datalayer_core/mixins/evals.py b/datalayer_core/mixins/evals.py new file mode 100644 index 00000000..6cc27043 --- /dev/null +++ b/datalayer_core/mixins/evals.py @@ -0,0 +1,307 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Evals management mixin for Datalayer Core.""" + +from __future__ import annotations + +from typing import Any, Optional + + +class EvalsMixin: + """Mixin for managing evals, experiments, runs, and live monitoring.""" + + def _evals_request( + self, + path: str, + *, + method: str, + account_uid: Optional[str] = None, + params: Optional[dict[str, Any]] = None, + json_body: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + query: dict[str, Any] = dict(params or {}) + if account_uid: + query["account_uid"] = account_uid + response = self._fetch( # type: ignore + f"{self.urls.ai_agents_url}/api/ai-agents/v1/evals{path}", # type: ignore + method=method, + params=query, + json=json_body, + ) + return response.json() + + def evals_list_evals( + self, + *, + kind: Optional[str] = None, + run_environment: Optional[str] = None, + q: Optional[str] = None, + limit: int = 50, + offset: int = 0, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + params: dict[str, Any] = {"limit": limit, "offset": offset} + if kind: + params["kind"] = kind + if run_environment: + params["run_environment"] = run_environment + if q: + params["q"] = q + return self._evals_request( + "/evalsets", + method="GET", + params=params, + account_uid=account_uid, + ) + + def evals_create_eval( + self, + *, + name: str, + description: str = "", + run_environment: str = "sdk", + kind: str = "batch", + schema: Optional[dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + cases: Optional[list[dict[str, Any]]] = None, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + body = { + "name": name, + "description": description, + "run_environment": run_environment, + "kind": kind, + "schema": schema or {}, + "tags": tags or [], + "metadata": metadata or {}, + "cases": cases or [], + } + return self._evals_request( + "/evalsets", + method="POST", + json_body=body, + account_uid=account_uid, + ) + + def evals_delete_eval( + self, + evalset_id: str, + *, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + return self._evals_request( + f"/evalsets/{evalset_id}", + method="DELETE", + account_uid=account_uid, + ) + + def evals_list_experiments( + self, + *, + evalset_id: Optional[str] = None, + status: Optional[str] = None, + limit: int = 50, + offset: int = 0, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + params: dict[str, Any] = {"limit": limit, "offset": offset} + if evalset_id: + params["evalset_id"] = evalset_id + if status: + params["status"] = status + return self._evals_request( + "/experiments", + method="GET", + params=params, + account_uid=account_uid, + ) + + def evals_create_experiment( + self, + *, + name: str, + evalset_id: Optional[str] = None, + description: str = "", + status: str = "draft", + config: Optional[dict[str, Any]] = None, + summary: Optional[dict[str, Any]] = None, + tags: Optional[list[str]] = None, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + body = { + "name": name, + "evalset_id": evalset_id, + "description": description, + "status": status, + "config": config or {}, + "summary": summary or {}, + "tags": tags or [], + } + return self._evals_request( + "/experiments", + method="POST", + json_body=body, + account_uid=account_uid, + ) + + def evals_delete_experiment( + self, + experiment_id: str, + *, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + return self._evals_request( + f"/experiments/{experiment_id}", + method="DELETE", + account_uid=account_uid, + ) + + def evals_list_runs( + self, + experiment_id: str, + *, + limit: int = 50, + offset: int = 0, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + return self._evals_request( + f"/experiments/{experiment_id}/runs", + method="GET", + params={"limit": limit, "offset": offset}, + account_uid=account_uid, + ) + + def evals_create_run( + self, + experiment_id: str, + *, + status: str = "queued", + started_at: Optional[str] = None, + ended_at: Optional[str] = None, + metrics: Optional[dict[str, Any]] = None, + summary: Optional[dict[str, Any]] = None, + report: Optional[dict[str, Any]] = None, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + body: dict[str, Any] = { + "status": status, + "metrics": metrics or {}, + "summary": summary or {}, + "report": report or {}, + } + if started_at: + body["started_at"] = started_at + if ended_at: + body["ended_at"] = ended_at + return self._evals_request( + f"/experiments/{experiment_id}/runs", + method="POST", + json_body=body, + account_uid=account_uid, + ) + + def evals_get_run( + self, + run_id: str, + *, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + return self._evals_request( + f"/runs/{run_id}", + method="GET", + account_uid=account_uid, + ) + + def evals_compare_runs( + self, + run_ids: list[str], + *, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + return self._evals_request( + "/runs/compare", + method="POST", + json_body={"run_ids": run_ids}, + account_uid=account_uid, + ) + + def evals_create_live_event( + self, + *, + target_id: str, + target_type: str = "agent", + evaluator_name: Optional[str] = None, + metric_name: Optional[str] = None, + value_num: Optional[float] = None, + label: Optional[str] = None, + passed: Optional[bool] = None, + attributes: Optional[dict[str, Any]] = None, + created_at: Optional[str] = None, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + body: dict[str, Any] = { + "target_id": target_id, + "target_type": target_type, + "attributes": attributes or {}, + } + if evaluator_name is not None: + body["evaluator_name"] = evaluator_name + if metric_name is not None: + body["metric_name"] = metric_name + if value_num is not None: + body["value_num"] = value_num + if label is not None: + body["label"] = label + if passed is not None: + body["passed"] = passed + if created_at is not None: + body["created_at"] = created_at + return self._evals_request( + "/live/events", + method="POST", + json_body=body, + account_uid=account_uid, + ) + + def evals_list_live_targets( + self, + *, + window: str = "24h", + limit: int = 50, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + return self._evals_request( + "/live/targets", + method="GET", + params={"window": window, "limit": limit}, + account_uid=account_uid, + ) + + def evals_list_live_events( + self, + *, + target_id: str, + target_type: str = "agent", + window: str = "24h", + evaluator_name: Optional[str] = None, + limit: int = 50, + offset: int = 0, + account_uid: Optional[str] = None, + ) -> dict[str, Any]: + params: dict[str, Any] = { + "target_id": target_id, + "target_type": target_type, + "window": window, + "limit": limit, + "offset": offset, + } + if evaluator_name: + params["evaluator_name"] = evaluator_name + return self._evals_request( + "/live/events", + method="GET", + params=params, + account_uid=account_uid, + ) \ No newline at end of file diff --git a/datalayer_core/mixins/ray.py b/datalayer_core/mixins/ray.py new file mode 100644 index 00000000..7de8b647 --- /dev/null +++ b/datalayer_core/mixins/ray.py @@ -0,0 +1,145 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Ray management mixin for Datalayer Core.""" + +from __future__ import annotations + +from typing import Any, Optional + + +class RayMixin: + """Mixin for managing Ray clusters and Ray jobs through the Ray addon API.""" + + _RAY_API_PREFIXES_RUNTIMES = ("/api/runtimes/v1/ray",) + _RAY_API_PREFIXES_ADDON = ("/api/ray/v1",) + + def _get_ray_api_prefixes(self) -> tuple[str, ...]: + if bool(getattr(self, "_ray_direct_addon", False)): # type: ignore[attr-defined] + return self._RAY_API_PREFIXES_ADDON + return self._RAY_API_PREFIXES_RUNTIMES + + def _ray_request( + self, + path: str, + *, + method: str, + params: Optional[dict[str, Any]] = None, + json_body: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + prefixes = self._get_ray_api_prefixes() + prefix = prefixes[0] + response = self._fetch( # type: ignore + f"{self.urls.ray_url}{prefix}{path}", # type: ignore + method=method, + params=params, + json=json_body, + ) + return response.json() + + def ray_list_clusters(self, *, namespace: str = "default") -> dict[str, Any]: + return self._ray_request( + "/clusters", + method="GET", + params={"namespace": namespace}, + ) + + def ray_create_cluster(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._ray_request( + "/clusters", + method="POST", + json_body=payload, + ) + + def ray_get_cluster(self, name: str, *, namespace: str = "default") -> dict[str, Any]: + return self._ray_request( + f"/clusters/{name}", + method="GET", + params={"namespace": namespace}, + ) + + def ray_delete_cluster(self, name: str, *, namespace: str = "default") -> dict[str, Any]: + return self._ray_request( + f"/clusters/{name}", + method="DELETE", + params={"namespace": namespace}, + ) + + def ray_submit_job( + self, + cluster_name: str, + payload: dict[str, Any], + ) -> dict[str, Any]: + return self._ray_request( + f"/clusters/{cluster_name}/jobs", + method="POST", + json_body=payload, + ) + + def ray_list_jobs( + self, + *, + namespace: str = "default", + cluster_name: Optional[str] = None, + ) -> dict[str, Any]: + params: dict[str, Any] = {"namespace": namespace} + if cluster_name: + params["cluster_name"] = cluster_name + return self._ray_request( + "/jobs", + method="GET", + params=params, + ) + + def ray_get_job(self, name: str, *, namespace: str = "default") -> dict[str, Any]: + return self._ray_request( + f"/jobs/{name}", + method="GET", + params={"namespace": namespace}, + ) + + def ray_get_job_logs( + self, + name: str, + *, + namespace: str = "default", + pod_name: Optional[str] = None, + container: Optional[str] = None, + tail_lines: int = 200, + ) -> dict[str, Any]: + params: dict[str, Any] = { + "namespace": namespace, + "tail_lines": tail_lines, + } + if pod_name: + params["pod_name"] = pod_name + if container: + params["container"] = container + return self._ray_request( + f"/jobs/{name}/logs", + method="GET", + params=params, + ) + + def ray_get_job_events( + self, + name: str, + *, + namespace: str = "default", + limit: int = 100, + ) -> dict[str, Any]: + return self._ray_request( + f"/jobs/{name}/events", + method="GET", + params={ + "namespace": namespace, + "limit": limit, + }, + ) + + def ray_delete_job(self, name: str, *, namespace: str = "default") -> dict[str, Any]: + return self._ray_request( + f"/jobs/{name}", + method="DELETE", + params={"namespace": namespace}, + ) diff --git a/datalayer_core/mixins/runtimes.py b/datalayer_core/mixins/runtimes.py index 6037084e..e721f3e0 100644 --- a/datalayer_core/mixins/runtimes.py +++ b/datalayer_core/mixins/runtimes.py @@ -39,6 +39,11 @@ def _create_runtime( given_name: Optional[str] = None, credits_limit: Optional[float] = None, from_snapshot_uid: Optional[str] = None, + agent_spec_id: Optional[str] = None, + agent_spec: Optional[dict[str, Any]] = None, + billable_account_uid: Optional[str] = None, + billable_account_type: Optional[str] = None, + billable_account_handle: Optional[str] = None, ) -> dict[str, Any]: """ Create a Runtime with the given environment name. @@ -108,6 +113,18 @@ def _create_runtime( if from_snapshot_uid: body["from"] = from_snapshot_uid + if agent_spec_id: + body["agent_spec_id"] = agent_spec_id + if agent_spec: + body["agent_spec"] = agent_spec + + if billable_account_uid: + body["billable_account_uid"] = billable_account_uid + if billable_account_type: + body["billable_account_type"] = billable_account_type + if billable_account_handle: + body["billable_account_handle"] = billable_account_handle + runtime_url = "{}/api/runtimes/v1/runtimes".format(self.urls.runtimes_url) # type: ignore logger.debug( "Creating runtime via %s with payload keys=%s", @@ -277,9 +294,127 @@ def _terminate_runtime(self: Any, pod_name: str) -> dict[str, Any]: return {"success": False, "message": error_msg} +class RuntimesGetMixin: + """Mixin for reading a single Datalayer runtime.""" + + def _get_runtime(self: Any, pod_name: str) -> dict[str, Any]: + """ + Get a single Runtime by pod name. + + Parameters + ---------- + pod_name : str + The pod name of the runtime to fetch. + + Returns + ------- + dict[str, Any] + Response containing the runtime payload. + """ + try: + response = self._fetch( + "{}/api/runtimes/v1/runtimes/{}".format( + self.urls.runtimes_url, pod_name + ), + ) + + if response.status_code != 200: + error_msg = f"Failed to get runtime: HTTP {response.status_code}" + logger.error(error_msg) + try: + error_details = response.json() + if "message" in error_details: + error_msg += f" - {error_details['message']}" + except Exception: + pass + return {"success": False, "message": error_msg} + + try: + result = response.json() + if "success" in result and not result["success"]: + error_msg = f"Get runtime failed: {result.get('message', 'Unknown error')}" + logger.error(error_msg) + return {"success": False, "message": error_msg} + return result + except Exception as e: + error_msg = f"Failed to parse runtime response: {str(e)}" + logger.error(error_msg) + return {"success": False, "message": error_msg} + + except Exception as e: + error_msg = f"Unexpected error getting runtime {pod_name}: {str(e)}" + logger.error(error_msg) + return {"success": False, "message": error_msg} + + +class RuntimesUpdateMixin: + """Mixin for updating a Datalayer runtime.""" + + def _update_runtime( + self: Any, + pod_name: str, + capabilities: list[str], + ) -> dict[str, Any]: + """ + Update a Runtime's capabilities. + + Parameters + ---------- + pod_name : str + The pod name of the runtime to update. + capabilities : list[str] + New capabilities to apply to the runtime. + + Returns + ------- + dict[str, Any] + Response containing the update status. + """ + try: + response = self._fetch( + "{}/api/runtimes/v1/runtimes/{}".format( + self.urls.runtimes_url, pod_name + ), + method="PUT", + json={"capabilities": capabilities}, + ) + + if response.status_code not in [200, 201, 202]: + error_msg = f"Failed to update runtime: HTTP {response.status_code}" + logger.error(error_msg) + try: + error_details = response.json() + if "message" in error_details: + error_msg += f" - {error_details['message']}" + elif "detail" in error_details: + error_msg += f" - {error_details['detail']}" + except Exception: + pass + return {"success": False, "message": error_msg} + + try: + result = response.json() + if "success" in result and not result["success"]: + error_msg = f"Update runtime failed: {result.get('message', 'Unknown error')}" + logger.error(error_msg) + return {"success": False, "message": error_msg} + return result + except Exception as e: + error_msg = f"Failed to parse runtime update response: {str(e)}" + logger.error(error_msg) + return {"success": False, "message": error_msg} + + except Exception as e: + error_msg = f"Unexpected error updating runtime {pod_name}: {str(e)}" + logger.error(error_msg) + return {"success": False, "message": error_msg} + + class RuntimesMixin( RuntimesCreateMixin, RuntimesListMixin, + RuntimesGetMixin, + RuntimesUpdateMixin, RuntimesTerminateMixin, ): """ diff --git a/datalayer_core/mixins/runtime_snapshots.py b/datalayer_core/mixins/sandbox_snapshots.py similarity index 87% rename from datalayer_core/mixins/runtime_snapshots.py rename to datalayer_core/mixins/sandbox_snapshots.py index 20881caf..0c082d6c 100644 --- a/datalayer_core/mixins/runtime_snapshots.py +++ b/datalayer_core/mixins/sandbox_snapshots.py @@ -4,7 +4,7 @@ from typing import Any -class RuntimeSnapshotsCreateMixin: +class SandboxSnapshotsCreateMixin: """Mixin class for creating snapshots.""" def _create_snapshot( @@ -37,7 +37,7 @@ def _create_snapshot( } try: response = self._fetch( # type: ignore - "{}/api/runtimes/v1/runtime-snapshots".format(self.urls.runtimes_url), # type: ignore + "{}/api/runtimes/v1/sandbox-snapshots".format(self.urls.runtimes_url), # type: ignore method="POST", json=body, ) @@ -46,7 +46,7 @@ def _create_snapshot( return {"success": False, "message": str(e)} -class RuntimeSnapshotsDeleteMixin: +class SandboxSnapshotsDeleteMixin: """ Mixin class that provides snapshot deletion functionality. """ @@ -67,7 +67,7 @@ def _delete_snapshot(self, snapshot_uid: str) -> dict[str, Any]: """ try: response = self._fetch( # type: ignore - "{}/api/runtimes/v1/runtime-snapshots/{}".format( + "{}/api/runtimes/v1/sandbox-snapshots/{}".format( self.urls.runtimes_url, # type: ignore snapshot_uid, ), @@ -81,7 +81,7 @@ def _delete_snapshot(self, snapshot_uid: str) -> dict[str, Any]: return {"success": False, "message": str(e)} -class RuntimeSnapshotsListMixin: +class SandboxSnapshotsListMixin: """ Mixin class to provide functionality for listing snapshots. """ @@ -97,15 +97,15 @@ def _list_snapshots(self) -> dict[str, Any]: """ try: response = self._fetch( # type: ignore - "{}/api/runtimes/v1/runtime-snapshots".format(self.urls.runtimes_url), # type: ignore + "{}/api/runtimes/v1/sandbox-snapshots".format(self.urls.runtimes_url), # type: ignore ) return response.json() except RuntimeError as e: return {"success": False, "message": str(e)} -class RuntimeSnapshotsMixin( - RuntimeSnapshotsCreateMixin, RuntimeSnapshotsDeleteMixin, RuntimeSnapshotsListMixin +class SandboxSnapshotsMixin( + SandboxSnapshotsCreateMixin, SandboxSnapshotsDeleteMixin, SandboxSnapshotsListMixin ): """ Mixin class that provides snapshot management functionality. diff --git a/datalayer_core/mixins/usage.py b/datalayer_core/mixins/usage.py index 80bc8f43..ae5856f3 100644 --- a/datalayer_core/mixins/usage.py +++ b/datalayer_core/mixins/usage.py @@ -37,7 +37,7 @@ def _get_subscription(self) -> dict[str, Any]: """ try: response = self._fetch( # type: ignore - "{}/api/iam/v1/subscription".format(self.urls.iam_url), # type: ignore + "{}/api/iam/v1/plans".format(self.urls.iam_url), # type: ignore ) return response.json() except RuntimeError as e: @@ -54,7 +54,7 @@ def _cancel_subscription(self) -> dict[str, Any]: """ try: response = self._fetch( # type: ignore - "{}/api/iam/v1/subscription/cancel".format(self.urls.iam_url), # type: ignore + "{}/api/iam/v1/plans/cancel".format(self.urls.iam_url), # type: ignore method="POST", ) return response.json() @@ -72,7 +72,7 @@ def _get_subscription_plans(self) -> dict[str, Any]: """ try: response = self._fetch( # type: ignore - "{}/api/iam/v1/subscription/plans".format(self.urls.iam_url), # type: ignore + "{}/api/iam/v1/plans/catalog".format(self.urls.iam_url), # type: ignore ) return response.json() except RuntimeError as e: diff --git a/datalayer_core/models/__init__.py b/datalayer_core/models/__init__.py index 74e3d0cc..c128d2c2 100644 --- a/datalayer_core/models/__init__.py +++ b/datalayer_core/models/__init__.py @@ -81,7 +81,7 @@ UserSettingsModel, ) from .runtime import RuntimeModel -from .runtime_snapshot import RuntimeSnapshotModel +from .sandbox_snapshot import SandboxSnapshotModel from .secret import SecretModel, SecretVariant from .token import TokenModel, TokenType @@ -137,7 +137,7 @@ "ResourceRequirements", "Response", "RuntimeModel", - "RuntimeSnapshotModel", + "SandboxSnapshotModel", "SecretModel", "SecretModel", "SecretVariant", diff --git a/datalayer_core/models/iam.py b/datalayer_core/models/iam.py index c10c05f3..c29eee33 100644 --- a/datalayer_core/models/iam.py +++ b/datalayer_core/models/iam.py @@ -449,6 +449,40 @@ def from_solr_results( return cls(memberships=memberships) +# Shareable Principals Models +class ShareablePrincipalModel(BaseModel): + """Principal a user can share artifacts with. + + Always one of: self (user), an organization the user is a member of, + or a team the user is a member of (with its parent organization info). + """ + + kind: str = Field(..., description="Principal kind: 'user' | 'organization' | 'team'") + uid: str = Field(..., description="Principal UID") + handle: str = Field(..., description="Principal handle") + name: Optional[str] = Field(None, description="Display name") + description: Optional[str] = Field(None, description="Description (org/team)") + email: Optional[str] = Field(None, description="Email (user only)") + avatar_url: Optional[str] = Field(None, description="Avatar URL") + organization_uid: Optional[str] = Field( + None, description="Parent organization UID (team only)" + ) + organization_handle: Optional[str] = Field( + None, description="Parent organization handle (team only)" + ) + + +class ShareablePrincipalsResponse(BaseModel): + """Response for principals-shareable-with endpoint.""" + + success: bool = Field(default=True) + message: Optional[str] = Field(default=None) + principals: List[ShareablePrincipalModel] = Field( + default_factory=list, + description="Self + member organizations + member teams", + ) + + # Credits and Reservations Models class ResourceRequirements(BaseModel): """Kubernetes pod resource requirements.""" diff --git a/datalayer_core/models/runtime_snapshot.py b/datalayer_core/models/sandbox_snapshot.py similarity index 78% rename from datalayer_core/models/runtime_snapshot.py rename to datalayer_core/models/sandbox_snapshot.py index 775eb4ff..77946bbd 100644 --- a/datalayer_core/models/runtime_snapshot.py +++ b/datalayer_core/models/sandbox_snapshot.py @@ -4,7 +4,7 @@ """ Runtime snapshot model for Datalayer. -Provides data structures for runtime snapshot management in Datalayer environments. +Provides data structures for code sandbox snapshot management in Datalayer environments. """ from typing import Any, Dict @@ -12,12 +12,12 @@ from pydantic import BaseModel, Field -class RuntimeSnapshotModel(BaseModel): +class SandboxSnapshotModel(BaseModel): """ Pydantic model representing a snapshot of a Datalayer runtime state. This model contains all the data fields and configuration parameters - for a runtime snapshot, separate from the service logic. + for a code sandbox snapshot, separate from the service logic. """ uid: str = Field(..., description="Unique identifier for the snapshot") @@ -32,6 +32,6 @@ class RuntimeSnapshotModel(BaseModel): def __repr__(self) -> str: return ( - f"RuntimeSnapshotModel(uid='{self.uid}', name='{self.name}', " + f"SandboxSnapshotModel(uid='{self.uid}', name='{self.name}', " f"description='{self.description}', environment='{self.environment}')" ) diff --git a/datalayer_core/runtimes/agent_runtime.py b/datalayer_core/runtimes/agent_runtime.py new file mode 100644 index 00000000..27856a57 --- /dev/null +++ b/datalayer_core/runtimes/agent_runtime.py @@ -0,0 +1,300 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Cloud agent runtime provisioning helpers. + +Reusable logic for launching cloud ``agent-runtimes`` from a +:class:`~datalayer_core.client.client.DatalayerClient`. Shared by the eval +examples and the GitHub Actions integration so credit/time-reservation math, +environment burning-rate lookup, and ``create_runtime`` error handling are not +duplicated across consumers. +""" + +from __future__ import annotations + +import math +from typing import Any, Optional + + +def resolve_environment_burning_rate( + client: Any, + environment_name: str, +) -> float: + """Return the positive burning rate for an environment. + + Parameters + ---------- + client : DatalayerClient + An authenticated client able to list environments. + environment_name : str + The environment to look up. + + Returns + ------- + float + The environment's positive burning rate. + + Raises + ------ + RuntimeError + If the environment cannot be listed, is not found, or has no positive + burning rate. + """ + + def _to_float(value: Any) -> Optional[float]: + try: + if value is None: + return None + parsed = float(value) + if parsed > 0: + return parsed + except (TypeError, ValueError): + return None + return None + + response = client._list_environments() + if not response.get("success", True): + raise RuntimeError( + f"Failed to list environments: {response.get('message', 'Unknown error')}" + ) + environments = response.get("environments") + if not isinstance(environments, list): + raise RuntimeError( + "Failed to list environments: invalid environments payload." + ) + + matched_environment: Optional[dict[str, Any]] = None + for raw_env in environments: + if ( + isinstance(raw_env, dict) + and str(raw_env.get("name") or "") == environment_name + ): + matched_environment = raw_env + break + + if matched_environment is None: + available = [ + str(env.get("name") or "") + for env in environments + if isinstance(env, dict) + ] + raise RuntimeError( + f"Environment '{environment_name}' not found for cloud runtime launch. " + f"Available environments: {available}" + ) + + parsed = _to_float(matched_environment.get("burning_rate")) + if parsed is not None: + return parsed + + available_keys = sorted(matched_environment.keys()) + raise RuntimeError( + f"Environment '{environment_name}' is missing a positive burning rate " + "in backend payload. Checked key: burning_rate. " + f"Environment keys: {available_keys}" + ) + + +def compute_time_reservation_minutes( + *, + credits_limit: float, + burning_rate: float, +) -> int: + """Compute a time reservation (minutes) from a credits budget. + + ``create_runtime`` charges ``burning_rate * 60 * time_reservation`` credits, + so this returns the smallest whole-minute reservation whose cost is at least + ``credits_limit`` (minimum 1 minute). + + Raises + ------ + ValueError + If ``burning_rate`` is not positive. + """ + if burning_rate <= 0: + raise ValueError("burning_rate must be positive.") + return max(1, int(math.ceil(float(credits_limit) / (burning_rate * 60.0)))) + + +def create_cloud_agent_runtime( + client: Any, + *, + environment_name: str, + name: Optional[str] = None, + agent_spec_id: Optional[str] = None, + agent_spec: Optional[dict[str, Any]] = None, + credits_limit: Optional[float] = None, + time_reservation: Optional[int] = None, + billable_account_uid: Optional[str] = None, + billable_account_type: Optional[str] = None, + billable_account_handle: Optional[str] = None, +) -> Any: + """Create a cloud agent runtime via the core client. + + Either ``time_reservation`` (in minutes) or ``credits_limit`` must be + provided. When only ``credits_limit`` is given, the time reservation is + derived from the environment's burning rate. + + Parameters + ---------- + client : DatalayerClient + An authenticated client. + environment_name : str + The runtime environment to launch in. + name : Optional[str] + Optional runtime name. + agent_spec_id : Optional[str] + Registered agent spec id (ignored when ``agent_spec`` is provided). + agent_spec : Optional[dict[str, Any]] + Inline agent spec payload (takes precedence over ``agent_spec_id``). + credits_limit : Optional[float] + Target credits budget used to derive ``time_reservation`` when the + latter is not supplied. + time_reservation : Optional[int] + Explicit time reservation in minutes. + billable_account_uid : Optional[str] + Optional billable account UID used for runtime billing attribution. + billable_account_type : Optional[str] + Optional billable account type (user, organization, team). + billable_account_handle : Optional[str] + Optional billable account handle. + + Returns + ------- + Any + The created runtime object (exposes ``pod_name`` and ``ingress``). + + Raises + ------ + ValueError + If neither ``time_reservation`` nor ``credits_limit`` is provided. + RuntimeError + If runtime creation fails or returns no ``pod_name``. + """ + if time_reservation is None: + if credits_limit is None: + raise ValueError( + "Provide either time_reservation or credits_limit." + ) + burning_rate = resolve_environment_burning_rate(client, environment_name) + time_reservation = compute_time_reservation_minutes( + credits_limit=credits_limit, + burning_rate=burning_rate, + ) + + try: + runtime = client.create_runtime( + name=name, + environment=environment_name, + time_reservation=int(time_reservation), + agent_spec_id=None if agent_spec else agent_spec_id, + agent_spec=agent_spec, + billable_account_uid=billable_account_uid, + billable_account_type=billable_account_type, + billable_account_handle=billable_account_handle, + ) + except Exception as exc: + spec_hint = "inline spec payload" if agent_spec else (agent_spec_id or "") + raise RuntimeError( + "Cloud runtime creation failed. " + f"environment={environment_name}, agent_spec={spec_hint}, error={exc}" + ) from exc + + pod_name = str(getattr(runtime, "pod_name", "") or "").strip() + if not pod_name: + raise RuntimeError("Runtime creation succeeded but pod_name is missing.") + return runtime + + +def terminate_cloud_agent_runtime( + client: Any, + runtime_or_pod_name: Any, + *, + raise_on_error: bool = False, +) -> bool: + """Terminate a cloud runtime created for agent execution. + + Parameters + ---------- + client : DatalayerClient + An authenticated client exposing ``terminate_runtime``. + runtime_or_pod_name : Any + Runtime object (with ``pod_name``) or raw pod-name string. + raise_on_error : bool + When ``True``, raise :class:`RuntimeError` if termination fails. + + Returns + ------- + bool + ``True`` when the runtime was terminated, otherwise ``False``. + """ + if isinstance(runtime_or_pod_name, str): + pod_name = runtime_or_pod_name.strip() + else: + pod_name = str(getattr(runtime_or_pod_name, "pod_name", "") or "").strip() + + if not pod_name: + if raise_on_error: + raise RuntimeError("Cannot terminate cloud runtime: pod_name is missing.") + return False + + try: + success = bool(client.terminate_runtime(pod_name)) + except Exception as exc: + if raise_on_error: + raise RuntimeError( + f"Cloud runtime termination failed for pod {pod_name}: {exc}" + ) from exc + return False + + if not success and raise_on_error: + raise RuntimeError(f"Cloud runtime termination returned unsuccessful for pod {pod_name}.") + return success + + +def teardown_agent_execution_resources( + client: Any, + *, + execution_target: str, + cloud_runtime_or_pod_name: Any = None, + local_base_url: Optional[str] = None, + local_agent_name: Optional[str] = None, + token: Optional[str] = None, + local_runtime: Any = None, +) -> dict[str, bool]: + """Teardown resources used by agent execution. + + Handles both cloud and local cleanup using a single API so consumers + (examples, GitHub Actions) don't duplicate teardown logic. + """ + result = { + "cloud_runtime_terminated": False, + "local_agent_deleted": False, + "local_runtime_terminated": False, + } + + target = str(execution_target or "").strip().lower() + if target == "cloud": + if cloud_runtime_or_pod_name: + result["cloud_runtime_terminated"] = terminate_cloud_agent_runtime( + client, + cloud_runtime_or_pod_name, + ) + return result + + if target == "local": + if local_base_url and token and local_agent_name: + from datalayer_core.runtimes.local import delete_local_agent + + result["local_agent_deleted"] = delete_local_agent( + base_url=local_base_url, + token=token, + agent_name=local_agent_name, + ) + if local_runtime is not None: + from datalayer_core.runtimes.local import terminate_local_agent_runtime + + terminate_local_agent_runtime(local_runtime) + result["local_runtime_terminated"] = True + + return result diff --git a/datalayer_core/runtimes/local.py b/datalayer_core/runtimes/local.py new file mode 100644 index 00000000..3ab44ca4 --- /dev/null +++ b/datalayer_core/runtimes/local.py @@ -0,0 +1,684 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Local agent runtime lifecycle helpers. + +Provides a reusable API to launch, register, interact with, and tear down a +local ``agent-runtimes`` server. Shared by the ``datalayer agents`` CLI +(``--local`` flag) and by examples so the same logic is not duplicated. +""" + +from __future__ import annotations + +import json +import logging +import os +import socket +import subprocess +import time +from dataclasses import dataclass, field +from typing import Any, Optional +from urllib.parse import urlparse + +import requests + +logger = logging.getLogger(__name__) + +DEFAULT_LOCAL_HOST = "127.0.0.1" +DEFAULT_LOCAL_AGENT_NAME = "default" +DEFAULT_LOCAL_PROTOCOL = "vercel-ai" +DEFAULT_LOCAL_LOG_LEVEL = "info" + +# Map Datalayer Bedrock credentials onto the AWS variables the local +# agent-runtimes server expects. +_BEDROCK_ENV_MAPPINGS = { + "DATALAYER_BEDROCK_AWS_ACCESS_KEY_ID": "AWS_ACCESS_KEY_ID", + "DATALAYER_BEDROCK_AWS_SECRET_ACCESS_KEY": "AWS_SECRET_ACCESS_KEY", + "DATALAYER_BEDROCK_AWS_DEFAULT_REGION": "AWS_DEFAULT_REGION", +} + + +@dataclass +class LocalAgentRuntime: + """Handle to a running local ``agent-runtimes`` server.""" + + base_url: str + agent_name: str + agent_spec_id: str + process: Optional[subprocess.Popen[Any]] = field(default=None, repr=False) + + @property + def chat_endpoint(self) -> str: + """Vercel AI chat endpoint for this runtime's agent.""" + return f"{self.base_url.rstrip('/')}/api/v1/vercel-ai/{self.agent_name}" + + def terminate(self) -> None: + """Terminate the underlying server process (if any).""" + terminate_local_agent_runtime(self) + + +def find_free_port(host: str = DEFAULT_LOCAL_HOST) -> int: + """Return a free TCP port bound on ``host``.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind((host, 0)) + return int(sock.getsockname()[1]) + + +def build_agent_runtime_env() -> tuple[dict[str, str], list[str]]: + """Build the subprocess environment with Bedrock -> AWS variable mapping. + + Returns + ------- + tuple[dict[str, str], list[str]] + The environment mapping and the list of AWS targets that were mapped. + """ + runtime_env = os.environ.copy() + mapped_targets: list[str] = [] + for source, target in _BEDROCK_ENV_MAPPINGS.items(): + value = (runtime_env.get(source) or "").strip() + if value: + runtime_env[target] = value + mapped_targets.append(target) + return runtime_env, mapped_targets + + +def wait_for_local_runtime(base_url: str, timeout_seconds: int = 25) -> None: + """Block until the local runtime ``/health`` endpoint responds. + + Parameters + ---------- + base_url : str + Base URL of the local agent-runtimes server. + timeout_seconds : int + Maximum number of seconds to wait. + + Raises + ------ + RuntimeError + If the server does not become ready before the timeout. + """ + endpoint = f"{base_url.rstrip('/')}/health" + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + response = requests.get(endpoint, timeout=2) + if response.status_code < 500: + return + except Exception: + pass + time.sleep(0.5) + raise RuntimeError( + f"Local agent-runtimes server did not become ready at {endpoint} " + f"within {timeout_seconds}s." + ) + + +def start_local_agent_runtime( + *, + agent_spec_id: str, + agent_name: str = DEFAULT_LOCAL_AGENT_NAME, + host: str = DEFAULT_LOCAL_HOST, + port: Optional[int] = None, + protocol: str = DEFAULT_LOCAL_PROTOCOL, + log_level: str = DEFAULT_LOCAL_LOG_LEVEL, + wait: bool = True, +) -> LocalAgentRuntime: + """Launch a local ``agent-runtimes`` server as a subprocess. + + Parameters + ---------- + agent_spec_id : str + Agent spec id to boot the runtime with. + agent_name : str + Registered agent name/id served by the runtime. + host : str + Host interface to bind to. + port : Optional[int] + Port to bind to. A free port is selected when omitted. + protocol : str + Transport protocol exposed by the runtime (e.g. ``vercel-ai``). + log_level : str + Log level for the runtime process. + wait : bool + Whether to block until the runtime reports healthy. + + Returns + ------- + LocalAgentRuntime + Handle pointing at the running server. + + Raises + ------ + RuntimeError + If the runtime cannot be started or does not become ready. + """ + resolved_port = port or find_free_port(host) + scheme = "http" + base_url = f"{scheme}://{host}:{resolved_port}" + + command = [ + "agent-runtimes", + "serve", + "--host", + host, + "--port", + str(resolved_port), + "--protocol", + protocol, + "--agent-id", + agent_spec_id, + "--agent-name", + agent_name, + "--log-level", + log_level, + ] + + runtime_env, mapped_targets = build_agent_runtime_env() + if mapped_targets: + logger.info( + "Launching local agent-runtimes with Bedrock env mapping: " + "DATALAYER_BEDROCK_* -> %s", + ", ".join(mapped_targets), + ) + else: + logger.info( + "Launching local agent-runtimes without DATALAYER_BEDROCK_* mapping " + "(no DATALAYER_BEDROCK_AWS_* variables detected)." + ) + + try: + process = subprocess.Popen(command, env=runtime_env) + except FileNotFoundError as exc: + raise RuntimeError( + "Could not start local agent runtime: the 'agent-runtimes' command " + "was not found on PATH. Install the agent-runtimes package first." + ) from exc + except Exception as exc: + raise RuntimeError( + f"Failed to start local agent runtime: {exc}" + ) from exc + + runtime = LocalAgentRuntime( + base_url=base_url, + agent_name=agent_name, + agent_spec_id=agent_spec_id, + process=process, + ) + + if wait: + try: + wait_for_local_runtime(base_url) + except Exception: + terminate_local_agent_runtime(runtime) + raise + + return runtime + + +def terminate_local_agent_runtime(runtime: LocalAgentRuntime) -> None: + """Terminate a local runtime process, escalating to kill if needed.""" + process = runtime.process + if process is None or process.poll() is not None: + return + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + + +def ensure_local_agent( + *, + base_url: str, + agent_name: str, + token: str, + agent_spec_id: str, + agent_library: str = "pydantic-ai", + transport: str = DEFAULT_LOCAL_PROTOCOL, + enable_skills: bool = True, + description: Optional[str] = None, + timeout: int = 120, +) -> None: + """Ensure a local agent with the expected transport is registered. + + Lists existing agents, replaces a mismatched-transport registration when + needed, and creates the agent if it is missing. + + Raises + ------ + RuntimeError + If the agent cannot be registered. + """ + base = base_url.rstrip("/") + headers = {"Authorization": f"Bearer {token}"} + + try: + response = requests.get(f"{base}/api/v1/agents", headers=headers, timeout=30) + payload = response.json() if response.content else {} + except Exception: + payload = {} + + existing_agents = payload.get("agents") if isinstance(payload, dict) else [] + if not isinstance(existing_agents, list): + existing_agents = [] + + for agent in existing_agents: + if not isinstance(agent, dict): + continue + existing_id = str(agent.get("id") or "").strip() + existing_name = str(agent.get("name") or "").strip() + if agent_name and (existing_id == agent_name or existing_name == agent_name): + existing_transport = str(agent.get("transport") or "").strip().lower() + if existing_transport in {"vercel-ai", "vercel_ai"}: + return + + # Replace mismatched transport registration so local interactions + # use the Vercel AI chat endpoint. + delete_target = existing_id or agent_name + try: + requests.delete( + f"{base}/api/v1/agents/{delete_target}", + headers=headers, + timeout=30, + ) + except Exception as exc: + raise RuntimeError( + "Local agent exists with incompatible transport " + f"'{existing_transport or 'unknown'}' and could not be " + f"replaced: {exc}" + ) from exc + break + + body = { + "name": agent_name, + "description": description + or f"Local agent '{agent_name}' registered by datalayer-core.", + "agent_library": agent_library, + "transport": transport, + "agent_spec_id": agent_spec_id, + "enable_skills": enable_skills, + "tools": [], + } + try: + response = requests.post( + f"{base}/api/v1/agents", + json=body, + headers=headers, + timeout=timeout, + ) + except requests.exceptions.RequestException as exc: + parsed = urlparse(base_url) + host = parsed.hostname or DEFAULT_LOCAL_HOST + port = parsed.port or 8000 + scheme = parsed.scheme or "http" + raise RuntimeError( + "Local agent bootstrap request failed: " + f"{exc}. Start agent-runtimes first, for example: " + f"agent-runtimes serve --host {host} --port {port} " + f"--agent-id {agent_spec_id} --agent-name {agent_name} " + f"(base URL: {scheme}://{host}:{port})." + ) from exc + + if response.status_code < 400: + return + body_text = response.text or "" + if response.status_code == 409 and "already exists" in body_text.lower(): + return + raise RuntimeError( + f"Local agent bootstrap failed ({response.status_code}): " + f"{body_text or 'unknown error'}" + ) + + +def delete_local_agents(*, base_url: str, token: str) -> tuple[int, int]: + """Delete all locally-registered agents. + + Returns + ------- + tuple[int, int] + ``(total_agents, deleted_agents)``. + """ + base = base_url.rstrip("/") + headers = {"Authorization": f"Bearer {token}"} + try: + response = requests.get(f"{base}/api/v1/agents", headers=headers, timeout=30) + payload = response.json() if response.content else {} + except Exception as exc: + logger.warning("Unable to list local agents for cleanup: %s", exc) + return (0, 0) + + agents = payload.get("agents") if isinstance(payload, dict) else [] + if not isinstance(agents, list): + agents = [] + + deleted = 0 + for agent in agents: + if not isinstance(agent, dict): + continue + agent_id = str(agent.get("id") or "").strip() + if not agent_id: + continue + try: + requests.delete( + f"{base}/api/v1/agents/{agent_id}", + headers=headers, + timeout=30, + ) + deleted += 1 + except Exception as exc: + logger.warning("Unable to delete local agent %s: %s", agent_id, exc) + + return (len(agents), deleted) + + +def delete_local_agent(*, base_url: str, token: str, agent_name: str) -> bool: + """Delete a single locally-registered agent by id or name. + + Parameters + ---------- + base_url : str + Local agent-runtimes base URL. + token : str + Bearer token used for local API calls. + agent_name : str + Agent id or name to delete. + + Returns + ------- + bool + ``True`` when a matching agent was found and delete accepted. + """ + target_name = str(agent_name or "").strip() + if not target_name: + return False + + base = base_url.rstrip("/") + headers = {"Authorization": f"Bearer {token}"} + try: + response = requests.get(f"{base}/api/v1/agents", headers=headers, timeout=30) + payload = response.json() if response.content else {} + except Exception as exc: + logger.warning("Unable to list local agents for cleanup: %s", exc) + return False + + agents = payload.get("agents") if isinstance(payload, dict) else [] + if not isinstance(agents, list): + return False + + for agent in agents: + if not isinstance(agent, dict): + continue + agent_id = str(agent.get("id") or "").strip() + name = str(agent.get("name") or "").strip() + if target_name not in {agent_id, name}: + continue + delete_target = agent_id or target_name + try: + response = requests.delete( + f"{base}/api/v1/agents/{delete_target}", + headers=headers, + timeout=30, + ) + return response.status_code < 400 + except Exception as exc: + logger.warning("Unable to delete local agent %s: %s", delete_target, exc) + return False + + return False + + +def extract_vercel_stream_text(raw: str) -> str: + """Extract concatenated text deltas from a Vercel AI SSE stream.""" + text_parts: list[str] = [] + for line in raw.splitlines(): + if not line.startswith("data: "): + continue + payload = line[6:].strip() + if not payload or payload == "[DONE]": + continue + try: + event = json.loads(payload) + except json.JSONDecodeError: + continue + + if isinstance(event, str): + if event.strip(): + text_parts.append(event) + continue + if not isinstance(event, dict): + continue + + for key in ("delta", "text", "content", "outputText", "textDelta"): + value = event.get(key) + if isinstance(value, str) and value: + text_parts.append(value) + + return "".join(text_parts).strip() + + +def _post_vercel_ai_chat( + *, + endpoint: str, + token: str, + prompt: str, + timeout: int, + source_label: str, +) -> dict[str, Any]: + """POST a single prompt to a Vercel AI chat endpoint. + + Shared by local and cloud chat helpers. Failures are captured into a + structured ``failure_cause`` (matching the eval report schema) instead of + raising. + + Returns + ------- + dict[str, Any] + On success: ``{"status": "completed", "output": {...}}``. + On failure: ``{"status": "failed", "output": {...}, + "failure_cause": {"stage", "type", "message", "detail_excerpt", + "execution_url"}}``. + """ + message_id = f"chat-{int(time.time() * 1000)}" + parts = [{"type": "text", "text": prompt}] + message = {"id": message_id, "role": "user", "parts": parts} + body = { + "trigger": "submit-message", + "id": f"chat-{message_id}", + "message": message, + "messages": [message], + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}", + } + try: + response = requests.post( + endpoint, + json=body, + headers=headers, + timeout=timeout, + ) + except requests.exceptions.RequestException as exc: + message_text = f"{source_label} chat request failed: {exc}" + return { + "status": "failed", + "output": {"text": "", "raw_stream_excerpt": ""}, + "failure_cause": { + "stage": "runtime_execution", + "type": "runtime_unreachable", + "message": message_text, + "detail_excerpt": message_text, + "execution_url": endpoint, + }, + } + + raw = response.text or "" + if response.status_code >= 400: + message_text = f"{source_label} chat failed (HTTP {response.status_code})" + return { + "status": "failed", + "output": {"text": "", "raw_stream_excerpt": raw[:2000]}, + "failure_cause": { + "stage": "runtime_execution", + "type": "runtime_http_error", + "message": message_text, + "detail_excerpt": raw[:2000] or message_text, + "execution_url": endpoint, + }, + } + + output_text = extract_vercel_stream_text(raw) + return { + "status": "completed", + "output": { + "text": output_text, + "raw_stream_excerpt": raw[:2000], + }, + } + + +def run_local_agent_chat( + *, + base_url: str, + agent_name: str, + token: str, + prompt: str, + timeout: int = 300, +) -> dict[str, Any]: + """Send a single prompt to a local agent via the Vercel AI endpoint. + + Failures are captured into a structured ``failure_cause`` (matching the + eval report schema) instead of raising, so callers can persist failed runs + and have them surfaced in reports. + + Returns + ------- + dict[str, Any] + On success: ``{"status": "completed", "output": {...}}``. + On failure: ``{"status": "failed", "output": {...}, + "failure_cause": {"stage", "type", "message", "detail_excerpt", + "execution_url"}}``. + """ + endpoint = f"{base_url.rstrip('/')}/api/v1/vercel-ai/{agent_name}" + return _post_vercel_ai_chat( + endpoint=endpoint, + token=token, + prompt=prompt, + timeout=timeout, + source_label="Local agent", + ) + + +def build_agent_runtimes_base_url(ingress: str) -> str: + """Derive the cloud ``agent-runtimes`` base URL from a runtime ingress. + + A runtime's ``ingress`` (returned by :meth:`DatalayerClient.create_runtime`) + points at the Jupyter server path on the runtimes host, e.g. + ``https://r1.datalayer.run/jupyter/server//``. The + ``agent-runtimes`` container is exposed under the sibling path + ``/agent-runtimes//`` on the **same** host. Using the + runtime's own ingress guarantees the correct runtimes host (e.g. ``r1``) + rather than the IAM/control-plane host (e.g. ``prod1``). + + Parameters + ---------- + ingress : str + The runtime ingress URL. + + Returns + ------- + str + The agent-runtimes base URL (without a trailing slash). + """ + base = (ingress or "").rstrip("/") + if "/jupyter/server/" in base: + base = base.replace("/jupyter/server/", "/agent-runtimes/", 1) + return base + + +def runtime_route_candidates( + *, + agent_name: Optional[str] = None, + agent_spec_id: Optional[str] = None, + pod_name: Optional[str] = None, +) -> list[str]: + """Build an ordered, de-duplicated list of Vercel AI route candidates. + + The ``agent-runtimes`` server inside a cloud runtime may register its agent + under different names depending on how it was launched. Trying a few known + candidates (explicit agent name, agent spec id, pod name, then the default + route) makes cloud execution resilient. + """ + candidates: list[str] = [] + for value in (agent_name, agent_spec_id, pod_name, DEFAULT_LOCAL_AGENT_NAME): + token = str(value or "").strip() + if token and token not in candidates: + candidates.append(token) + return candidates + + +def run_cloud_agent_chat( + *, + ingress: str, + token: str, + prompt: str, + route_candidates: list[str], + timeout: int = 300, +) -> dict[str, Any]: + """Send a single prompt to a cloud runtime agent via the Vercel AI endpoint. + + The execution URL is derived from the runtime's ``ingress`` (via + :func:`build_agent_runtimes_base_url`) so the request targets the correct + runtimes host (e.g. ``r1.datalayer.run``). Each route candidate is tried in + order until one succeeds; if all fail, the last structured failure is + returned with every attempted URL recorded in ``detail_excerpt``. + + Returns + ------- + dict[str, Any] + Same contract as :func:`run_local_agent_chat`. + """ + base_url = build_agent_runtimes_base_url(ingress) + candidates = [c for c in route_candidates if str(c or "").strip()] + if not candidates: + candidates = [DEFAULT_LOCAL_AGENT_NAME] + + attempted: list[str] = [] + last_result: dict[str, Any] | None = None + for route in candidates: + endpoint = f"{base_url}/api/v1/vercel-ai/{route}" + attempted.append(endpoint) + result = _post_vercel_ai_chat( + endpoint=endpoint, + token=token, + prompt=prompt, + timeout=timeout, + source_label="Cloud agent", + ) + if str(result.get("status") or "").strip().lower() == "completed": + return result + last_result = result + + if last_result is None: + last_result = { + "status": "failed", + "output": {"text": "", "raw_stream_excerpt": ""}, + "failure_cause": { + "stage": "runtime_execution", + "type": "runtime_unreachable", + "message": "No cloud agent route candidates available.", + "detail_excerpt": "No cloud agent route candidates available.", + "execution_url": base_url, + }, + } + elif len(attempted) > 1: + failure_cause = last_result.get("failure_cause") + if isinstance(failure_cause, dict): + tried = "; ".join(attempted) + base_detail = str(failure_cause.get("detail_excerpt") or "") + failure_cause["detail_excerpt"] = ( + f"{base_detail}\nAttempted routes: {tried}" + ).strip() + failure_cause["attempted_urls"] = attempted + return last_result + diff --git a/datalayer_core/runtimes/runtime.py b/datalayer_core/runtimes/runtime.py index e17fcab6..dd292ccc 100644 --- a/datalayer_core/runtimes/runtime.py +++ b/datalayer_core/runtimes/runtime.py @@ -16,13 +16,13 @@ from jupyter_kernel_client import KernelClient from datalayer_core.mixins.authn import AuthnMixin -from datalayer_core.mixins.runtime_snapshots import RuntimeSnapshotsMixin +from datalayer_core.mixins.sandbox_snapshots import SandboxSnapshotsMixin from datalayer_core.mixins.runtimes import RuntimesMixin from datalayer_core.models import ExecutionResponse from datalayer_core.models.runtime import RuntimeModel -from datalayer_core.runtimes.runtime_snapshot import ( - RuntimeSnapshotModel, - as_runtime_snapshots, +from datalayer_core.runtimes.sandbox_snapshot import ( + SandboxSnapshotModel, + as_code_sandbox_snapshots, create_snapshot, ) from datalayer_core.utils.defaults import ( @@ -38,7 +38,7 @@ from datalayer_core.utils.urls import DEFAULT_DATALAYER_RUN_URL, DatalayerURLs -class RuntimeService(AuthnMixin, RuntimesMixin, RuntimeSnapshotsMixin): +class RuntimeService(AuthnMixin, RuntimesMixin, SandboxSnapshotsMixin): """ Service for managing Datalayer runtime operations. @@ -678,7 +678,7 @@ def create_snapshot( name: Optional[str] = None, description: Optional[str] = None, stop: bool = True, - ) -> "RuntimeSnapshotModel": + ) -> "SandboxSnapshotModel": """ Create a new snapshot from the current state. @@ -693,7 +693,7 @@ def create_snapshot( Returns ------- - RuntimeSnapshot + SandboxSnapshot A new snapshot object. """ if self.model.pod_name is None: @@ -720,8 +720,8 @@ def create_snapshot( pass response = self._list_snapshots() - snapshot_objects = as_runtime_snapshots(response) - snapshot: Optional[RuntimeSnapshotModel] = None + snapshot_objects = as_code_sandbox_snapshots(response) + snapshot: Optional[SandboxSnapshotModel] = None max_poll_attempts = max( 1, int(os.getenv("DATALAYER_SNAPSHOT_POLL_ATTEMPTS", "30")), @@ -736,14 +736,14 @@ def create_snapshot( break time.sleep(poll_interval_seconds) response = self._list_snapshots() - snapshot_objects = as_runtime_snapshots(response) + snapshot_objects = as_code_sandbox_snapshots(response) if snapshot is None: raise RuntimeError( f"Snapshot '{name}' was created but not found in snapshot listing" ) - return RuntimeSnapshotModel( + return SandboxSnapshotModel( uid=snapshot.uid, name=name, description=description, diff --git a/datalayer_core/runtimes/runtime_snapshot.py b/datalayer_core/runtimes/sandbox_snapshot.py similarity index 78% rename from datalayer_core/runtimes/runtime_snapshot.py rename to datalayer_core/runtimes/sandbox_snapshot.py index d2ea786c..a02198eb 100644 --- a/datalayer_core/runtimes/runtime_snapshot.py +++ b/datalayer_core/runtimes/sandbox_snapshot.py @@ -4,13 +4,13 @@ """ Snapshot services for Datalayer. -Provides runtime snapshot management and operations in Datalayer environments. +Provides code sandbox snapshot management and operations in Datalayer environments. """ import uuid from typing import Any, List, Optional, Tuple -from datalayer_core.models.runtime_snapshot import RuntimeSnapshotModel +from datalayer_core.models.sandbox_snapshot import SandboxSnapshotModel def create_snapshot(name: Optional[str], description: Optional[str]) -> Tuple[str, str]: @@ -39,9 +39,9 @@ def create_snapshot(name: Optional[str], description: Optional[str]) -> Tuple[st return name, description -def as_runtime_snapshots(response: dict[str, Any]) -> List["RuntimeSnapshotModel"]: +def as_code_sandbox_snapshots(response: dict[str, Any]) -> List["SandboxSnapshotModel"]: """ - Parse API response and create RuntimeSnapshot objects. + Parse API response and create SandboxSnapshot objects. Parameters ---------- @@ -50,15 +50,15 @@ def as_runtime_snapshots(response: dict[str, Any]) -> List["RuntimeSnapshotModel Returns ------- - List[RuntimeSnapshot] - List of RuntimeSnapshot objects parsed from the response. + List[SandboxSnapshot] + List of SandboxSnapshot objects parsed from the response. """ snapshot_objects = [] if response["success"]: snapshots = response["snapshots"] for snapshot in snapshots: snapshot_objects.append( - RuntimeSnapshotModel( + SandboxSnapshotModel( uid=snapshot["uid"], name=snapshot["name"], description=snapshot["description"], diff --git a/datalayer_core/tests/test_cli.py b/datalayer_core/tests/test_cli.py index bcd6cbb5..9083a986 100644 --- a/datalayer_core/tests/test_cli.py +++ b/datalayer_core/tests/test_cli.py @@ -42,6 +42,7 @@ def _delete_all_runtimes(secs: int = 5) -> None: (["--version"], "1."), (["--help"], "The Datalayer CLI application"), (["about"], "About"), + (["evals", "--help"], "Launch and monitor SaaS evals"), ], ) def test_cli(args: List[str], expected_output: str) -> None: diff --git a/datalayer_core/tests/test_cli_main.py b/datalayer_core/tests/test_cli_main.py new file mode 100644 index 00000000..fe12f845 --- /dev/null +++ b/datalayer_core/tests/test_cli_main.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Tests for CLI main argument normalization.""" + +from datalayer_core.cli.__main__ import _normalize_global_options + + +def test_normalize_global_options_hoists_runtimes_url_after_subcommands(): + argv = [ + "d", + "ray", + "clusters", + "ls", + "--runtimes-url", + "http://localhost:9500", + ] + + normalized = _normalize_global_options(argv) + + assert normalized == [ + "d", + "--runtimes-url", + "http://localhost:9500", + "ray", + "clusters", + "ls", + ] + + +def test_normalize_global_options_preserves_equals_syntax(): + argv = ["d", "whoami", "--iam-url=https://iam.example"] + + normalized = _normalize_global_options(argv) + + assert normalized == ["d", "--iam-url=https://iam.example", "whoami"] diff --git a/datalayer_core/tests/test_client.py b/datalayer_core/tests/test_client.py index 342e3ad6..532fa133 100644 --- a/datalayer_core/tests/test_client.py +++ b/datalayer_core/tests/test_client.py @@ -11,7 +11,7 @@ from dotenv import load_dotenv from datalayer_core import DatalayerClient -from datalayer_core.models.runtime_snapshot import RuntimeSnapshotModel +from datalayer_core.models.sandbox_snapshot import SandboxSnapshotModel load_dotenv() @@ -101,7 +101,7 @@ def test_runtime_create_execute_and_list() -> None: not bool(TEST_DATALAYER_API_KEY), reason="TEST_DATALAYER_API_KEY is not set, skipping secret tests.", ) -def test_runtime_snapshot_create_and_delete() -> None: +def test_code_sandbox_snapshot_create_and_delete() -> None: """ Test the creation and deletion of runtime. """ @@ -114,7 +114,7 @@ def test_runtime_snapshot_create_and_delete() -> None: def _delete_with_retry( client: DatalayerClient, - snap: RuntimeSnapshotModel, + snap: SandboxSnapshotModel, retries: int = 10, delay: float = 5.0, ) -> None: diff --git a/datalayer_core/tests/test_ray.py b/datalayer_core/tests/test_ray.py new file mode 100644 index 00000000..3b2b193f --- /dev/null +++ b/datalayer_core/tests/test_ray.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Tests for Ray URL resolution and Ray mixin requests.""" + +from __future__ import annotations + +from datalayer_core.mixins.ray import RayMixin +from datalayer_core.utils.urls import DatalayerURLs + + +class _FakeResponse: + def __init__(self, payload): + self._payload = payload + + def json(self): + return self._payload + + +class _FakeRayClient(RayMixin): + def __init__(self): + self.urls = DatalayerURLs.from_environment(ray_url="https://ray.example") + self.calls = [] + + def _fetch(self, url: str, **kwargs): + self.calls.append((url, kwargs)) + return _FakeResponse({"success": True, "url": url, "kwargs": kwargs}) + + +def test_urls_resolve_ray_url_from_environment(monkeypatch): + monkeypatch.setenv("DATALAYER_RAY_URL", "https://ray-from-env.example/") + urls = DatalayerURLs.from_environment() + assert urls.ray_url == "https://ray-from-env.example" + + +def test_urls_resolve_ray_url_from_default(monkeypatch): + monkeypatch.delenv("DATALAYER_RAY_URL", raising=False) + urls = DatalayerURLs.from_environment() + assert urls.ray_url == "https://prod1.datalayer.run" + + +def test_ray_mixin_job_logs_and_events_paths(): + client = _FakeRayClient() + + logs_payload = client.ray_get_job_logs( + "job-1", + namespace="team-a", + pod_name="pod-1", + container="submitter", + tail_lines=50, + ) + events_payload = client.ray_get_job_events("job-1", namespace="team-a", limit=25) + + assert logs_payload["success"] is True + assert events_payload["success"] is True + + logs_url, logs_kwargs = client.calls[0] + assert logs_url.endswith("/api/runtimes/v1/ray/jobs/job-1/logs") + assert logs_kwargs["params"] == { + "namespace": "team-a", + "tail_lines": 50, + "pod_name": "pod-1", + "container": "submitter", + } + + events_url, events_kwargs = client.calls[1] + assert events_url.endswith("/api/runtimes/v1/ray/jobs/job-1/events") + assert events_kwargs["params"] == { + "namespace": "team-a", + "limit": 25, + } + + +def test_ray_mixin_uses_runtimes_path_by_default(): + client = _FakeRayClient() + + payload = client.ray_list_clusters(namespace="default") + + assert payload["success"] is True + assert len(client.calls) == 1 + first_url, _ = client.calls[0] + assert first_url.endswith("/api/runtimes/v1/ray/clusters") + + +def test_ray_mixin_uses_addon_path_in_direct_mode(): + client = _FakeRayClient() + client._ray_direct_addon = True + + payload = client.ray_list_clusters(namespace="default") + + assert payload["success"] is True + assert len(client.calls) == 1 + first_url, _ = client.calls[0] + assert first_url.endswith("/api/ray/v1/clusters") diff --git a/datalayer_core/tests/test_usage.py b/datalayer_core/tests/test_usage.py new file mode 100644 index 00000000..4d0d5786 --- /dev/null +++ b/datalayer_core/tests/test_usage.py @@ -0,0 +1,263 @@ +# Copyright (c) 2023-2026 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +"""Integration tests for usage history across billable account scopes.""" + +import os +import time +import uuid +from datetime import datetime +from typing import Any +from urllib.parse import urlencode + +import pytest +from dotenv import load_dotenv + +from datalayer_core import DatalayerClient +from datalayer_core.utils.urls import DatalayerURLs + +load_dotenv() + +TEST_DATALAYER_API_KEY = os.environ.get("TEST_DATALAYER_API_KEY") or os.environ.get( + "DATALAYER_API_KEY" +) + +LOCAL_RUN_URL = os.environ.get("TEST_DATALAYER_RUN_URL", "http://localhost:9700") +LOCAL_IAM_URL = os.environ.get("TEST_DATALAYER_IAM_URL", "http://localhost:9700") +LOCAL_RUNTIMES_URL = os.environ.get( + "TEST_DATALAYER_RUNTIMES_URL", + "http://localhost:9500", +) + + +def _build_test_client() -> DatalayerClient: + return DatalayerClient( + token=TEST_DATALAYER_API_KEY, + urls=DatalayerURLs.from_environment( + run_url=LOCAL_RUN_URL, + iam_url=LOCAL_IAM_URL, + runtimes_url=LOCAL_RUNTIMES_URL, + ), + ) + + +def _parse_timestamp(value: Any) -> datetime | None: + if not value: + return None + if isinstance(value, datetime): + return value + text = str(value).strip() + if not text: + return None + if text.endswith("Z"): + text = text[:-1] + "+00:00" + try: + return datetime.fromisoformat(text) + except ValueError: + return None + + +def _iam_get_json(client: DatalayerClient, path: str) -> dict[str, Any]: + response = client._fetch(f"{client.urls.iam_url}{path}") + payload = response.json() + if not payload.get("success", True): + raise RuntimeError(payload.get("message", f"Request failed for path {path}")) + return payload + + +def _resolve_billable_accounts(client: DatalayerClient) -> dict[str, dict[str, str]]: + whoami_payload = _iam_get_json(client, "/api/iam/v1/whoami") + profile = whoami_payload.get("profile") or {} + if not profile.get("uid"): + raise RuntimeError("Unable to resolve authenticated user profile uid") + + memberships_payload = _iam_get_json(client, "/api/iam/v1/memberships") + memberships = memberships_payload.get("memberships") or [] + + first_team = next( + (m for m in memberships if str(m.get("type") or "").lower() == "team"), + None, + ) + datalayer_org = next( + ( + m + for m in memberships + if str(m.get("type") or "").lower() == "organization" + and str(m.get("handle") or "").lower() == "datalayer" + ), + None, + ) + + accounts: dict[str, dict[str, str]] = { + "user": { + "uid": str(profile["uid"]), + "kind": "user", + "handle": str(profile.get("handle") or ""), + } + } + + if first_team and first_team.get("uid"): + accounts["team"] = { + "uid": str(first_team["uid"]), + "kind": "team", + "handle": str(first_team.get("handle") or ""), + } + + if datalayer_org and datalayer_org.get("uid"): + accounts["datalayer"] = { + "uid": str(datalayer_org["uid"]), + "kind": "organization", + "handle": str(datalayer_org.get("handle") or "datalayer"), + } + + return accounts + + +def _fetch_usage_history( + client: DatalayerClient, + account_uid: str, + account_kind: str, +) -> list[dict[str, Any]]: + query: dict[str, str] = { + "billable_account_uid": account_uid, + } + # API currently recognizes only user|organization kinds. + if account_kind in {"user", "organization"}: + query["billable_account_kind"] = account_kind + + payload = _iam_get_json( + client, + f"/api/iam/v1/usage/user?{urlencode(query)}", + ) + return payload.get("usages") or [] + + +def _find_usage_row(usages: list[dict[str, Any]], runtime_uid: str) -> dict[str, Any] | None: + for usage in usages: + if str(usage.get("resource_uid") or "") == runtime_uid: + return usage + return None + + +def _wait_for_usage_row( + client: DatalayerClient, + account_uid: str, + account_kind: str, + runtime_uid: str, + expect_closed: bool, + timeout_seconds: int = 240, + poll_seconds: int = 5, +) -> dict[str, Any]: + deadline = time.time() + timeout_seconds + last_seen: dict[str, Any] | None = None + + while time.time() < deadline: + usages = _fetch_usage_history(client, account_uid, account_kind) + row = _find_usage_row(usages, runtime_uid) + if row is not None: + last_seen = row + has_end_date = bool(row.get("end_date")) + if expect_closed == has_end_date: + return row + time.sleep(poll_seconds) + + state = "closed" if expect_closed else "open" + raise AssertionError( + f"Timed out waiting for {state} usage row for runtime={runtime_uid}. Last seen={last_seen}" + ) + + +@pytest.mark.parametrize("account_case", ["user", "team", "datalayer"]) +@pytest.mark.skipif( + not bool(TEST_DATALAYER_API_KEY), + reason="TEST_DATALAYER_API_KEY is not set, skipping usage integration tests.", +) +def test_usage_matrix_creation_reservation_and_history(account_case: str) -> None: + """ + Validate usage lifecycle with a 1-minute reservation and manual stop at ~30s. + + Matrix: + - user billable account + - team billable account + - datalayer organization billable account + + Coverage: + - runtime creation + - active reservation/open usage row while running + - closed usage history row after manual stop + """ + client = _build_test_client() + accounts = _resolve_billable_accounts(client) + + if account_case not in accounts: + pytest.skip(f"No available account for case={account_case}") + + account = accounts[account_case] + runtime = None + + runtime_name = f"test_usage_{account_case}_{uuid.uuid4().hex[:8]}" + + try: + runtime = client.create_runtime( + name=runtime_name, + time_reservation=1, + billable_account_uid=account["uid"], + billable_account_type=account["kind"], + billable_account_handle=account["handle"] or None, + ) + + # Creation coverage. + assert runtime.pod_name, "Runtime pod_name should be set after creation" + assert runtime.reservation_id, "Runtime reservation_id should be present" + + # Reservation coverage: usage row should be open while runtime is running. + open_usage = _wait_for_usage_row( + client=client, + account_uid=account["uid"], + account_kind=account["kind"], + runtime_uid=runtime.pod_name, + expect_closed=False, + timeout_seconds=180, + ) + assert not open_usage.get("end_date"), "Expected open usage row while runtime is running" + + # Manual stop after ~30 seconds for a 1-minute reservation scenario. + stop_wait_start = time.monotonic() + time.sleep(30) + stop_wait_elapsed = time.monotonic() - stop_wait_start + assert client.terminate_runtime(runtime), "Runtime termination should succeed" + assert stop_wait_elapsed >= 25, ( + f"Expected to wait about 30s before manual stop, got {stop_wait_elapsed:.2f}s" + ) + + # Usage history coverage: same runtime row should close with end_date set. + closed_usage = _wait_for_usage_row( + client=client, + account_uid=account["uid"], + account_kind=account["kind"], + runtime_uid=runtime.pod_name, + expect_closed=True, + timeout_seconds=240, + ) + assert closed_usage.get("end_date"), "Expected closed usage row after manual stop" + + # Usage history timestamps can be rounded to seconds and occasionally collapse + # to the same second; keep checks robust to that backend behavior. + start_dt = _parse_timestamp(closed_usage.get("start_date")) + end_dt = _parse_timestamp(closed_usage.get("end_date")) + assert start_dt is not None and end_dt is not None, "Usage start/end timestamps must be parseable" + duration_seconds = (end_dt - start_dt).total_seconds() + assert duration_seconds >= 0, ( + f"Expected non-negative usage duration, got {duration_seconds:.2f}s" + ) + assert duration_seconds <= 90, ( + f"Expected usage duration to remain bounded for a 1-minute reservation, got {duration_seconds:.2f}s" + ) + + finally: + if runtime is not None and runtime.pod_name: + # Best-effort cleanup for flaky failures. + try: + client.terminate_runtime(runtime) + except Exception: + pass diff --git a/datalayer_core/utils/urls.py b/datalayer_core/utils/urls.py index 028250c3..f51f7cc2 100644 --- a/datalayer_core/utils/urls.py +++ b/datalayer_core/utils/urls.py @@ -34,6 +34,8 @@ DEFAULT_DATALAYER_AI_INFERENCE_URL = DEFAULT_DATALAYER_RUN_URL +DEFAULT_DATALAYER_RAY_URL = DEFAULT_DATALAYER_RUN_URL + DEFAULT_DATALAYER_MCP_SERVERS_URL = DEFAULT_DATALAYER_RUN_URL DEFAULT_DATALAYER_OTEL_URL = DEFAULT_DATALAYER_RUN_URL @@ -85,6 +87,8 @@ class DatalayerURLs: The Datalayer support service URL mcp_server_url : str The Datalayer MCP server service URL + ray_url : str + The Datalayer Ray service URL """ run_url: str @@ -101,6 +105,7 @@ class DatalayerURLs: status_url: str support_url: str mcp_server_url: str + ray_url: str @classmethod def from_environment( @@ -119,6 +124,7 @@ def from_environment( status_url: Optional[str] = None, support_url: Optional[str] = None, mcp_server_url: Optional[str] = None, + ray_url: Optional[str] = None, ) -> "DatalayerURLs": """ Create DatalayerURLs instance from environment variables and parameters. @@ -167,6 +173,9 @@ def from_environment( mcp_server_url : Optional[str] Override for the MCP server URL. If None, will check DATALAYER_MCP_SERVER_URL env var then fallback to DEFAULT_DATALAYER_MCP_SERVER_URL. + ray_url : Optional[str] + Override for the Ray URL. If None, will check DATALAYER_RAY_URL env var + then fallback to DEFAULT_DATALAYER_RAY_URL. Returns ------- @@ -276,6 +285,12 @@ def from_environment( or base_url_for_services or DEFAULT_DATALAYER_MCP_SERVERS_URL ) + resolved_ray_url = ( + ray_url + or os.environ.get("DATALAYER_RAY_URL") + or base_url_for_services + or DEFAULT_DATALAYER_RAY_URL + ) # Strip trailing slashes for consistency resolved_run_url = resolved_run_url.rstrip("/") @@ -292,6 +307,7 @@ def from_environment( resolved_status_url = resolved_status_url.rstrip("/") resolved_support_url = resolved_support_url.rstrip("/") resolved_mcp_server_url = resolved_mcp_server_url.rstrip("/") + resolved_ray_url = resolved_ray_url.rstrip("/") return cls( run_url=resolved_run_url, @@ -308,6 +324,7 @@ def from_environment( status_url=resolved_status_url, support_url=resolved_support_url, mcp_server_url=resolved_mcp_server_url, + ray_url=resolved_ray_url, ) def __post_init__(self) -> None: @@ -326,3 +343,4 @@ def __post_init__(self) -> None: self.status_url = self.status_url.rstrip("/") self.support_url = self.support_url.rstrip("/") self.mcp_server_url = self.mcp_server_url.rstrip("/") + self.ray_url = self.ray_url.rstrip("/") diff --git a/docs/docusaurus.config.js b/docs/docusaurus.config.js index ccc3b224..e9dad98f 100644 --- a/docs/docusaurus.config.js +++ b/docs/docusaurus.config.js @@ -160,7 +160,7 @@ module.exports = { }, { label: 'Datalayer Docs', - href: 'https://docs.datalayer.app', + href: 'https://datalayer.ai/docs', }, { label: 'Datalayer Blog', diff --git a/examples/README.md b/examples/README.md index fe1ee0f2..7c300469 100644 --- a/examples/README.md +++ b/examples/README.md @@ -10,6 +10,14 @@ This directory contains practical examples demonstrating how to use the Datalaye ## 🎯 Client Fundamentals +### 📈 [Evals CLI Workflows](./evals/README.md) + +Beginner-friendly walkthrough for launching and monitoring SaaS evals with `datalayer evals`. + +- **Use Case**: Run evals/experiments from CLI and track in the SaaS UI +- **Technologies**: Datalayer Core CLI, AI Agents eval APIs +- **Features**: Eval/experiment/run creation, run watching, live target inspection, make targets for quick onboarding + ### 🎭 [Datalayer Decorator](./decorator/README.md) Comprehensive examples demonstrating the `@datalayer` decorator for seamless remote function execution. @@ -79,7 +87,7 @@ This project is licensed under the MIT License - see the [LICENSE](../../LICENSE ## Support -- **Documentation**: [Datalayer Platform Documentation](https://docs.datalayer.app/) +- **Documentation**: [Datalayer Platform Documentation](https://datalayer.ai/docs/) - **Issues**: [GitHub Issues](https://github.com/datalayer/core/issues) - **Community**: [Datalayer Platform](https://datalayer.app/) diff --git a/examples/decorator/README.md b/examples/decorator/README.md index 153894bc..a82ff012 100644 --- a/examples/decorator/README.md +++ b/examples/decorator/README.md @@ -13,7 +13,7 @@ This example showcases: - **Function Decoration**: Transform regular functions into distributed computations using `@datalayer` - **Remote Execution**: Execute functions on cloud-based runtimes with different environments - **Variable Management**: Pass inputs and retrieve outputs from remote execution contexts -- **Snapshot Integration**: Use pre-configured runtime snapshots for consistent environments +- **Snapshot Integration**: Use pre-configured code sandbox snapshots for consistent environments - **Error Handling**: Timeout configuration and debug mode for development ## Features @@ -210,7 +210,7 @@ This project is licensed under the MIT License - see the [LICENSE](../../LICENSE ## Support -- **Documentation**: [Datalayer Platform Documentation](https://docs.datalayer.app/) +- **Documentation**: [Datalayer Platform Documentation](https://datalayer.ai/docs/) - **Issues**: [GitHub Issues](https://github.com/datalayer/core/issues) - **Community**: [Datalayer Platform](https://datalayer.app/) diff --git a/examples/fastapi/README.md b/examples/fastapi/README.md index 66bb5b38..332da6f1 100644 --- a/examples/fastapi/README.md +++ b/examples/fastapi/README.md @@ -154,7 +154,7 @@ This project is licensed under the MIT License - see the [LICENSE](../../LICENSE ## Support -- **Documentation**: [Datalayer Platform Documentation](https://docs.datalayer.app/) +- **Documentation**: [Datalayer Platform Documentation](https://datalayer.ai/docs/) - **Issues**: [GitHub Issues](https://github.com/datalayer/core/issues) - **Community**: [Datalayer Platform](https://datalayer.app/) diff --git a/examples/nextjs/README.md b/examples/nextjs/README.md index b0fe1dd7..d6a8f7f3 100644 --- a/examples/nextjs/README.md +++ b/examples/nextjs/README.md @@ -278,7 +278,7 @@ This project is licensed under the Modified BSD License - see the [LICENSE](../. ## Support -- **Documentation**: [Datalayer Platform Documentation](https://docs.datalayer.app/) +- **Documentation**: [Datalayer Platform Documentation](https://datalayer.ai/docs/) - **Issues**: [GitHub Issues](https://github.com/datalayer/core/issues) - **Community**: [Datalayer Platform](https://datalayer.app/) diff --git a/examples/nextjs/src/components/Footer.tsx b/examples/nextjs/src/components/Footer.tsx index d55f5c25..349383e3 100644 --- a/examples/nextjs/src/components/Footer.tsx +++ b/examples/nextjs/src/components/Footer.tsx @@ -61,7 +61,7 @@ export default function Footer() { =2.10,<3", - "keyring==23.0.1", + "keyring", "mcp", "pydantic-settings", "pydantic[email]", diff --git a/src/api/DatalayerApi.ts b/src/api/DatalayerApi.ts index 469fa4a1..d8130cf9 100644 --- a/src/api/DatalayerApi.ts +++ b/src/api/DatalayerApi.ts @@ -301,10 +301,35 @@ async function handleAxiosRedirection( ): Promise { let redirect = response.headers.location; if (redirect) { - const parsedURL = URLExt.parse(originalConfig.url!); - const baseUrl = parsedURL.protocol + '//' + parsedURL.hostname; - if (!redirect.startsWith(baseUrl)) { - redirect = URLExt.join(baseUrl, redirect); + const baseUrl = originalConfig.url ?? ''; + const normalizedRedirect = String(redirect).replace( + /^([a-z][a-z0-9+.-]*):\/(?!\/)/i, + '$1://', + ); + + try { + const resolved = new URL(normalizedRedirect, baseUrl); + const base = new URL(baseUrl, typeof window !== 'undefined' ? window.location.origin : undefined); + + // If a proxy emits an http Location for the same host while the + // original request is https, force https to avoid mixed-content errors + // that browsers often report as CORS/network failures. + if ( + base.protocol === 'https:' && + resolved.protocol === 'http:' && + resolved.hostname === base.hostname + ) { + resolved.protocol = 'https:'; + if (resolved.port === '80') { + resolved.port = ''; + } + } + + redirect = resolved.toString(); + } catch { + const parsedURL = URLExt.parse(baseUrl); + const fallbackBase = parsedURL.protocol + '//' + parsedURL.hostname; + redirect = URLExt.join(fallbackBase, normalizedRedirect); } } diff --git a/src/api/__tests__/runtimes.integration.test.ts b/src/api/__tests__/runtimes.integration.test.ts index a626c479..077b857d 100644 --- a/src/api/__tests__/runtimes.integration.test.ts +++ b/src/api/__tests__/runtimes.integration.test.ts @@ -556,7 +556,7 @@ describe.skipIf(skipTests || skipInCi)( } }); - it('should successfully list runtime snapshots', async () => { + it('should successfully list code sandbox snapshots', async () => { console.log('Testing list snapshots endpoint...'); const response = await snapshots.listSnapshots( @@ -564,7 +564,7 @@ describe.skipIf(skipTests || skipInCi)( BASE_URL, ); - console.log(`Found ${response.snapshots.length} runtime snapshots`); + console.log(`Found ${response.snapshots.length} code sandbox snapshots`); expect(response).toBeDefined(); expect(response).toHaveProperty('success'); diff --git a/src/api/iam/profile.ts b/src/api/iam/profile.ts index 85ae8682..feefc7c1 100644 --- a/src/api/iam/profile.ts +++ b/src/api/iam/profile.ts @@ -15,6 +15,7 @@ import { requestDatalayerAPI } from '../DatalayerApi'; import { API_BASE_PATHS, DEFAULT_SERVICE_URLS } from '../constants'; import { MembershipsResponse, + ShareablePrincipalsResponse, UserMeResponse, WhoAmIResponse, } from '../../models/IAM'; @@ -76,3 +77,23 @@ export const memberships = async ( token, }); }; + +/** + * Get the set of principals the authenticated user can share artifacts with + * (self + member organizations + member teams). + * + * @param token - Authentication token (required) + * @param baseUrl - Base URL for the API (defaults to production IAM URL) + */ +export const principalsShareable = async ( + token: string, + baseUrl: string = DEFAULT_SERVICE_URLS.IAM, +): Promise => { + validateToken(token); + + return requestDatalayerAPI({ + url: `${baseUrl}${API_BASE_PATHS.IAM}/principals/shareable`, + method: 'GET', + token, + }); +}; diff --git a/src/api/runtimes/checkpoints.ts b/src/api/runtimes/checkpoints.ts index c318561c..eae229d3 100644 --- a/src/api/runtimes/checkpoints.ts +++ b/src/api/runtimes/checkpoints.ts @@ -7,7 +7,7 @@ * Runtime checkpoints API functions for the Datalayer platform. * * Provides functions for managing CRIU full-pod checkpoints. - * These are distinct from runtime snapshots (Jupyter sandbox snapshots). + * These are distinct from code sandbox snapshots (Jupyter sandbox snapshots). * * @module api/runtimes/checkpoints */ diff --git a/src/api/runtimes/snapshots.ts b/src/api/runtimes/snapshots.ts index de62661b..6399f0f5 100644 --- a/src/api/runtimes/snapshots.ts +++ b/src/api/runtimes/snapshots.ts @@ -4,9 +4,9 @@ */ /** - * Runtime snapshots API functions for the Datalayer platform. + * Code Sandbox snapshots API functions for the Datalayer platform. * - * Provides functions for managing runtime snapshots (saved runtime states). + * Provides functions for managing code sandbox snapshots (saved runtime states). * * @module api/runtimes/snapshots */ @@ -14,11 +14,11 @@ import { requestDatalayerAPI } from '../DatalayerApi'; import { API_BASE_PATHS, DEFAULT_SERVICE_URLS } from '../constants'; import { - CreateRuntimeSnapshotRequest, - ListRuntimeSnapshotsResponse, - GetRuntimeSnapshotResponse, - CreateRuntimeSnapshotResponse, -} from '../../models/RuntimeSnapshotDTO'; + CreateCodeSandboxSnapshotRequest, + ListCodeSandboxSnapshotsResponse, + GetCodeSandboxSnapshotResponse, + CreateCodeSandboxSnapshotResponse, +} from '../../models/CodeSandboxSnapshotDTO'; import { validateToken, validateRequiredString } from '../utils/validation'; /** @@ -31,13 +31,13 @@ import { validateToken, validateRequiredString } from '../utils/validation'; */ export const createSnapshot = async ( token: string, - data: CreateRuntimeSnapshotRequest, + data: CreateCodeSandboxSnapshotRequest, baseUrl: string = DEFAULT_SERVICE_URLS.RUNTIMES, -): Promise => { +): Promise => { validateToken(token); - return requestDatalayerAPI({ - url: `${baseUrl}${API_BASE_PATHS.RUNTIMES}/runtime-snapshots`, + return requestDatalayerAPI({ + url: `${baseUrl}${API_BASE_PATHS.RUNTIMES}/sandbox-snapshots`, method: 'POST', token, body: data, @@ -45,7 +45,7 @@ export const createSnapshot = async ( }; /** - * List all runtime snapshots. + * List all code sandbox snapshots. * @param token - Authentication token * @param baseUrl - Base URL for the API (defaults to production Runtimes URL) * @returns Promise resolving to list of snapshots @@ -54,18 +54,18 @@ export const createSnapshot = async ( export const listSnapshots = async ( token: string, baseUrl: string = DEFAULT_SERVICE_URLS.RUNTIMES, -): Promise => { +): Promise => { validateToken(token); - return requestDatalayerAPI({ - url: `${baseUrl}${API_BASE_PATHS.RUNTIMES}/runtime-snapshots`, + return requestDatalayerAPI({ + url: `${baseUrl}${API_BASE_PATHS.RUNTIMES}/sandbox-snapshots`, method: 'GET', token, }); }; /** - * Get details for a specific runtime snapshot. + * Get details for a specific code sandbox snapshot. * @param token - Authentication token * @param snapshotId - The unique identifier of the snapshot * @param baseUrl - Base URL for the API (defaults to production Runtimes URL) @@ -77,19 +77,19 @@ export const getSnapshot = async ( token: string, snapshotId: string, baseUrl: string = DEFAULT_SERVICE_URLS.RUNTIMES, -): Promise => { +): Promise => { validateToken(token); validateRequiredString(snapshotId, 'Snapshot ID'); - return requestDatalayerAPI({ - url: `${baseUrl}${API_BASE_PATHS.RUNTIMES}/runtime-snapshots/${snapshotId}`, + return requestDatalayerAPI({ + url: `${baseUrl}${API_BASE_PATHS.RUNTIMES}/sandbox-snapshots/${snapshotId}`, method: 'GET', token, }); }; /** - * Delete a runtime snapshot. + * Delete a code sandbox snapshot. * @param token - Authentication token * @param snapshotId - The unique identifier of the snapshot to delete * @param baseUrl - Base URL for the API (defaults to production Runtimes URL) @@ -106,7 +106,7 @@ export const deleteSnapshot = async ( validateRequiredString(snapshotId, 'Snapshot ID'); return requestDatalayerAPI({ - url: `${baseUrl}${API_BASE_PATHS.RUNTIMES}/runtime-snapshots/${snapshotId}`, + url: `${baseUrl}${API_BASE_PATHS.RUNTIMES}/sandbox-snapshots/${snapshotId}`, method: 'DELETE', token, }); diff --git a/src/client/__tests__/client.models.integration.test.ts b/src/client/__tests__/client.models.integration.test.ts index c35341f3..61236e70 100644 --- a/src/client/__tests__/client.models.integration.test.ts +++ b/src/client/__tests__/client.models.integration.test.ts @@ -9,7 +9,7 @@ import { describe, it, expect, beforeAll, afterAll } from 'vitest'; import { DatalayerClient } from '..'; import { RuntimeDTO } from '../../models/RuntimeDTO'; import { DEFAULT_SERVICE_URLS } from '../../api/constants'; -import { RuntimeSnapshotDTO } from '../../models/RuntimeSnapshotDTO'; +import { CodeSandboxSnapshotDTO } from '../../models/CodeSandboxSnapshotDTO'; import { SpaceDTO } from '../../models/SpaceDTO'; import { NotebookDTO } from '../../models/NotebookDTO'; import { LexicalDTO } from '../../models/LexicalDTO'; @@ -45,7 +45,7 @@ describe.skipIf(skipInCi)('Client Models Integration Tests', () => { let testNotebook: NotebookDTO | null = null; let testLexical: LexicalDTO | null = null; let testRuntime: RuntimeDTO | null = null; - let testSnapshot: RuntimeSnapshotDTO | null = null; + let testSnapshot: CodeSandboxSnapshotDTO | null = null; beforeAll(async () => { if (!testConfig.hasToken()) { @@ -288,7 +288,7 @@ describe.skipIf(skipInCi)('Client Models Integration Tests', () => { 'Test snapshot from model test', ); - expect(testSnapshot).toBeInstanceOf(RuntimeSnapshotDTO); + expect(testSnapshot).toBeInstanceOf(CodeSandboxSnapshotDTO); // Snapshots don't have a podName property // Instead, check that the snapshot was created successfully expect(testSnapshot.uid).toBeDefined(); @@ -297,7 +297,7 @@ describe.skipIf(skipInCi)('Client Models Integration Tests', () => { console.log(`Created snapshot ${testSnapshot.uid} from runtime`); }); - it('should list runtime snapshots', async () => { + it('should list code sandbox snapshots', async () => { if (!testRuntime) { const environmentName = await resolveEnvironmentName(client); testRuntime = await client.createRuntime( @@ -315,7 +315,7 @@ describe.skipIf(skipInCi)('Client Models Integration Tests', () => { ); } - console.log('Testing runtime snapshot listing...'); + console.log('Testing code sandbox snapshot listing...'); // List all snapshots const snapshots = await client.listSnapshots(); diff --git a/src/client/__tests__/client.runtimes.integration.test.ts b/src/client/__tests__/client.runtimes.integration.test.ts index d7be3dee..a4237ee8 100644 --- a/src/client/__tests__/client.runtimes.integration.test.ts +++ b/src/client/__tests__/client.runtimes.integration.test.ts @@ -8,7 +8,7 @@ import { describe, it, expect, beforeAll, afterAll } from 'vitest'; import { DatalayerClient } from '..'; import { RuntimeDTO } from '../../models/RuntimeDTO'; -import { RuntimeSnapshotDTO } from '../../models/RuntimeSnapshotDTO'; +import { CodeSandboxSnapshotDTO } from '../../models/CodeSandboxSnapshotDTO'; import { testConfig } from '../../__tests__/shared/test-config'; import { DEFAULT_SERVICE_URLS } from '../../api/constants'; import { performCleanup } from '../../__tests__/shared/cleanup-shared'; @@ -39,7 +39,7 @@ const resolveEnvironmentName = async ( describe.skipIf(skipInCi)('Client Runtimes Integration Tests', () => { let client: DatalayerClient; let createdRuntime: RuntimeDTO | null = null; - let createdSnapshot: RuntimeSnapshotDTO | null = null; + let createdSnapshot: CodeSandboxSnapshotDTO | null = null; const ensureRuntime = async (): Promise => { if (createdRuntime) { @@ -56,7 +56,7 @@ describe.skipIf(skipInCi)('Client Runtimes Integration Tests', () => { return createdRuntime; }; - const ensureSnapshot = async (): Promise => { + const ensureSnapshot = async (): Promise => { if (createdSnapshot) { return createdSnapshot; } @@ -219,7 +219,7 @@ describe.skipIf(skipInCi)('Client Runtimes Integration Tests', () => { 'Test snapshot from Client', ); - expect(snapshot).toBeInstanceOf(RuntimeSnapshotDTO); + expect(snapshot).toBeInstanceOf(CodeSandboxSnapshotDTO); expect(snapshot.uid).toBeDefined(); expect(snapshot.name).toContain('client-test-snapshot'); @@ -239,7 +239,7 @@ describe.skipIf(skipInCi)('Client Runtimes Integration Tests', () => { const found = snapshots.find(s => s.uid === snapshotRef.uid); expect(found).toBeDefined(); - expect(found).toBeInstanceOf(RuntimeSnapshotDTO); + expect(found).toBeInstanceOf(CodeSandboxSnapshotDTO); console.log(`Found ${snapshots.length} snapshot(s)`); console.log(`Created snapshot found in list: ${found!.uid}`); @@ -251,7 +251,7 @@ describe.skipIf(skipInCi)('Client Runtimes Integration Tests', () => { console.log('Getting snapshot details...'); const snapshot = await client.getSnapshot(snapshotRef.uid); - expect(snapshot).toBeInstanceOf(RuntimeSnapshotDTO); + expect(snapshot).toBeInstanceOf(CodeSandboxSnapshotDTO); expect(snapshot.uid).toBe(snapshotRef.uid); expect(snapshot.environment).toBe(snapshotRef.environment); diff --git a/src/client/index.ts b/src/client/index.ts index 5d4e4205..cfdd5a82 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -34,7 +34,7 @@ import type { UserDTO } from './../models/UserDTO'; import type { CreditsDTO } from '../models/CreditsDTO'; import type { EnvironmentDTO } from '../models/EnvironmentDTO'; import type { RuntimeDTO } from '../models/RuntimeDTO'; -import type { RuntimeSnapshotDTO } from '../models/RuntimeSnapshotDTO'; +import type { CodeSandboxSnapshotDTO } from '../models/CodeSandboxSnapshotDTO'; import type { SpaceDTO } from '../models/SpaceDTO'; import type { NotebookDTO } from '../models/NotebookDTO'; import type { LexicalDTO } from '../models/LexicalDTO'; @@ -124,15 +124,15 @@ export type { EnvironmentData, ListEnvironmentsResponse, } from '../models/EnvironmentDTO'; -export { RuntimeSnapshotDTO as Snapshot } from '../models/RuntimeSnapshotDTO'; +export { CodeSandboxSnapshotDTO as Snapshot } from '../models/CodeSandboxSnapshotDTO'; export type { - RuntimeSnapshotJSON, - RuntimeSnapshotData, - CreateRuntimeSnapshotRequest, - CreateRuntimeSnapshotResponse, - GetRuntimeSnapshotResponse, - ListRuntimeSnapshotsResponse, -} from '../models/RuntimeSnapshotDTO'; + CodeSandboxSnapshotJSON, + CodeSandboxSnapshotData, + CreateCodeSandboxSnapshotRequest, + CreateCodeSandboxSnapshotResponse, + GetCodeSandboxSnapshotResponse, + ListCodeSandboxSnapshotsResponse, +} from '../models/CodeSandboxSnapshotDTO'; export { SpaceDTO as Space } from '../models/SpaceDTO'; export type { SpaceJSON, @@ -246,7 +246,7 @@ export type { IRuntimeLocation, IRuntimeCapabilities, } from '../models/Runtime'; -export type { IRuntimeSnapshot } from '../models/RuntimeSnapshot'; +export type { ICodeSandboxSnapshot } from '../models/CodeSandboxSnapshot'; export type { IDatalayerEnvironment, IResources, @@ -394,9 +394,9 @@ export interface DatalayerClient { name: string, description: string, stop?: boolean, - ): Promise; - listSnapshots(): Promise; - getSnapshot(id: string): Promise; + ): Promise; + listSnapshots(): Promise; + getSnapshot(id: string): Promise; deleteSnapshot(id: string): Promise; checkRuntimesHealth(): Promise; diff --git a/src/client/mixins/RuntimesMixin.ts b/src/client/mixins/RuntimesMixin.ts index dc9ba30c..a5a6a309 100644 --- a/src/client/mixins/RuntimesMixin.ts +++ b/src/client/mixins/RuntimesMixin.ts @@ -12,11 +12,11 @@ import * as environments from '../../api/runtimes/environments'; import * as runtimes from '../../api/runtimes/runtimes'; import * as snapshots from '../../api/runtimes/snapshots'; import type { CreateRuntimeRequest } from '../../models/RuntimeDTO'; -import type { CreateRuntimeSnapshotRequest } from '../../models/RuntimeSnapshotDTO'; +import type { CreateCodeSandboxSnapshotRequest } from '../../models/CodeSandboxSnapshotDTO'; import type { Constructor } from '../utils/mixins'; import { EnvironmentDTO } from '../../models/EnvironmentDTO'; import { RuntimeDTO } from '../../models/RuntimeDTO'; -import { RuntimeSnapshotDTO } from '../../models/RuntimeSnapshotDTO'; +import { CodeSandboxSnapshotDTO } from '../../models/CodeSandboxSnapshotDTO'; import { HealthCheck } from '../../models/HealthCheck'; /** Options for ensuring a runtime is available. */ @@ -51,7 +51,7 @@ export function RuntimesMixin(Base: TBase) { } _extractSnapshotId( - snapshotIdOrInstance: string | RuntimeSnapshotDTO, + snapshotIdOrInstance: string | CodeSandboxSnapshotDTO, ): string { return typeof snapshotIdOrInstance === 'string' ? snapshotIdOrInstance @@ -212,11 +212,11 @@ export function RuntimesMixin(Base: TBase) { name: string, description: string, stop: boolean = false, - ): Promise { + ): Promise { const token = (this as any).getToken(); const runtimesRunUrl = (this as any).getRuntimesRunUrl(); - const data: CreateRuntimeSnapshotRequest = { + const data: CreateCodeSandboxSnapshotRequest = { pod_name: podName, name, description, @@ -228,19 +228,19 @@ export function RuntimesMixin(Base: TBase) { data, runtimesRunUrl, ); - return new RuntimeSnapshotDTO(response.snapshot, this as any); + return new CodeSandboxSnapshotDTO(response.snapshot, this as any); } /** - * List all runtime snapshots. + * List all code sandbox snapshots. * @returns Array of snapshots */ - async listSnapshots(): Promise { + async listSnapshots(): Promise { const token = (this as any).getToken(); const runtimesRunUrl = (this as any).getRuntimesRunUrl(); const response = await snapshots.listSnapshots(token, runtimesRunUrl); return response.snapshots.map( - s => new RuntimeSnapshotDTO(s, this as any), + s => new CodeSandboxSnapshotDTO(s, this as any), ); } @@ -249,11 +249,11 @@ export function RuntimesMixin(Base: TBase) { * @param id - Snapshot ID * @returns Snapshot details */ - async getSnapshot(id: string): Promise { + async getSnapshot(id: string): Promise { const token = (this as any).getToken(); const runtimesRunUrl = (this as any).getRuntimesRunUrl(); const response = await snapshots.getSnapshot(token, id, runtimesRunUrl); - return new RuntimeSnapshotDTO(response.snapshot, this as any); + return new CodeSandboxSnapshotDTO(response.snapshot, this as any); } /** diff --git a/src/components/billing/BillableAccountSelect.tsx b/src/components/billing/BillableAccountSelect.tsx new file mode 100644 index 00000000..81466562 --- /dev/null +++ b/src/components/billing/BillableAccountSelect.tsx @@ -0,0 +1,679 @@ +/* + * Copyright (c) 2023-2026 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +/** + * BillableAccountSelect — self-contained dropdown that lets the user pick a + * billable account (personal, organization, or eligible team) for runs that + * consume wallet credits. + * + * Encapsulates eligibility merge logic and account-detail fetching. Callers + * only need to provide a value/onChange pair and optionally observe the full + * resolved account via `onSelectedAccountChange`. + */ + +import { useCallback, useEffect, useMemo, useRef, Fragment } from 'react'; +import { + ActionList, + ActionMenu, + Button, + Flash, + FormControl, + Label, + Spinner, + Text, +} from '@primer/react'; +import { + OrganizationIcon, + PeopleIcon, + PersonIcon, +} from '@primer/octicons-react'; +import { Box } from '@datalayer/primer-addons'; +import { useCache } from '../../hooks/useCache'; +import { useSelectedPrincipal } from '../../hooks/useSelectedPrincipal'; +import { useIAMStore } from '../../state'; + +export type BillableAccountType = 'user' | 'organization' | 'team'; + +export type BillableAccount = { + accountUid: string; + accountType: BillableAccountType; + accountHandle: string; + accountName: string; + planName: string; + isEligible: boolean; + isPaidPlan: boolean; + sourceOrganizationUid?: string; + sourceOrganizationHandle?: string; + teamHandle?: string; +}; + +export type BillableAccountSelectProps = { + value: string; + onChange: (accountUid: string) => void; + onSelectedAccountChange?: (account: BillableAccount | undefined) => void; + onAccountsResolved?: (state: { + accounts: BillableAccount[]; + eligibleAccounts: BillableAccount[]; + isLoading: boolean; + hasEligibleAccount: boolean; + }) => void; + disabled?: boolean; + label?: string; + caption?: string; + emptyMessage?: string; + flashMessage?: string; + width?: string | number; + preferOrganizationDefault?: boolean; +}; + +const PLAN_FREE_TERMS = ['free', 'starter']; +const PLAN_PRO_TERMS = ['pro', 'paid', 'team', 'enterprise', 'business']; + +const BILLABLE_ACCOUNT_COOKIE = 'datalayer-billable-account-uid'; +const BILLABLE_ACCOUNT_COOKIE_MAX_AGE = 60 * 60 * 24 * 365; + +function readBillableAccountCookie(): string | null { + if (typeof document === 'undefined') return null; + const escaped = BILLABLE_ACCOUNT_COOKIE.replace( + /[.$?*|{}()[\]\\/+^]/g, + '\\$&', + ); + const match = document.cookie.match( + new RegExp('(?:^|; )' + escaped + '=([^;]*)'), + ); + return match ? decodeURIComponent(match[1]) : null; +} + +function writeBillableAccountCookie(value: string): void { + if (typeof document === 'undefined') return; + document.cookie = + `${BILLABLE_ACCOUNT_COOKIE}=${encodeURIComponent(value)};` + + ` path=/; max-age=${BILLABLE_ACCOUNT_COOKIE_MAX_AGE}; SameSite=Lax`; +} + +const planContains = (value: string, terms: string[]) => + terms.some(term => value.includes(term)); + +export function resolveBillablePlanTier(value: unknown): 'free' | 'pro' { + const normalized = String(value ?? '').toLowerCase(); + if (!normalized || normalized === 'unknown') return 'free'; + if (planContains(normalized, PLAN_FREE_TERMS)) return 'free'; + if (planContains(normalized, PLAN_PRO_TERMS)) return 'pro'; + return 'free'; +} + +export function formatBillableAccountPlanLabel(planName: string): string { + return resolveBillablePlanTier(planName) === 'pro' + ? 'Team Plan' + : 'Free Plan'; +} + +export function BillableAccountSelect({ + value, + onChange, + onSelectedAccountChange, + onAccountsResolved, + disabled = false, + label = 'Run under', + caption = 'Personal, organization, and eligible team accounts can be selected for billable assignment. For team billing, runtime runs are attributed to the parent organization while credits are consumed from the selected team wallet.', + emptyMessage = 'No billable accounts available', + flashMessage = 'Runs and credits are charged to the selected billable account. Wallet credits of that account are consumed; LLM token usage is tracked for visibility only. Accounts without an eligible plan or wallet balance are disabled.', + width = 'min(100%, 520px)', + preferOrganizationDefault = false, +}: BillableAccountSelectProps): JSX.Element { + const { user } = useIAMStore(); + const { + useEligibleSubscriptionAccounts, + useSubscriptionAccountsDetails, + useUserOrganizations, + } = useCache(); + + const { selectedPrincipalKind, selectedPrincipalUid } = + useSelectedPrincipal(); + + const userOrganizationsQuery = useUserOrganizations(); + const { data: eligibleAccountsRaw, isLoading: eligibleAccountsLoading } = + useEligibleSubscriptionAccounts({ + refetchInterval: 10_000, + refetchOnMount: true, + refetchOnWindowFocus: true, + staleTime: 0, + }); + + const eligibleAccounts = useMemo( + () => + (eligibleAccountsRaw || []).map((entry: any) => ({ + accountUid: String(entry.account_uid || ''), + accountType: String(entry.account_type || 'user'), + accountHandle: String(entry.account_handle || '').trim(), + accountName: + String(entry.account_name || '').trim() || + String(entry.account_handle || '').trim() || + String(entry.account_uid || '').trim(), + planName: String( + entry?.subscription?.plan_name || entry?.plan?.plan_name || '', + ).trim(), + })), + [eligibleAccountsRaw], + ); + + const personalAccountUid = String((user as any)?.id || ''); + const allContextAccounts = useMemo(() => { + const accountMap = new Map< + string, + { + accountUid: string; + accountType: string; + accountHandle: string; + accountName: string; + planName: string; + } + >(); + + if (personalAccountUid) { + accountMap.set(personalAccountUid, { + accountUid: personalAccountUid, + accountType: 'user', + accountHandle: String((user as any)?.handle || '').trim(), + accountName: + String((user as any)?.handle || '').trim() || personalAccountUid, + planName: '', + }); + } + + for (const organization of (userOrganizationsQuery.data || []) as any[]) { + const orgUid = String(organization?.uid || organization?.id || '').trim(); + if (!orgUid) continue; + accountMap.set(orgUid, { + accountUid: orgUid, + accountType: 'organization', + accountHandle: String(organization?.handle || '').trim(), + accountName: + String(organization?.handle || '').trim() || + String(organization?.name || '').trim() || + orgUid, + planName: String( + organization?.subscription?.plan_name || + organization?.plan_name || + '', + ).trim(), + }); + } + + return Array.from(accountMap.values()); + }, [personalAccountUid, user, userOrganizationsQuery.data]); + + const eligibleAccountByUid = useMemo( + () => new Map(eligibleAccounts.map(a => [a.accountUid, a])), + [eligibleAccounts], + ); + + const candidateUids = useMemo(() => { + const values = new Set(); + for (const a of allContextAccounts) values.add(a.accountUid); + for (const a of eligibleAccounts) values.add(a.accountUid); + if (value) values.add(value); + return Array.from(values).filter(Boolean); + }, [allContextAccounts, eligibleAccounts, value]); + + const { data: detailsRaw, isLoading: detailsLoading } = + useSubscriptionAccountsDetails(candidateUids, { + refetchInterval: 10_000, + refetchOnMount: true, + refetchOnWindowFocus: true, + staleTime: 0, + }); + + const detailsByUid = useMemo( + () => + new Map( + (detailsRaw || []).map((entry: any) => [ + String(entry.account_uid || ''), + entry, + ]), + ), + [detailsRaw], + ); + + const accounts = useMemo(() => { + const accountMap = new Map(); + for (const a of allContextAccounts) accountMap.set(a.accountUid, a); + for (const a of eligibleAccounts) accountMap.set(a.accountUid, a); + + const merged = Array.from(accountMap.values()); + const mergedByUid = new Map(merged.map(a => [a.accountUid, a])); + + return merged.map(account => { + const eligible = eligibleAccountByUid.get(account.accountUid); + const details = detailsByUid.get(account.accountUid); + const accountType = String( + details?.account_type || account.accountType || 'user', + ) as BillableAccountType; + const accountHandle = String( + details?.account_handle || account.accountHandle || '', + ); + const planName = String( + details?.subscription?.plan_name || + eligible?.planName || + account.planName || + '', + ).trim(); + + const walletBalance = Number( + accountType === 'team' + ? (details?.wallet_balance ?? 0) + : (details?.subscription?.wallet_balance ?? + details?.wallet_balance ?? + 0), + ); + const hasPositiveWallet = + Number.isFinite(walletBalance) && walletBalance > 0; + + const isEligible = + accountType === 'team' + ? hasPositiveWallet + : typeof details?.is_eligible === 'boolean' + ? details.is_eligible || + (accountType === 'user' && hasPositiveWallet) + : Boolean(eligible); + + const sourceOrganizationUid = + accountType === 'team' + ? String(details?.plan_source_account_uid || '').trim() || undefined + : undefined; + const sourceOrgDetails = sourceOrganizationUid + ? detailsByUid.get(sourceOrganizationUid) + : undefined; + const sourceOrgMerged = sourceOrganizationUid + ? mergedByUid.get(sourceOrganizationUid) + : undefined; + const sourceOrganizationHandle = sourceOrganizationUid + ? String( + sourceOrgDetails?.account_handle || + sourceOrgMerged?.accountHandle || + '', + ).trim() || undefined + : undefined; + + return { + accountUid: account.accountUid, + accountType, + accountHandle, + accountName: String( + details?.account_name || account.accountName || account.accountUid, + ), + planName, + isEligible, + isPaidPlan: resolveBillablePlanTier(planName || 'free') === 'pro', + sourceOrganizationUid, + sourceOrganizationHandle, + teamHandle: accountType === 'team' ? accountHandle : undefined, + }; + }); + }, [ + allContextAccounts, + eligibleAccounts, + eligibleAccountByUid, + detailsByUid, + ]); + + const eligibleBillable = useMemo( + () => accounts.filter(a => a.isEligible), + [accounts], + ); + const hasEligibleAccount = eligibleBillable.length > 0; + const isLoading = eligibleAccountsLoading || detailsLoading; + + const storedBillableAccountUid = useMemo( + () => readBillableAccountCookie(), + [], + ); + + const preferredEligible = useMemo(() => { + const fromCookie = storedBillableAccountUid + ? eligibleBillable.find( + account => account.accountUid === storedBillableAccountUid, + ) + : undefined; + if (fromCookie) return fromCookie; + + const personalEligible = eligibleBillable.find( + a => a.accountType === 'user' && a.accountUid === personalAccountUid, + ); + if (personalEligible) return personalEligible; + + const byPrincipal = selectedPrincipalUid + ? eligibleBillable.find(account => { + if (account.accountUid !== selectedPrincipalUid) return false; + if (selectedPrincipalKind === 'organization') + return account.accountType === 'organization'; + if (selectedPrincipalKind === 'team') + return account.accountType === 'team'; + return account.accountType === 'user'; + }) + : undefined; + if (byPrincipal) return byPrincipal; + + const firstOrg = eligibleBillable.find( + a => a.accountType === 'organization', + ); + if (preferOrganizationDefault && firstOrg) return firstOrg; + return firstOrg || eligibleBillable[0]; + }, [ + storedBillableAccountUid, + eligibleBillable, + personalAccountUid, + preferOrganizationDefault, + selectedPrincipalKind, + selectedPrincipalUid, + ]); + + const handleAccountSelect = useCallback( + (accountUid: string) => { + writeBillableAccountCookie(accountUid); + onChange(accountUid); + }, + [onChange], + ); + + // Apply persisted selection (cookie) once accounts are resolved. Runs once + // per mount so users can still change selection afterwards; falls back to + // personal user account when the stored uid is unknown/ineligible. + const initialSelectionAppliedRef = useRef(false); + useEffect(() => { + if (isLoading) return; + if (initialSelectionAppliedRef.current) return; + if (eligibleBillable.length === 0 && accounts.length === 0) return; + initialSelectionAppliedRef.current = true; + + const storedAccount = storedBillableAccountUid + ? eligibleBillable.find(a => a.accountUid === storedBillableAccountUid) + : undefined; + const personalAccount = eligibleBillable.find( + a => a.accountType === 'user' && a.accountUid === personalAccountUid, + ); + const target = storedAccount || personalAccount || preferredEligible; + if (!target) return; + if (target.accountUid !== value) { + writeBillableAccountCookie(target.accountUid); + onChange(target.accountUid); + } + }, [ + isLoading, + accounts, + eligibleBillable, + storedBillableAccountUid, + personalAccountUid, + preferredEligible, + value, + onChange, + ]); + + // Auto-select a sensible default when current value is empty/ineligible. + useEffect(() => { + if (isLoading) return; + if (!initialSelectionAppliedRef.current) return; + if (!preferredEligible) { + if (value) onChange(''); + return; + } + const current = accounts.find(a => a.accountUid === value); + if (!current || !current.isEligible) { + writeBillableAccountCookie(preferredEligible.accountUid); + onChange(preferredEligible.accountUid); + } + }, [isLoading, preferredEligible, accounts, value, onChange]); + + const selectedAccount = useMemo( + () => accounts.find(a => a.accountUid === value), + [accounts, value], + ); + + useEffect(() => { + onSelectedAccountChange?.(selectedAccount); + }, [selectedAccount, onSelectedAccountChange]); + + useEffect(() => { + onAccountsResolved?.({ + accounts, + eligibleAccounts: eligibleBillable, + isLoading, + hasEligibleAccount, + }); + }, [ + accounts, + eligibleBillable, + isLoading, + hasEligibleAccount, + onAccountsResolved, + ]); + + return ( + + {label} + + + + + + + + {flashMessage} + + + {isLoading ? ( + + + + Loading plan status... + + + ) : !hasEligibleAccount ? ( + {emptyMessage} + ) : ( + (() => { + const typeOrder: Record = { + user: 0, + organization: 1, + team: 2, + }; + const sorted = [...accounts].sort( + (a, b) => + (typeOrder[a.accountType] ?? 99) - + (typeOrder[b.accountType] ?? 99), + ); + return sorted.map((account, idx) => { + const prevType = + idx > 0 ? sorted[idx - 1].accountType : undefined; + const showDivider = + prevType !== undefined && + prevType !== account.accountType; + return ( + + {showDivider && } + { + if (account.isEligible) { + handleAccountSelect(account.accountUid); + } + }} + > + + + @{account.accountName} + + + + {!account.isEligible && + account.accountType === 'team' && ( + + )} + + + + + {account.isEligible + ? 'Eligible' + : account.accountType === 'team' + ? 'Not eligible — no team credits allocated' + : 'Not eligible — activate a plan or add credits to use this account'} + + + + ); + }); + })() + )} + + + + + {caption} + + ); +} + +export default BillableAccountSelect; diff --git a/src/components/billing/index.ts b/src/components/billing/index.ts new file mode 100644 index 00000000..c6ecd8b9 --- /dev/null +++ b/src/components/billing/index.ts @@ -0,0 +1,6 @@ +/* + * Copyright (c) 2023-2026 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +export * from './BillableAccountSelect'; diff --git a/src/components/checkout/StripeCheckout.tsx b/src/components/checkout/StripeCheckout.tsx index 7006f81b..e647dc07 100644 --- a/src/components/checkout/StripeCheckout.tsx +++ b/src/components/checkout/StripeCheckout.tsx @@ -58,6 +58,10 @@ export interface IPrice { * Computational credits to receive */ credits: number; + /** + * Whether this price is the server-selected default option + */ + default?: boolean; } export interface ISubscriptionPlan { @@ -69,11 +73,23 @@ export interface ISubscriptionPlan { included_runs?: number; } +type TopUpConfirmation = { + purchasedCredits: number; + oldWalletBalance: number; + newWalletBalance: number; + oldAvailableCredits: number; + newAvailableCredits: number; +}; + export type StripeCheckoutProps = { checkoutPortal: ICheckoutPortal | null; appearance?: StripeElementsOptions['appearance']; accountUid?: string; showStatusUsageSummary?: boolean; + onCheckoutSuccess?: (event: { + checkoutType: 'topup' | 'subscription' | 'resume'; + purchasedCredits?: number; + }) => void; }; const PLAN_INCLUDED_RUNS_DEFAULTS: Record = { @@ -311,12 +327,12 @@ export function StripeCheckout({ checkoutPortal, appearance, accountUid, - showStatusUsageSummary = true, + showStatusUsageSummary = false, + onCheckoutSuccess, }: StripeCheckoutProps) { const { useCreateTopUpPaymentIntent, useCreateSubscriptionPaymentIntent, - useCreateResumeSetupIntent, useSubscriptionPlans, useTopUpPrices, useSubscriptionStatus, @@ -335,11 +351,37 @@ export function StripeCheckout({ 'topup' | 'subscription' | 'resume' >('topup'); const [cancelViewOpen, setCancelViewOpen] = useState(false); + const [isConfirmingCancel, setIsConfirmingCancel] = useState(false); + const [isResumingTransition, setIsResumingTransition] = useState(false); const [paymentMessage, setPaymentMessage] = useState(null); + const [resumeConfirmationMessage, setResumeConfirmationMessage] = useState< + string | null + >(null); + const [isReturningFromCheckout, setIsReturningFromCheckout] = useState(false); + const [topUpConfirmation, setTopUpConfirmation] = + useState(null); + const [pendingTopUpTarget, setPendingTopUpTarget] = useState<{ + targetWalletBalance: number; + } | null>(null); + const topUpPurchaseRef = useRef<{ + purchasedCredits: number; + oldWalletBalance: number; + oldAvailableCredits: number; + } | null>(null); // Get Stripe prices using TanStack Query hook - const { data: pricesData } = useTopUpPrices(); - const items = (pricesData as IPrice[] | undefined) ?? null; + const { + data: pricesData, + isPending: isTopUpPricesPending, + isError: isTopUpPricesError, + error: topUpPricesError, + } = useTopUpPrices(); + const items = useMemo(() => { + if (Array.isArray(pricesData)) { + return pricesData as IPrice[]; + } + return []; + }, [pricesData]); const sortedTopUpItems = useMemo( () => [...(items ?? [])].sort( @@ -368,7 +410,6 @@ export function StripeCheckout({ const subscriptionPaymentIntentMutation = useCreateSubscriptionPaymentIntent({ accountUid, }); - const resumeSetupIntentMutation = useCreateResumeSetupIntent({ accountUid }); // Load stripe API useEffect(() => { @@ -407,12 +448,14 @@ export function StripeCheckout({ setProduct(null); setSubscriptionPlan(null); setPaymentMessage(null); + setIsReturningFromCheckout(true); if (checkoutType === 'resume') { try { const resp = await resumeSubscriptionMutation.mutateAsync(); setPaymentMessage( resp?.message || 'Payment confirmed and plan resumed successfully.', ); + onCheckoutSuccess?.({ checkoutType: 'resume' }); } catch (error) { setPaymentMessage( error instanceof Error @@ -420,6 +463,7 @@ export function StripeCheckout({ : 'Payment confirmed, but unable to resume your plan right now.', ); } + setIsReturningFromCheckout(false); return; } if (checkoutType === 'subscription') { @@ -438,14 +482,55 @@ export function StripeCheckout({ setPaymentMessage( 'Plan payment confirmed. Your plan status may take a few seconds to refresh.', ); + onCheckoutSuccess?.({ checkoutType: 'subscription' }); } else { + const topUpPurchase = topUpPurchaseRef.current; + const purchasedCredits = topUpPurchase?.purchasedCredits || 0; + if (topUpPurchase && topUpPurchase.purchasedCredits > 0) { + const targetWalletBalance = + topUpPurchase.oldWalletBalance + topUpPurchase.purchasedCredits; + setTopUpConfirmation({ + purchasedCredits: topUpPurchase.purchasedCredits, + oldWalletBalance: topUpPurchase.oldWalletBalance, + newWalletBalance: targetWalletBalance, + oldAvailableCredits: topUpPurchase.oldAvailableCredits, + newAvailableCredits: + topUpPurchase.oldAvailableCredits + topUpPurchase.purchasedCredits, + }); + setPendingTopUpTarget({ + targetWalletBalance, + }); + } + + for (let attempt = 0; attempt < 5; attempt += 1) { + try { + await refetchSubscriptionStatus(); + } catch { + // Keep confirmation visible even if refresh fails transiently. + } + if (attempt < 4) { + await new Promise(resolve => setTimeout(resolve, 800)); + } + } + setPaymentMessage( 'Payment confirmed. Credits update may take a few seconds.', ); + onCheckoutSuccess?.({ + checkoutType: 'topup', + purchasedCredits, + }); + topUpPurchaseRef.current = null; } - }, [checkoutType, refetchSubscriptionStatus, resumeSubscriptionMutation]); + setIsReturningFromCheckout(false); + }, [ + checkoutType, + onCheckoutSuccess, + refetchSubscriptionStatus, + resumeSubscriptionMutation, + ]); - const subscription = subscriptionResp?.subscription || null; + const subscription = subscriptionResp?.plan || null; const availablePlans = useMemo(() => { const byId = new Map(); const add = (plan: any) => { @@ -466,9 +551,9 @@ export function StripeCheckout({ }); }; plans.forEach(add); - (subscriptionResp?.available_subscriptions || []).forEach(add); + (subscriptionResp?.available_plans || []).forEach(add); return Array.from(byId.values()); - }, [plans, subscriptionResp?.available_subscriptions]); + }, [plans, subscriptionResp?.available_plans]); const subscriptionStatus = subscription?.status || 'unknown'; const normalizedSubscriptionStatus = String(subscriptionStatus).toLowerCase(); @@ -638,6 +723,12 @@ export function StripeCheckout({ const walletBalance = walletIsQuota ? Math.max(0, remainingCredits) : Math.max(0, walletBalanceRaw); + const displayedWalletBalance = pendingTopUpTarget + ? Math.max(walletBalance, pendingTopUpTarget.targetWalletBalance) + : walletBalance; + const displayedAvailableCredits = pendingTopUpTarget + ? Math.max(remainingCredits, pendingTopUpTarget.targetWalletBalance) + : remainingCredits; const isRunsOverQuota = runsTotal > 0 && usedRuns > runsTotal; const hasBillablePlan = useMemo(() => { @@ -686,6 +777,18 @@ export function StripeCheckout({ return !nonCancelable; }, [hasBillablePlan, subscriptionStatus, isCancellationScheduled]); + const isCancelActionPending = + cancelSubscriptionMutation.isPending || isConfirmingCancel; + const isResumeActionPending = + resumeSubscriptionMutation.isPending || isResumingTransition; + const showResumeAction = isCancellationScheduled && !isCancelActionPending; + + useEffect(() => { + if (isResumingTransition && !isCancellationScheduled) { + setIsResumingTransition(false); + } + }, [isCancellationScheduled, isResumingTransition]); + useEffect(() => { if (isPaidSubscription && paymentMessage) { setPaymentMessage(null); @@ -700,10 +803,21 @@ export function StripeCheckout({ useEffect(() => { if (!product && sortedTopUpItems.length > 0) { - setProduct(sortedTopUpItems[sortedTopUpItems.length - 1]); + const secondCard = + sortedTopUpItems.length > 1 ? sortedTopUpItems[1] : sortedTopUpItems[0]; + setProduct(secondCard); } }, [product, sortedTopUpItems]); + useEffect(() => { + if (!pendingTopUpTarget) { + return; + } + if (walletBalance >= pendingTopUpTarget.targetWalletBalance) { + setPendingTopUpTarget(null); + } + }, [pendingTopUpTarget, walletBalance]); + // Auto-open the in-app cancel/downgrade view when the page is opened with // `?action=downgrade` (e.g. from the Plan Overview "Downgrade" CTA). // When opened with `?action=resume`, immediately trigger the resume flow. @@ -734,6 +848,12 @@ export function StripeCheckout({ if (!product) { return; } + topUpPurchaseRef.current = { + purchasedCredits: Math.max(0, Number(product.credits || 0)), + oldWalletBalance: displayedWalletBalance, + oldAvailableCredits: displayedAvailableCredits, + }; + setTopUpConfirmation(null); setPaymentMessage(null); setCheckoutType('topup'); setCheckout(true); @@ -755,11 +875,17 @@ export function StripeCheckout({ error instanceof Error ? error.message : 'Unable to initialize Stripe checkout. Please try again.'; + topUpPurchaseRef.current = null; setPaymentClientSecret(null); setCheckout(false); setPaymentMessage(detail); } - }, [topUpPaymentIntentMutation, product]); + }, [ + displayedAvailableCredits, + displayedWalletBalance, + topUpPaymentIntentMutation, + product, + ]); const startSubscriptionCheckout = useCallback( async (planOverride?: ISubscriptionPlan | null) => { @@ -831,6 +957,7 @@ export function StripeCheckout({ const onCancelSubscription = useCallback(() => { setPaymentMessage(null); + setResumeConfirmationMessage(null); setCancelViewOpen(true); }, []); @@ -840,6 +967,7 @@ export function StripeCheckout({ const onConfirmCancelSubscription = useCallback(async () => { setPaymentMessage(null); + setIsConfirmingCancel(true); try { const resp = await cancelSubscriptionMutation.mutateAsync(); if (resp?.success === false) { @@ -848,19 +976,6 @@ export function StripeCheckout({ ); } - // Refresh plan status so stale "incomplete" snapshots disappear - // as soon as cancellation is applied upstream. - for (let attempt = 0; attempt < 5; attempt += 1) { - try { - await refetchSubscriptionStatus(); - } catch { - // Ignore transient refetch errors and keep trying. - } - if (attempt < 4) { - await new Promise(resolve => setTimeout(resolve, 800)); - } - } - const responseStatus = String(resp?.status || '').toLowerCase(); const responseCancelAtPeriodEnd = Boolean(resp?.cancel_at_period_end); const isNowCanceled = @@ -878,7 +993,23 @@ export function StripeCheckout({ 'Plan change requested successfully.', ); setCancelViewOpen(false); + setIsConfirmingCancel(false); + + // Refresh plan status in the background so UI feedback is immediate. + void (async () => { + for (let attempt = 0; attempt < 5; attempt += 1) { + try { + await refetchSubscriptionStatus(); + } catch { + // Ignore transient refetch errors and keep trying. + } + if (attempt < 4) { + await new Promise(resolve => setTimeout(resolve, 800)); + } + } + })(); } catch (error) { + setIsConfirmingCancel(false); setPaymentMessage( error instanceof Error ? error.message @@ -893,28 +1024,54 @@ export function StripeCheckout({ const onResumeSubscription = useCallback(async () => { setPaymentMessage(null); + setResumeConfirmationMessage(null); + setIsResumingTransition(true); try { - const clientSecret = await resumeSetupIntentMutation.mutateAsync(); - if (!clientSecret) { - setCheckout(false); - setPaymentClientSecret(null); - setPaymentMessage( - 'Unable to initialize Stripe checkout. Please try again.', + const resp = await resumeSubscriptionMutation.mutateAsync(); + if (resp?.success === false) { + throw new Error( + resp?.message || 'Unable to resume your plan right now.', ); - return; } - setCheckoutType('resume'); - setPaymentClientSecret(clientSecret); - setCheckout(true); + + setCheckout(false); + setPaymentClientSecret(null); setPaymentMessage(null); + const periodEndText = + subscriptionPeriodEndLabel && subscriptionPeriodEndLabel !== 'N/A' + ? ` through ${subscriptionPeriodEndLabel}` + : ''; + setResumeConfirmationMessage( + `Resume complete. Your plan remains active${periodEndText} and will renew automatically after that date.`, + ); + setIsResumingTransition(false); + + // Refresh plan status in the background so success feedback appears fast. + void (async () => { + for (let attempt = 0; attempt < 5; attempt += 1) { + try { + await refetchSubscriptionStatus(); + } catch { + // Ignore transient refetch errors and keep trying. + } + if (attempt < 4) { + await new Promise(resolve => setTimeout(resolve, 800)); + } + } + })(); } catch (error) { + setIsResumingTransition(false); setPaymentMessage( error instanceof Error ? error.message - : 'Unable to initialize resume checkout right now.', + : 'Unable to resume your plan right now.', ); } - }, [resumeSetupIntentMutation]); + }, [ + refetchSubscriptionStatus, + resumeSubscriptionMutation, + subscriptionPeriodEndLabel, + ]); const onRefreshSubscriptionStatus = useCallback(async () => { setPaymentMessage(null); @@ -947,10 +1104,6 @@ export function StripeCheckout({ return `${product.name} (${amount}, ${product.credits} credits)`; } - if (checkoutType === 'resume') { - return 'Plan resume (card update required)'; - } - return null; }, [checkoutType, product, subscriptionPlan]); @@ -960,7 +1113,10 @@ export function StripeCheckout({ marginBottom: 'var(--stack-gap-normal)', } as const; - const monthlySubscriptionSection = ( + const shouldShowMonthlySubscriptionSection = + !isPaidSubscription || isIncompleteSubscription; + + const monthlySubscriptionSection = shouldShowMonthlySubscriptionSection ? ( {isIncompleteSubscription ? ( - - A pending plan change already exists. Complete payment or cancel it - from the billing portal before creating a new one. - + <> + + A pending plan change already exists. Complete payment or cancel it + from the billing portal before creating a new one. + + + + + + ) : !isPaidSubscription ? ( <> - ) : ( - - {isCancellationScheduled - ? `Your monthly plan will cancel on ${subscriptionPeriodEndLabel}.` - : 'Your monthly plan is active. You can manage plan details from plan controls.'} - - )} + ) : null} - ); + ) : null; const topUpSection = ( @@ -1130,344 +1307,523 @@ export function StripeCheckout({ ? 'Preparing top-up checkout...' : 'Checkout'} + {topUpConfirmation ? ( + + + Top-up confirmed: + + {topUpConfirmation.purchasedCredits.toLocaleString()} credits + + + {`Wallet balance: ${topUpConfirmation.oldWalletBalance.toLocaleString()} to ${topUpConfirmation.newWalletBalance.toLocaleString()}`} + + + {`Available credits: ${topUpConfirmation.oldAvailableCredits.toLocaleString()} to ${topUpConfirmation.newAvailableCredits.toLocaleString()}`} + + + ) : null} ); - const topCards = showStatusUsageSummary ? ( - + const topCards = + showStatusUsageSummary && !isPaidSubscription ? ( - - - Plan status - - Plan: {String(currentSubscriptionPlan)} - {isPendingSubscriptionCheckout && ( - + + - Upgrade pending payment. Your Team plan is not active until card - payment succeeds. - - )} - {currentPlanPriceLabel !== 'N/A' && ( - Price: {currentPlanPriceLabel} - )} - {displaySubscriptionStatus && ( - - Status: {displaySubscriptionStatus} + Plan status - )} - + Plan: {String(currentSubscriptionPlan)} + {isPendingSubscriptionCheckout && ( + + Upgrade pending payment. Your Team plan is not active until card + payment succeeds. + + )} + {currentPlanPriceLabel !== 'N/A' && ( + Price: {currentPlanPriceLabel} + )} + {displaySubscriptionStatus && ( + + Status: {displaySubscriptionStatus} + + )} - - Current usage - - - - - - - Runs: {usedRuns.toLocaleString()} / {runsTotal.toLocaleString()} - - - - - - - - - Used in quota - - - - Remaining - - - - Over quota - + + Current usage + + - - {periodProgress ? ( - Usage period days: {periodProgress.elapsedDays} /{' '} - {periodProgress.totalDays} + Runs: {usedRuns.toLocaleString()} /{' '} + {runsTotal.toLocaleString()} + - - {periodProgress.remainingDays} day(s) remaining in current - period - + + + Used in quota + + + + Remaining + + + + Over quota + + - ) : null} - - - Wallet balance: {walletBalance.toLocaleString()} - - - Spent credits in current period:{' '} - {usedCredits.toLocaleString(undefined, { - minimumFractionDigits: 2, - maximumFractionDigits: 2, - })} - - - Wallet credits are additive on renewal and top-ups. - + {periodProgress ? ( + + + Usage period days: {periodProgress.elapsedDays} /{' '} + {periodProgress.totalDays} + + + + + + + {periodProgress.remainingDays} day(s) remaining in current + period + + + ) : null} + + + + Wallet balance: {displayedWalletBalance.toLocaleString()} + + + Spent credits in current period:{' '} + {usedCredits.toLocaleString(undefined, { + minimumFractionDigits: 2, + maximumFractionDigits: 2, + })} + + + Wallet credits are additive on renewal and top-ups. + + - - {isCancellationScheduled && ( - + Plan will switch to Free at the end of the current period on{' '} + {subscriptionPeriodEndLabel}. + + )} + - Plan will switch to Free at the end of the current period on{' '} - {subscriptionPeriodEndLabel}. - - )} - - {subscriptionPortalUrl && ( + {subscriptionPortalUrl && ( + + )} - )} - - {canCancelSubscription && !cancelViewOpen && ( - - )} - {isIncompleteSubscription && !cancelViewOpen && ( - <> + {canCancelSubscription && !cancelViewOpen && ( + + )} + {isIncompleteSubscription && !cancelViewOpen && ( + <> + + + + )} + {showResumeAction && ( - - - )} - {isCancellationScheduled && ( - + + {isIncompleteSubscription + ? 'Cancel pending plan change' + : 'Downgrade to Free Plan'} + + + {isIncompleteSubscription + ? 'This pending plan change will be canceled immediately.' + : 'Your plan will switch at the end of the current usage period.'} + + + + + + )} + + + ) : null; + + const currentPlanSection = isPaidSubscription ? ( + + + + Current plan + + + {String(currentSubscriptionPlan)} + + + You are currently on {String(currentSubscriptionPlan)}. + + {currentPlanPriceLabel !== 'N/A' && ( + + {currentPlanPriceLabel} + + )} + {displaySubscriptionStatus && ( + + + + )} + + {isCancellationScheduled ? ( + + Your downgrade to Free Plan is scheduled at period end on{' '} + {subscriptionPeriodEndLabel}. + + ) : null} + + {!isCancellationScheduled || showResumeAction ? ( - Next step:{' '} - {isCancellationScheduled - ? 'Your plan is already scheduled to switch at period end. You can keep using it until then.' - : isIncompleteSubscription - ? 'Your payment is pending. Open the in-app cancel view below to cancel this plan change or continue with payment.' - : isPaidSubscription - ? 'Keep your plan active. You can top-up credits any time.' - : 'Top-up credits are available on Free and Team plans.'} + {showResumeAction + ? 'Possible action: Resume Team Plan.' + : 'Possible action: Downgrade to Free Plan.'} - {cancelViewOpen && ( + ) : null} + + + {canCancelSubscription && !cancelViewOpen && ( + + )} + {showResumeAction && ( + + )} + + + {cancelViewOpen && ( + + + Downgrade to Free Plan + + + Your plan will switch at the end of the current usage period. + - - {isIncompleteSubscription - ? 'Cancel pending plan change' - : 'Downgrade to Free Plan'} - - - {isIncompleteSubscription - ? 'This pending plan change will be canceled immediately.' - : 'Your plan will switch at the end of the current usage period.'} - - void onConfirmCancelSubscription()} + disabled={isCancelActionPending} + leadingVisual={() => + isCancelActionPending ? : undefined + } > - - - + {isCancelActionPending + ? 'Waiting for confirmation...' + : 'Confirm downgrade'} + + - )} - + + )} ) : null; @@ -1540,13 +1896,7 @@ export function StripeCheckout({ 'Cancel', ), ), - checkoutType === 'resume' - ? createElement( - Flash, - { variant: 'warning' }, - 'Enter a new payment card to resume your plan.', - ) - : null, + null, createElement( Elements, { @@ -1603,7 +1953,47 @@ export function StripeCheckout({ ); } - } else if (items) { + } else if (isReturningFromCheckout) { + view = ( + + + + Refreshing plan status… + + {disabledTopCards} + + ); + } else if (isTopUpPricesPending) { + view = ( + + + + ); + } else if (isTopUpPricesError) { + view = ( + + {topCards} + + {topUpPricesError instanceof Error + ? topUpPricesError.message + : 'Unable to fetch the available products. Please try again later.'} + + + ); + } else { view = items.length ? ( + {shouldShowMonthlySubscriptionSection ? ( + + {monthlySubscriptionSection} + + ) : null} + {currentPlanSection} - {monthlySubscriptionSection} - - {topUpSection} @@ -1656,7 +2056,12 @@ export function StripeCheckout({ ) : null} - {paymentMessage && ( + {resumeConfirmationMessage && ( + + {resumeConfirmationMessage} + + )} + {paymentMessage && !resumeConfirmationMessage && ( {paymentMessage} @@ -1665,10 +2070,18 @@ export function StripeCheckout({ ) : ( + {resumeConfirmationMessage && ( + + {resumeConfirmationMessage} + + )} + {paymentMessage && !resumeConfirmationMessage && ( + + {paymentMessage} + + )} {topCards} - - Unable to fetch the available products. Please try again later. - + No products are available yet. ); } diff --git a/src/components/display/JupyterDialog.tsx b/src/components/display/JupyterDialog.tsx index 49dd436f..0ec4d4a9 100644 --- a/src/components/display/JupyterDialog.tsx +++ b/src/components/display/JupyterDialog.tsx @@ -14,7 +14,7 @@ import { ReactWidget } from '@jupyterlab/ui-components'; import { PromiseDelegate } from '@lumino/coreutils'; import { Widget } from '@lumino/widgets'; import { FocusKeys } from '@primer/behaviors'; -import { Checkbox, FormControl, useFocusZone } from '@primer/react'; +import { Checkbox, FormControl, Spinner, useFocusZone } from '@primer/react'; import { DialogButtonProps, DialogProps, @@ -93,6 +93,13 @@ export interface IDialogWrapperOptions { * The top level text for the dialog. */ title: string; + /** + * Optional async hook called before an accept button closes the dialog. + * Return false to keep the dialog open. + */ + onWillAccept: ( + result: Dialog.IResult, + ) => Promise | boolean | void; } /** @@ -106,6 +113,10 @@ export class JupyterDialog extends ReactWidget { protected buttons: Dialog.IButton[]; protected host: HTMLElement; protected dialogTitle?: string; + protected onWillAccept?: ( + result: Dialog.IResult, + ) => Promise | boolean | void; + private _pendingButtonIndex: number | null = null; private _closing = new PromiseDelegate(); private _result: Dialog.IResult = { button: null as any, @@ -126,6 +137,7 @@ export class JupyterDialog extends ReactWidget { Dialog.okButton(), ]; this.dialogTitle = options.title; + this.onWillAccept = options.onWillAccept; } private _renderBody = (props: PropsWithChildren) => ( @@ -156,7 +168,11 @@ export class JupyterDialog extends ReactWidget { {this.dialogTitle} } - onClose={this.close} + onClose={() => { + if (this._pendingButtonIndex === null) { + this.close(); + } + }} renderBody={this._renderBody} renderFooter={this._renderFooter} footerButtons={this.buttons.map((but, idx) => { @@ -170,8 +186,14 @@ export class JupyterDialog extends ReactWidget { onClick: () => { this.handleButton(idx); }, - content: but.label, + content: + this._pendingButtonIndex === idx ? ( + + ) : ( + but.label + ), 'aria-label': but.ariaLabel, + disabled: this._pendingButtonIndex !== null, autoFocus: but.accept, }; return footerButton; @@ -192,8 +214,29 @@ export class JupyterDialog extends ReactWidget { return this._result; } - protected handleButton = (idx: number): void => { - this.setButton(this.buttons[idx]); + protected handleButton = async (idx: number): Promise => { + if (this._pendingButtonIndex !== null) { + return; + } + const button = this.buttons[idx]; + this.setButton(button); + if (button.accept && this.onWillAccept) { + this._pendingButtonIndex = idx; + this.update(); + try { + const shouldClose = await this.onWillAccept(this._result); + if (shouldClose === false) { + this._pendingButtonIndex = null; + this.update(); + return; + } + } catch (error) { + this._pendingButtonIndex = null; + this.update(); + throw error; + } + } + this._pendingButtonIndex = null; this.close(); }; @@ -214,6 +257,9 @@ export class JupyterDialog extends ReactWidget { }; close = (): void => { + if (this._pendingButtonIndex !== null) { + return; + } Widget.detach(this); this._closing.resolve(); }; diff --git a/src/components/index.ts b/src/components/index.ts index 85a005b5..d292b021 100644 --- a/src/components/index.ts +++ b/src/components/index.ts @@ -4,4 +4,6 @@ */ export * from './auth'; +export * from './billing'; +export * from './sharing'; export * from './sparklines'; diff --git a/src/components/principal/Principal.tsx b/src/components/principal/Principal.tsx new file mode 100644 index 00000000..a4607f5e --- /dev/null +++ b/src/components/principal/Principal.tsx @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2023-2025 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +/** + * Principal – common, tunable display for an actor (user / team / + * organization). Combines a {@link PrincipalAvatar} with a + * {@link PrincipalDetailsOverlay} so all spots that need to show + * "avatar + clickable name with details overlay" can share a single + * component. + */ + +import * as React from 'react'; +import { Box } from '@datalayer/primer-addons'; +import { useCache } from '../../hooks'; +import { PrincipalAvatar, PrincipalAvatarKind } from './PrincipalAvatar'; +import { PrincipalDetailsOverlay } from './PrincipalDetailsOverlay'; + +type PrincipalKind = PrincipalAvatarKind; + +/** + * Normalised actor descriptor used by all caching resolvers. Views are + * expected to produce one of these out of their raw API data so the + * common component can render consistently. + */ +export type PrincipalDescriptor = { + kind: PrincipalKind; + uid?: string; + displayName: string; + handle?: string; + accountHandle?: string; + firstName?: string; + lastName?: string; + email?: string; + origin?: string; + avatarUrl?: string; +}; + +export type PrincipalProps = { + principal: PrincipalDescriptor; + isAdmin?: boolean; + avatarSize?: number; + gap?: number; + square?: boolean; + sx?: any; +}; + +export const Principal: React.FC = ({ + principal, + isAdmin = false, + avatarSize = 20, + gap = 2, + square = false, + sx, +}) => { + const { useUser, useOrganization } = useCache(); + + const hydratedUserQuery = useUser( + principal.kind === 'user' ? String(principal.uid || '') : '', + ); + const hydratedOrgQuery = useOrganization( + principal.kind === 'organization' ? String(principal.uid || '') : '', + ); + + const hydratedEntity = + principal.kind === 'user' + ? hydratedUserQuery.data + : principal.kind === 'organization' + ? hydratedOrgQuery.data + : undefined; + + const hydratedDisplayName = + principal.kind === 'user' + ? String( + (hydratedEntity as any)?.displayName || + [ + (hydratedEntity as any)?.firstName, + (hydratedEntity as any)?.lastName, + ] + .filter(Boolean) + .join(' ') || + '', + ).trim() + : String( + (hydratedEntity as any)?.displayName || + (hydratedEntity as any)?.name || + '', + ).trim(); + + const resolvedPrincipal: PrincipalDescriptor = { + ...principal, + displayName: + hydratedDisplayName || + principal.displayName || + principal.handle || + principal.uid || + 'Unknown', + handle: + principal.handle || + String((hydratedEntity as any)?.handle || '').trim() || + undefined, + accountHandle: + principal.accountHandle || + String((hydratedEntity as any)?.handle || '').trim() || + undefined, + avatarUrl: + principal.avatarUrl || (hydratedEntity as any)?.avatarUrl || undefined, + firstName: + principal.firstName || (hydratedEntity as any)?.firstName || undefined, + lastName: + principal.lastName || (hydratedEntity as any)?.lastName || undefined, + email: principal.email || (hydratedEntity as any)?.email || undefined, + origin: principal.origin || (hydratedEntity as any)?.origin || undefined, + }; + + return ( + + + + + ); +}; + +export default Principal; diff --git a/src/components/principal/PrincipalAvatar.tsx b/src/components/principal/PrincipalAvatar.tsx new file mode 100644 index 00000000..b28b7ab9 --- /dev/null +++ b/src/components/principal/PrincipalAvatar.tsx @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2023-2025 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +import { Box, useColorPalette } from '@datalayer/primer-addons'; +import { OrganizationIcon, PeopleIcon } from '@primer/octicons-react'; +import { AlienIcon } from '@datalayer/icons-react'; +import { DLAvatar } from '../avatars'; + +export type PrincipalAvatarKind = 'user' | 'team' | 'organization'; + +export type PrincipalAvatarProps = { + kind: PrincipalAvatarKind; + avatarUrl?: string; + alt?: string; + size?: number; + square?: boolean; +}; + +function hasRealAvatar(url?: string): boolean { + if (!url) { + return false; + } + if (url.startsWith('https://www.gravatar.com/avatar')) { + return false; + } + return true; +} + +function getFallbackIconSize(size: number): number { + return Math.max(12, Math.round(size * 0.62)); +} + +export function PrincipalAvatar({ + kind, + avatarUrl, + alt, + size = 20, + square = false, +}: PrincipalAvatarProps): JSX.Element { + const palette = useColorPalette(); + if (kind === 'user' && hasRealAvatar(avatarUrl)) { + return ( + + ); + } + + const iconSize = getFallbackIconSize(size); + const borderRadius = square ? 2 : '50%'; + + if (kind === 'user') { + return ( + + + + ); + } + + const Icon = kind === 'team' ? PeopleIcon : OrganizationIcon; + + return ( + + + + ); +} + +export default PrincipalAvatar; diff --git a/src/components/principal/PrincipalBadge.tsx b/src/components/principal/PrincipalBadge.tsx new file mode 100644 index 00000000..fd7a1dd8 --- /dev/null +++ b/src/components/principal/PrincipalBadge.tsx @@ -0,0 +1,279 @@ +/* + * Copyright (c) 2023-2025 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +import { useMemo } from 'react'; +import { Box, Label, Text } from '@primer/react'; +import { useCache } from '../../hooks'; +import { useIAMStore } from '../../state/substates'; +import { useSelectedPrincipal } from '../../hooks/useSelectedPrincipal'; +import { formatFriendlyHandle } from '../../utils/Handles'; +import { Principal, type PrincipalDescriptor } from './Principal'; + +function normalizeUserOrigin(originRaw?: string): string | undefined { + const value = (originRaw || '').trim(); + if (!value) { + return undefined; + } + const lower = value.toLowerCase(); + if (lower === 'github') { + return 'GitHub'; + } + if (lower === 'google') { + return 'Google'; + } + if (lower === 'linkedin') { + return 'LinkedIn'; + } + if (lower === 'microsoft') { + return 'Microsoft'; + } + if (lower === 'datalayer') { + return 'Datalayer'; + } + return value; +} + +export type PrincipalBadgeInput = Omit & { + displayName?: string; +}; + +type PrincipalBadgeProps = { + principal?: PrincipalBadgeInput; + showPrincipalLabel?: boolean; + showApplyingToText?: boolean; + showOriginLabel?: boolean; + principalLabel?: string; + isAdmin?: boolean; + sx?: any; +}; + +/** + * PrincipalBadge — small inline pill that displays a resolved principal + * (user / organization / team). Falls back to the currently selected + * principal when no explicit `principal` prop is supplied. + */ +export const PrincipalBadge = ({ + principal: providedPrincipal, + showPrincipalLabel = true, + showApplyingToText = true, + showOriginLabel = true, + principalLabel = 'Principal', + isAdmin = false, + sx, +}: PrincipalBadgeProps = {}) => { + const { user } = useIAMStore(); + const { + selectedPrincipalKind, + selectedPrincipalUid, + selectedPrincipalHandle, + selectedTeamParentOrganizationHandle, + } = useSelectedPrincipal(); + const { useUser, useOrganization } = useCache(); + + const basePrincipal = useMemo(() => { + if (providedPrincipal) { + return { + ...providedPrincipal, + displayName: + providedPrincipal.displayName || + providedPrincipal.handle || + providedPrincipal.uid || + 'Principal', + }; + } + + if (selectedPrincipalKind === 'organization') { + return { + kind: 'organization', + uid: selectedPrincipalUid, + handle: selectedPrincipalHandle, + accountHandle: selectedPrincipalHandle, + displayName: selectedPrincipalHandle + ? `@${formatFriendlyHandle(selectedPrincipalHandle)}` + : 'Organization', + origin: 'Datalayer', + }; + } + + if (selectedPrincipalKind === 'team') { + const teamHandle = selectedPrincipalHandle || 'team'; + const orgHandle = selectedTeamParentOrganizationHandle || 'organization'; + return { + kind: 'team', + uid: selectedPrincipalUid, + handle: `${orgHandle}/${teamHandle}`, + accountHandle: teamHandle, + displayName: `@${formatFriendlyHandle(orgHandle)}/${formatFriendlyHandle(teamHandle)}`, + origin: 'Datalayer', + }; + } + + const fullName = [user?.firstName, user?.lastName] + .filter(Boolean) + .join(' ') + .trim(); + const resolvedHandle = user?.handle || selectedPrincipalHandle; + const fallbackHandle = resolvedHandle + ? `@${formatFriendlyHandle(resolvedHandle)}` + : '@me'; + + return { + kind: 'user', + uid: user?.id || selectedPrincipalUid, + displayName: fullName || fallbackHandle, + handle: resolvedHandle, + accountHandle: resolvedHandle, + firstName: user?.firstName, + lastName: user?.lastName, + email: user?.email, + avatarUrl: user?.avatarUrl, + origin: normalizeUserOrigin(user?.origin), + }; + }, [ + providedPrincipal, + selectedPrincipalKind, + selectedPrincipalUid, + selectedPrincipalHandle, + selectedTeamParentOrganizationHandle, + user?.id, + user?.origin, + user?.handle, + user?.firstName, + user?.lastName, + user?.email, + user?.avatarUrl, + ]); + + const userLookupUid = + basePrincipal.kind === 'user' ? String(basePrincipal.uid || '') : ''; + const organizationLookupUid = + basePrincipal.kind === 'organization' + ? String(basePrincipal.uid || '') + : ''; + + const { data: resolvedUser } = useUser(userLookupUid); + const { data: resolvedOrganization } = useOrganization(organizationLookupUid); + + const principal = useMemo(() => { + if (basePrincipal.kind === 'organization') { + const resolvedHandle = + resolvedOrganization?.handle || + basePrincipal.handle || + basePrincipal.accountHandle; + const normalizedHandle = resolvedHandle + ? formatFriendlyHandle(resolvedHandle) + : 'organization'; + return { + kind: 'organization', + uid: basePrincipal.uid, + displayName: + resolvedOrganization?.name || + basePrincipal.displayName || + `@${normalizedHandle}`, + handle: resolvedHandle, + accountHandle: resolvedHandle, + origin: basePrincipal.origin || 'Datalayer', + }; + } + + if (basePrincipal.kind === 'team') { + return { + kind: 'team', + uid: basePrincipal.uid, + displayName: + basePrincipal.displayName || + basePrincipal.handle || + basePrincipal.uid || + 'Team', + handle: basePrincipal.handle, + accountHandle: basePrincipal.accountHandle, + avatarUrl: basePrincipal.avatarUrl, + origin: basePrincipal.origin, + }; + } + + const fullName = [ + resolvedUser?.firstName || basePrincipal.firstName, + resolvedUser?.lastName || basePrincipal.lastName, + ] + .filter(Boolean) + .join(' ') + .trim(); + const resolvedHandle = + resolvedUser?.handle || + basePrincipal.handle || + basePrincipal.accountHandle; + const fallbackHandle = resolvedHandle + ? `@${formatFriendlyHandle(resolvedHandle)}` + : '@me'; + const resolvedDisplayName = + resolvedUser?.displayName || + basePrincipal.displayName || + fullName || + fallbackHandle; + const origin = normalizeUserOrigin( + resolvedUser?.origin || basePrincipal.origin, + ); + + return { + kind: 'user', + uid: resolvedUser?.uid || basePrincipal.uid, + displayName: resolvedDisplayName, + handle: resolvedHandle, + accountHandle: resolvedHandle, + firstName: resolvedUser?.firstName || basePrincipal.firstName, + lastName: resolvedUser?.lastName || basePrincipal.lastName, + email: resolvedUser?.email || basePrincipal.email, + avatarUrl: resolvedUser?.avatarUrl || basePrincipal.avatarUrl, + origin, + }; + }, [ + basePrincipal, + resolvedOrganization?.name, + resolvedOrganization?.handle, + resolvedUser?.uid, + resolvedUser?.displayName, + resolvedUser?.handle, + resolvedUser?.firstName, + resolvedUser?.lastName, + resolvedUser?.email, + resolvedUser?.avatarUrl, + resolvedUser?.origin, + ]); + + return ( + + {showPrincipalLabel && ( + + )} + {showApplyingToText && ( + Applying to + )} + + {showOriginLabel && principal.origin && ( + + )} + + ); +}; + +export default PrincipalBadge; diff --git a/src/components/principal/PrincipalBanner.tsx b/src/components/principal/PrincipalBanner.tsx new file mode 100644 index 00000000..1caa12e2 --- /dev/null +++ b/src/components/principal/PrincipalBanner.tsx @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023-2025 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +/** + * PrincipalBanner — displays the currently selected principal + * (user, organization, or team) with a colored visual so a user + * can immediately see which principal a settings page applies to. + */ + +import { Box, Label, Text } from '@primer/react'; +import { + OrganizationIcon, + PeopleIcon, + PersonIcon, +} from '@primer/octicons-react'; +import type { ReactNode } from 'react'; +import { useSelectedPrincipal } from '../../hooks/useSelectedPrincipal'; +import { useIAMStore } from '../../state/substates'; + +export type PrincipalBannerProps = { + caption?: string; + rightContent?: ReactNode; +}; + +export const PrincipalBanner = ({ + caption, + rightContent, +}: PrincipalBannerProps) => { + const { user } = useIAMStore(); + const { + selectedPrincipalKind, + selectedPrincipalHandle, + selectedTeamParentOrganizationHandle, + } = useSelectedPrincipal(); + + const isOrganization = selectedPrincipalKind === 'organization'; + const isTeam = selectedPrincipalKind === 'team'; + const handle = isOrganization + ? selectedPrincipalHandle || '' + : isTeam + ? `${selectedTeamParentOrganizationHandle || 'organization'}/${selectedPrincipalHandle || 'team'}` + : user?.handle || selectedPrincipalHandle || ''; + const Icon = isOrganization + ? OrganizationIcon + : isTeam + ? PeopleIcon + : PersonIcon; + + const accent = isOrganization ? 'done' : isTeam ? 'attention' : 'accent'; + const bg = isOrganization + ? 'done.subtle' + : isTeam + ? 'attention.subtle' + : 'accent.subtle'; + const borderColor = isOrganization + ? 'done.muted' + : isTeam + ? 'attention.muted' + : 'accent.muted'; + const fg = isOrganization ? 'done.fg' : isTeam ? 'attention.fg' : 'accent.fg'; + + return ( + + + + + + + + Principal + + + + + {handle + ? `@${handle}` + : isOrganization + ? 'Organization' + : isTeam + ? 'Team' + : 'User'} + + {caption && ( + + {caption} + + )} + + {rightContent ? ( + + {rightContent} + + ) : null} + + ); +}; + +export default PrincipalBanner; diff --git a/src/components/principal/PrincipalDetailsOverlay.tsx b/src/components/principal/PrincipalDetailsOverlay.tsx new file mode 100644 index 00000000..18e641df --- /dev/null +++ b/src/components/principal/PrincipalDetailsOverlay.tsx @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2023-2025 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +import { ActionMenu, Box, Button, Text } from '@primer/react'; +import { useNavigate } from '../../hooks'; +import { PrincipalAvatar } from './PrincipalAvatar'; + +export type PrincipalKind = 'user' | 'team' | 'organization'; + +export type PrincipalDetailsOverlayProps = { + kind: PrincipalKind; + uid?: string; + displayName: string; + handle?: string; + accountHandle?: string; + firstName?: string; + lastName?: string; + email?: string; + origin?: string; + avatarUrl?: string; + isAdmin?: boolean; +}; + +function normalize(value?: string): string { + return (value || '').trim(); +} + +export function buildPrincipalProfilePath({ + kind, + uid, + handle, + accountHandle, + isAdmin, +}: { + kind: PrincipalKind; + uid?: string; + handle?: string; + accountHandle?: string; + isAdmin?: boolean; +}): string | null { + const normalizedUid = normalize(uid); + const normalizedHandle = normalize(handle); + const normalizedAccountHandle = normalize(accountHandle); + const safeHandle = + normalizedHandle && normalizedHandle !== normalizedUid + ? normalizedHandle + : ''; + + if (kind === 'user') { + if (isAdmin && normalizedUid) { + return `/admin/management/iam/users/${normalizedUid}`; + } + if (safeHandle) { + return `/${safeHandle}`; + } + return null; + } + + if (kind === 'team') { + if (normalizedAccountHandle && safeHandle) { + return `/${normalizedAccountHandle}/team/${safeHandle}`; + } + if (safeHandle.includes('/')) { + const [orgHandle, teamHandle] = safeHandle.split('/', 2); + if (orgHandle && teamHandle) { + return `/${orgHandle}/team/${teamHandle}`; + } + } + if (safeHandle) { + return `/datalayer/team/${safeHandle}`; + } + return null; + } + + if (safeHandle) { + return `/${safeHandle}`; + } + return null; +} + +export function PrincipalDetailsOverlay({ + kind, + uid, + displayName, + handle, + accountHandle, + firstName, + lastName, + email, + origin, + avatarUrl, + isAdmin = false, +}: PrincipalDetailsOverlayProps): JSX.Element { + const navigate = useNavigate(); + + const normalizedDisplayName = + normalize(displayName) || + normalize(handle) || + normalize(uid) || + 'Principal'; + const normalizedHandle = normalize(handle); + const normalizedUid = normalize(uid); + const targetPath = buildPrincipalProfilePath({ + kind, + uid: normalizedUid, + handle: normalizedHandle, + accountHandle, + isAdmin, + }); + + return ( + + + + {normalizedDisplayName} + + + + + + + + + {normalizedDisplayName} + + {normalizedHandle ? ( + + @{normalizedHandle} + + ) : null} + + + + Type + {kind} + {normalizedHandle ? ( + <> + Handle + @{normalizedHandle} + + ) : null} + {normalizedUid ? ( + <> + UID + {normalizedUid} + + ) : null} + {kind === 'user' ? ( + <> + First name + {firstName || 'N/A'} + Last name + {lastName || 'N/A'} + Origin + {origin || 'Datalayer'} + {email ? ( + <> + Email + {email} + + ) : null} + + ) : ( + <> + Origin + {origin || 'Datalayer'} + + )} + + + + + + + + ); +} + +export default PrincipalDetailsOverlay; diff --git a/src/components/principal/PrincipalSwitcherMenu.tsx b/src/components/principal/PrincipalSwitcherMenu.tsx new file mode 100644 index 00000000..67a3161a --- /dev/null +++ b/src/components/principal/PrincipalSwitcherMenu.tsx @@ -0,0 +1,477 @@ +/* + * Copyright (c) 2023-2025 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +import { useEffect, useMemo, useState } from 'react'; +import { ActionList, ActionMenu, Box, Text } from '@primer/react'; +import { + OrganizationIcon, + PeopleIcon, + PersonIcon, +} from '@primer/octicons-react'; +import { useCache, useAuthorization } from '../../hooks'; +import { useCoreStore } from '../../state'; +import { useIAMStore } from '../../state/substates'; +import { memberships as fetchMemberships } from '../../api/iam/profile'; +import { usePrincipalStore } from '../../hooks/usePrincipalStore'; +import { useBillableAccountStore } from '../../hooks/useBillableAccountStore'; +import { useSelectedPrincipal } from '../../hooks/useSelectedPrincipal'; +import { formatFriendlyHandle } from '../../utils/Handles'; + +type TeamMembership = { + uid: string; + handle: string; + organizationUid?: string; + organizationHandle?: string; +}; + +export type PrincipalSwitcherMenuProps = { + maxLabelChars?: number; + fullWidth?: boolean; + showClosedBorder?: boolean; +}; + +function truncatePrincipalLabel(label: string, maxChars: number): string { + const trimmed = (label || '').trim(); + if (!trimmed) { + return ''; + } + if (trimmed.length <= maxChars) { + return trimmed; + } + return `${trimmed.slice(0, Math.max(0, maxChars - 1))}…`; +} + +/** + * PrincipalSwitcherMenu — the *only* component allowed to write to the + * principal store and the billable account store. It keeps both stores in + * sync per the rule: + * - selecting a user/org principal → billable account = same user/org + * - selecting a team principal → billable account = the team's parent org + */ +export function PrincipalSwitcherMenu({ + maxLabelChars = 48, + fullWidth = true, + showClosedBorder = true, +}: PrincipalSwitcherMenuProps): JSX.Element { + const { user, token, iamRunUrl } = useIAMStore(); + const { configuration } = useCoreStore(); + const { checkIsPlatformAdmin } = useAuthorization(); + const { useUserOrganizations } = useCache(); + const organizationsQuery = useUserOrganizations(); + const organizations = organizationsQuery.data || []; + const isOrganizationsLoading = organizationsQuery.isLoading; + const isPlatformAdmin = user ? checkIsPlatformAdmin(user) : false; + const [teams, setTeams] = useState([]); + const [teamsLoading, setTeamsLoading] = useState(false); + + const selectUserPrincipal = usePrincipalStore( + state => state.selectUserPrincipal, + ); + const selectOrganizationPrincipal = usePrincipalStore( + state => state.selectOrganizationPrincipal, + ); + const selectTeamPrincipal = usePrincipalStore( + state => state.selectTeamPrincipal, + ); + const setBillableAccount = useBillableAccountStore( + state => state.setBillableAccount, + ); + + const { + selectedPrincipalKind, + selectedPrincipalUid, + selectedPrincipalHandle, + selectedTeamParentOrganizationHandle, + } = useSelectedPrincipal(); + + const personalUid = user?.uid || user?.id || ''; + const personalHandle = user?.handle || ''; + + const selectUser = (uid: string, handle: string) => { + selectUserPrincipal(uid, handle); + setBillableAccount({ kind: 'user', uid, handle }); + }; + + const selectOrganization = (uid: string, handle: string) => { + selectOrganizationPrincipal(uid, handle); + setBillableAccount({ kind: 'organization', uid, handle }); + }; + + const selectTeam = (team: TeamMembership, orgHandle: string) => { + if (!team.organizationUid) { + return; + } + selectTeamPrincipal({ + teamUid: team.uid, + teamHandle: team.handle, + organizationUid: team.organizationUid, + organizationHandle: orgHandle, + }); + setBillableAccount({ + kind: 'organization', + uid: team.organizationUid, + handle: orgHandle, + }); + }; + + const getOrganizationUid = (organization: any): string => + String(organization?.uid || organization?.id || ''); + + const selectedOrganization = useMemo( + () => + organizations.find( + (org: any) => getOrganizationUid(org) === selectedPrincipalUid, + ), + [organizations, selectedPrincipalUid], + ); + + const selectedTeam = useMemo( + () => teams.find(team => team.uid === selectedPrincipalUid), + [teams, selectedPrincipalUid], + ); + + useEffect(() => { + let cancelled = false; + const loadTeams = async () => { + if (!token) { + setTeams([]); + return; + } + setTeamsLoading(true); + try { + const baseUrl = iamRunUrl || configuration.iamRunUrl; + const response = await fetchMemberships(token, baseUrl); + const rawMemberships = Array.isArray((response as any)?.memberships) + ? (response as any).memberships + : []; + const mappedTeams = rawMemberships + .filter((membership: any) => membership?.type === 'team') + .map((membership: any) => ({ + uid: String(membership?.uid || membership?.id || '').trim(), + handle: String(membership?.handle || '').trim(), + organizationUid: + String(membership?.organization_uid || '').trim() || undefined, + organizationHandle: + String( + membership?.organization_handle || + membership?.organization?.handle || + '', + ).trim() || undefined, + })) + .filter((team: TeamMembership) => Boolean(team.uid && team.handle)); + if (!cancelled) { + setTeams(mappedTeams); + } + } catch { + if (!cancelled) { + setTeams([]); + } + } finally { + if (!cancelled) { + setTeamsLoading(false); + } + } + }; + void loadTeams(); + return () => { + cancelled = true; + }; + }, [token, iamRunUrl, configuration.iamRunUrl]); + + useEffect(() => { + if (!personalUid || !personalHandle) { + return; + } + if (!selectedPrincipalUid) { + selectUser(personalUid, personalHandle); + return; + } + if (selectedPrincipalKind === 'organization' && isOrganizationsLoading) { + return; + } + if (selectedPrincipalKind === 'organization' && !selectedOrganization) { + selectUser(personalUid, personalHandle); + return; + } + if (selectedPrincipalKind === 'team' && teamsLoading) { + return; + } + if (selectedPrincipalKind === 'team' && !selectedTeam) { + selectUser(personalUid, personalHandle); + return; + } + if ( + selectedPrincipalKind === 'user' && + selectedPrincipalUid !== personalUid + ) { + selectUser(personalUid, personalHandle); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [ + personalUid, + personalHandle, + selectedPrincipalUid, + selectedPrincipalKind, + isOrganizationsLoading, + teamsLoading, + selectedOrganization, + selectedTeam, + ]); + + const effectiveHandle = + selectedPrincipalKind === 'organization' + ? selectedOrganization?.handle || + selectedPrincipalHandle || + personalHandle + : selectedPrincipalKind === 'team' + ? selectedTeam?.handle || selectedPrincipalHandle || personalHandle + : personalHandle; + + const organizationHandleByUid = useMemo(() => { + const byUid = new Map(); + for (const organization of organizations) { + const uid = getOrganizationUid(organization); + const handle = String(organization?.handle || '').trim(); + if (uid && handle) { + byUid.set(uid, handle); + } + } + return byUid; + }, [organizations]); + + const resolveTeamOrganizationHandle = (team?: TeamMembership): string => { + if (!team) { + return ''; + } + const directHandle = String(team.organizationHandle || '').trim(); + if (directHandle) { + return directHandle; + } + const fromOrganizations = team.organizationUid + ? organizationHandleByUid.get(team.organizationUid) || '' + : ''; + return fromOrganizations.trim(); + }; + + const effectiveOrganizationHandle = + selectedPrincipalKind === 'team' + ? resolveTeamOrganizationHandle(selectedTeam) || + selectedTeamParentOrganizationHandle || + '' + : ''; + + const selectedPrincipalLabel = + selectedPrincipalKind === 'team' + ? `@${formatFriendlyHandle(effectiveOrganizationHandle || personalHandle || 'organization')}/${formatFriendlyHandle(effectiveHandle)}` + : `@${formatFriendlyHandle(effectiveHandle)}`; + const selectedPrincipalLabelClosed = truncatePrincipalLabel( + selectedPrincipalLabel, + maxLabelChars, + ); + + const isCurrentUserPrincipal = selectedPrincipalKind === 'user'; + const selectedItemSx = { + bg: 'accent.subtle', + borderColor: 'accent.muted', + color: 'accent.fg', + fontWeight: 'semibold', + } as const; + const adminBadgeSx = { + ml: 'auto', + px: 1, + py: '2px', + borderRadius: 999, + bg: 'attention.subtle', + color: 'attention.fg', + fontSize: 0, + fontWeight: 'semibold', + lineHeight: 1.2, + textTransform: 'lowercase', + } as const; + + return ( + + + + + {selectedPrincipalKind === 'organization' ? ( + + ) : selectedPrincipalKind === 'team' ? ( + + ) : ( + + )} + + + + {selectedPrincipalLabelClosed} + + + {isPlatformAdmin && isCurrentUserPrincipal ? ( + + + admin + + + ) : null} + + + + + + User + { + if (isCurrentUserPrincipal) { + return; + } + if (personalUid && personalHandle) { + selectUser(personalUid, personalHandle); + } + }} + > + + + + @{formatFriendlyHandle(personalHandle || 'me')} + {isPlatformAdmin ? ( + + + admin + + + ) : null} + + + + Organizations + {organizations.length === 0 ? ( + No organizations + ) : ( + organizations.map((organization: any) => { + const organizationUid = getOrganizationUid(organization); + const isCurrentOrganizationPrincipal = + selectedPrincipalKind === 'organization' && + selectedPrincipalUid === organizationUid; + return ( + { + if (isCurrentOrganizationPrincipal) { + return; + } + if (organizationUid && organization.handle) { + selectOrganization( + organizationUid, + organization.handle, + ); + } + }} + > + + + + @{organization.handle} + + ); + }) + )} + + + Teams + {teams.length === 0 ? ( + No teams + ) : ( + teams.map(team => { + const isCurrentTeamPrincipal = + selectedPrincipalKind === 'team' && + selectedPrincipalUid === team.uid; + const orgHandle = + resolveTeamOrganizationHandle(team) || + personalHandle || + 'organization'; + return ( + { + if (isCurrentTeamPrincipal) { + return; + } + selectTeam(team, orgHandle); + }} + > + + + + @{formatFriendlyHandle(orgHandle)}/ + {formatFriendlyHandle(team.handle)} + + ); + }) + )} + + + + + ); +} + +export default PrincipalSwitcherMenu; diff --git a/src/components/principal/index.ts b/src/components/principal/index.ts new file mode 100644 index 00000000..9e8bb801 --- /dev/null +++ b/src/components/principal/index.ts @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023-2025 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +export * from './Principal'; +export * from './PrincipalAvatar'; +export * from './PrincipalBadge'; +export * from './PrincipalBanner'; +export { + buildPrincipalProfilePath, + PrincipalDetailsOverlay, +} from './PrincipalDetailsOverlay'; +export type { PrincipalDetailsOverlayProps } from './PrincipalDetailsOverlay'; +export * from './PrincipalSwitcherMenu'; diff --git a/src/components/runtimes/RuntimeLauncherDialog.tsx b/src/components/runtimes/RuntimeLauncherDialog.tsx index 64c12869..7b088b2c 100644 --- a/src/components/runtimes/RuntimeLauncherDialog.tsx +++ b/src/components/runtimes/RuntimeLauncherDialog.tsx @@ -26,7 +26,7 @@ import { useNavigate } from '../../hooks'; import { NO_RUNTIME_AVAILABLE_LABEL } from '../../i18n'; import type { IRemoteServicesManager } from '../../stateful/runtimes'; import type { RunResponseError } from '../../api/DatalayerApi'; -import type { IRuntimeSnapshot, IRuntimeDesc } from '../../models'; +import type { ICodeSandboxSnapshot, IRuntimeDesc } from '../../models'; import { iamStore, useCoreStore, useIAMStore } from '../../state'; import { createNotebook, sleep } from '../../utils'; import { Markdown } from '../display'; @@ -88,7 +88,7 @@ export interface IRuntimeLauncherDialogProps { * If provided the kernel will be started and will * restore the provided snapshot in the kernel. */ - kernelSnapshot?: IRuntimeSnapshot; + kernelSnapshot?: ICodeSandboxSnapshot; /** * HTML sanitizer @@ -99,6 +99,11 @@ export interface IRuntimeLauncherDialogProps { * Upgrade subscription URL */ upgradeSubscription?: string; + + /** + * Optional submit button label override. + */ + submitLabel?: string; } /** @@ -115,10 +120,12 @@ export function RuntimeLauncherDialog( markdownParser, sanitizer, upgradeSubscription, + submitLabel, startRuntime = true, } = props; const hasExample = startRuntime === 'with-example'; + const shouldStartRuntime = startRuntime !== 'defer'; const user = iamStore.getState().user; const environments = manager.environments.get(); @@ -141,9 +148,7 @@ export function RuntimeLauncherDialog( const [selection, setSelection] = useState( (kernelSnapshot?.environment || environments[0]?.name) ?? '', ); - const [timeLimit, setTimeLimit] = useState( - Math.min(credits?.available ?? 0, 10), - ); + const [timeLimit, setTimeLimit] = useState(10); const [runtimeName, setRuntimeName] = useState( environments[0]?.runtime?.givenNameTemplate || environments[0]?.title || '', ); @@ -156,10 +161,10 @@ export function RuntimeLauncherDialog( const [flashLevel, setFlashLevel] = useState<'danger' | 'warning'>('danger'); const isMounted = useIsMounted(); useEffect(() => { - if (startRuntime) { + if (shouldStartRuntime) { refreshCredits(); } - }, [startRuntime]); + }, [shouldStartRuntime]); const spec = useMemo( () => environments.find(spec => spec.name === selection), [environments, selection], @@ -167,9 +172,33 @@ export function RuntimeLauncherDialog( const description = spec?.description ?? ''; const burningRate = spec?.burning_rate ?? 1; const creditsToMinutes = 1.0 / burningRate / 60.0; - const max = Math.floor((credits?.available ?? 0) * creditsToMinutes); + const includedRuns = + user?.subscription?.usage?.included_runs ?? + user?.subscription?.included_runs; + const currentRuns = + user?.subscription?.usage?.current_runs ?? + user?.subscription?.current_runs ?? + user?.subscription?.used_runs; + const hasKnownRunAllowance = typeof includedRuns === 'number'; + const hasRemainingRuns = + hasKnownRunAllowance && + typeof currentRuns === 'number' && + includedRuns > 0 && + currentRuns < includedRuns; + const hasKnownCredits = typeof credits?.available === 'number'; + const maxFromCredits = hasKnownCredits + ? Math.floor((credits.available ?? 0) * creditsToMinutes) + : 10; + const effectiveMaxMinutes = + hasKnownCredits && hasKnownRunAllowance && !hasRemainingRuns + ? Math.max(1, maxFromCredits) + : Math.max(10, maxFromCredits > 0 ? maxFromCredits : 0); const outOfCredits = - startRuntime && (!credits?.available || max < Number.EPSILON); + shouldStartRuntime && + hasKnownCredits && + hasKnownRunAllowance && + !hasRemainingRuns && + ((credits.available ?? 0) <= 0 || maxFromCredits < Number.EPSILON); const handleSelectionChange = useCallback( (e: any) => { const selection = (e.target as HTMLSelectElement).value; @@ -184,7 +213,7 @@ export function RuntimeLauncherDialog( const handleSubmitRuntime = useCallback(async () => { if (selection) { setError(undefined); - setWaitingForRuntime(true); + setWaitingForRuntime(shouldStartRuntime); const spec = environments.find(s => s.name === selection); const desc: IRuntimeDesc = { name: selection, @@ -203,7 +232,7 @@ export function RuntimeLauncherDialog( desc.params['capabilities'] = ['user_storage']; } let success = true; - if (startRuntime && startRuntime !== 'defer') { + if (shouldStartRuntime) { success = false; let availableTrial = 1; let retryDelay = NOT_AVAILABLE_INIT_RETRY; @@ -299,6 +328,9 @@ export function RuntimeLauncherDialog( success = await startNewKernel(); } if (success && isMounted()) { + if (!shouldStartRuntime) { + setWaitingForRuntime(false); + } onSubmit(desc); } } @@ -312,6 +344,7 @@ export function RuntimeLauncherDialog( openExample, jupyterLabAdapter, timeLimit, + shouldStartRuntime, isMounted, ]); const handleUserStorageChange = useCallback( @@ -365,13 +398,10 @@ export function RuntimeLauncherDialog( onClick: handleSubmitRuntime, content: waitingForRuntime ? ( - ) : (startRuntime ?? true) ? ( - 'Launch' ) : ( - 'Assign from the Environment' + (submitLabel ?? ((startRuntime ?? true) ? 'Launch' : 'Assign')) ), - disabled: - waitingForRuntime || outOfCredits || timeLimit < Number.EPSILON, + disabled: waitingForRuntime || !selection || outOfCredits, autoFocus: true, }, ]} @@ -457,7 +487,7 @@ export function RuntimeLauncherDialog( } disabled={outOfCredits} label={'Time reservation'} - max={max} + max={effectiveMaxMinutes} time={timeLimit} burningRate={burningRate} onTimeChange={setTimeLimit} diff --git a/src/components/runtimes/RuntimePickerBase.tsx b/src/components/runtimes/RuntimePickerBase.tsx index 42d67c26..e06bf14d 100644 --- a/src/components/runtimes/RuntimePickerBase.tsx +++ b/src/components/runtimes/RuntimePickerBase.tsx @@ -256,43 +256,45 @@ export function RuntimePickerBase( ([group, runtimeDescs]) => ( {group} - {runtimeDescs.map(runtimeDesc => { - const annotation = runtimeDesc.podName - ? ` - ${runtimeDesc.podName.split('-', 2).reverse()[0]}` - : runtimeDesc.kernelId - ? ` - ${runtimeDesc.kernelId}` + {runtimeDescs.map(candidateRuntimeDesc => { + const annotation = candidateRuntimeDesc.podName + ? ` - ${candidateRuntimeDesc.podName.split('-', 2).reverse()[0]}` + : candidateRuntimeDesc.kernelId + ? ` - ${candidateRuntimeDesc.kernelId}` : ''; const fullDisplayName = - (runtimeDesc.displayName ?? '') + annotation; + (candidateRuntimeDesc.displayName ?? '') + annotation; const displayName = - (runtimeDesc.displayName?.length ?? 0) > + (candidateRuntimeDesc.displayName?.length ?? 0) > RUNTIME_DISPLAY_NAME_MAX_LENGTH - ? runtimeDesc.displayName!.slice( + ? candidateRuntimeDesc.displayName!.slice( 0, RUNTIME_DISPLAY_NAME_MAX_LENGTH, ) + '…' - : (runtimeDesc.displayName ?? ''); + : (candidateRuntimeDesc.displayName ?? ''); return ( { - setRuntimeDesc(runtimeDesc); + setRuntimeDesc(candidateRuntimeDesc); }} > - {runtimeDesc.location === 'local' ? ( + {candidateRuntimeDesc.location === 'local' ? ( - ) : runtimeDesc.location === 'browser' ? ( + ) : candidateRuntimeDesc.location === 'browser' ? ( ) : ( @@ -337,9 +339,11 @@ export function RuntimePickerBase( setRuntimeDesc(k); }} checked={ - (k.location === k?.location || + (k.location === runtimeDesc?.location || (isRuntimeRemote(k.location) && - isRuntimeRemote(k?.location ?? 'local'))) && + isRuntimeRemote( + runtimeDesc?.location ?? 'local', + ))) && (k.kernelId ?? k.name) === (runtimeDesc?.kernelId ?? runtimeDesc?.name) } diff --git a/src/components/runtimes/RuntimePickerCell.tsx b/src/components/runtimes/RuntimePickerCell.tsx index ea8447d5..ffdfe991 100644 --- a/src/components/runtimes/RuntimePickerCell.tsx +++ b/src/components/runtimes/RuntimePickerCell.tsx @@ -245,7 +245,8 @@ export function RuntimePickerCell(props: IRuntimePickerCellProps): JSX.Element { diff --git a/src/components/runtimes/RuntimePickerNotebook.tsx b/src/components/runtimes/RuntimePickerNotebook.tsx index aa6599dd..439551f4 100644 --- a/src/components/runtimes/RuntimePickerNotebook.tsx +++ b/src/components/runtimes/RuntimePickerNotebook.tsx @@ -76,12 +76,10 @@ export function RuntimePickerNotebook( ): JSX.Element { const { multiServiceManager, sessionContext, setValue, translator } = props; const { configuration } = useCoreStore(); - const { credits, refreshCredits, token } = useIAMStore(); + const { credits, refreshCredits, token, user } = useIAMStore(); const [selectedRuntimeDesc, setSelectedRuntimeDesc] = useState(); - const [timeLimit, setTimeLimit] = useState( - Math.min(credits?.available ?? 0, 10), - ); + const [timeLimit, setTimeLimit] = useState(10); const [userStorage, setUserStorage] = useState(false); const [canTransferFrom, setTransferFrom] = useState(false); const [canTransferTo, setTransferTo] = useState(false); @@ -192,42 +190,124 @@ export function RuntimePickerNotebook( [userStorage], ); useEffect((): void => { + const resolvedBurningRate = + selectedRuntimeDesc?.burningRate ?? + multiServiceManager.remote?.environments + .get() + .find(env => env.name === selectedRuntimeDesc?.name)?.burning_rate; + const includedRuns = + user?.subscription?.usage?.included_runs ?? + user?.subscription?.included_runs; + const currentRuns = + user?.subscription?.usage?.current_runs ?? + user?.subscription?.current_runs ?? + user?.subscription?.used_runs; + const hasKnownRunAllowance = typeof includedRuns === 'number'; + const hasRemainingRuns = + hasKnownRunAllowance && + typeof currentRuns === 'number' && + includedRuns > 0 && + currentRuns < includedRuns; + const hasKnownCredits = typeof credits?.available === 'number'; + const maxMinutes = + selectedRuntimeDesc?.location === 'remote' && resolvedBurningRate + ? Math.floor((credits?.available ?? 0) / resolvedBurningRate / 60.0) + : undefined; + const effectiveTimeLimit = + selectedRuntimeDesc?.location === 'remote' + ? Math.max( + 1, + Math.min(timeLimit, maxMinutes && maxMinutes > 0 ? maxMinutes : 10), + ) + : timeLimit; const creditsLimit = - selectedRuntimeDesc?.location === 'remote' && - selectedRuntimeDesc.burningRate - ? Math.min(timeLimit, MAXIMAL_RUNTIME_TIME_RESERVATION_MINUTES) * - selectedRuntimeDesc.burningRate * + selectedRuntimeDesc?.location === 'remote' && resolvedBurningRate + ? Math.min( + effectiveTimeLimit, + MAXIMAL_RUNTIME_TIME_RESERVATION_MINUTES, + ) * + resolvedBurningRate * 60 : undefined; - setValue( - creditsLimit !== 0 - ? { - runtime: selectedRuntimeDesc - ? ({ - environmentName: ['browser', 'remote'].includes( - selectedRuntimeDesc.location, - ) - ? `${selectedRuntimeDesc.location}-${selectedRuntimeDesc.name}` - : selectedRuntimeDesc.name, - id: selectedRuntimeDesc.kernelId, - creditsLimit, - capabilities: userStorage ? ['user_storage'] : undefined, - } satisfies Partial< - Omit & { id: string } - > | null) - : null, - selectedVariables: toTransfer, - } - : new Error('Credits limit must be strictly positive.'), - ); - }, [selectedRuntimeDesc, userStorage, toTransfer, timeLimit]); + const requiresRuntimeStart = + !!selectedRuntimeDesc && !selectedRuntimeDesc.kernelId; + if (requiresRuntimeStart && selectedRuntimeDesc.location === 'remote') { + if (!resolvedBurningRate || !Number.isFinite(resolvedBurningRate)) { + setValue({ runtime: null, selectedVariables: toTransfer }); + return; + } + if ( + hasKnownCredits && + hasKnownRunAllowance && + !hasRemainingRuns && + (!creditsLimit || creditsLimit <= 0) + ) { + setValue({ runtime: null, selectedVariables: toTransfer }); + return; + } + } + setValue({ + runtime: selectedRuntimeDesc + ? ({ + environmentName: ['browser', 'remote'].includes( + selectedRuntimeDesc.location, + ) + ? `${selectedRuntimeDesc.location}-${selectedRuntimeDesc.name}` + : selectedRuntimeDesc.name, + id: selectedRuntimeDesc.kernelId, + creditsLimit, + capabilities: userStorage ? ['user_storage'] : undefined, + } satisfies Partial< + Omit & { id: string } + > | null) + : null, + selectedVariables: toTransfer, + }); + }, [ + selectedRuntimeDesc, + userStorage, + toTransfer, + timeLimit, + multiServiceManager.remote, + credits?.available, + user, + ]); const { kernelPreference: { canStart }, } = sessionContext; - const max = Math.floor( - (credits?.available ?? 0) / (selectedRuntimeDesc?.burningRate ?? -1) / 60.0, - ); - const outOfCredits = !credits?.available || max < Number.EPSILON; + const resolvedBurningRate = + selectedRuntimeDesc?.burningRate ?? + multiServiceManager.remote?.environments + .get() + .find(env => env.name === selectedRuntimeDesc?.name)?.burning_rate; + const maxFromCredits = resolvedBurningRate + ? Math.floor((credits?.available ?? 0) / resolvedBurningRate / 60.0) + : -1; + const includedRuns = + user?.subscription?.usage?.included_runs ?? + user?.subscription?.included_runs; + const currentRuns = + user?.subscription?.usage?.current_runs ?? + user?.subscription?.current_runs ?? + user?.subscription?.used_runs; + const hasKnownRunAllowance = typeof includedRuns === 'number'; + const hasRemainingRuns = + hasKnownRunAllowance && + typeof currentRuns === 'number' && + includedRuns > 0 && + currentRuns < includedRuns; + const hasKnownCredits = typeof credits?.available === 'number'; + const effectiveMaxMinutes = + selectedRuntimeDesc?.location === 'remote' + ? hasKnownCredits && hasKnownRunAllowance && !hasRemainingRuns + ? Math.max(1, maxFromCredits) + : Math.max(10, maxFromCredits > 0 ? maxFromCredits : 0) + : Math.max(1, maxFromCredits); + const outOfCredits = + hasKnownCredits && + hasKnownRunAllowance && + !hasRemainingRuns && + maxFromCredits < Number.EPSILON; return ( @@ -279,11 +359,11 @@ export function RuntimePickerNotebook( outOfCredits || selectedRuntimeDesc?.location !== 'remote' } label={'Time reservation'} - max={max < 0 ? 1 : max} + max={effectiveMaxMinutes} time={timeLimit} onTimeChange={setTimeLimit} error={ - outOfCredits && max >= 0 + outOfCredits && maxFromCredits >= 0 ? 'You must add credits to your account.' : timeLimit === 0 ? 'You must set a time limit.' diff --git a/src/components/runtimes/RuntimeReservationControl.tsx b/src/components/runtimes/RuntimeReservationControl.tsx index 69731973..94252fcc 100644 --- a/src/components/runtimes/RuntimeReservationControl.tsx +++ b/src/components/runtimes/RuntimeReservationControl.tsx @@ -69,7 +69,22 @@ export function RuntimeReservationControl( onTimeChange, time, } = props; - const max = Math.min(maxProps, MAXIMAL_RUNTIME_TIME_RESERVATION_MINUTES); + const max = Math.max( + 1, + Math.min(maxProps, MAXIMAL_RUNTIME_TIME_RESERVATION_MINUTES), + ); + const displayedTime = Number.isFinite(time) + ? Math.min(max, Math.max(1, time)) + : 1; + const handleTimeChange = (valueOrEvent: any) => { + const rawValue = + typeof valueOrEvent === 'number' + ? valueOrEvent + : parseFloat(valueOrEvent?.target?.value); + if (Number.isFinite(rawValue)) { + onTimeChange(Math.min(max, Math.max(1, rawValue))); + } + }; // Temporary workaround to not show disabled components. const hidden = disabled; return !hidden ? ( @@ -91,8 +106,8 @@ export function RuntimeReservationControl( step={1} min={1} max={max} - value={time} - onChange={onTimeChange} + value={displayedTime} + onChange={handleTimeChange} disabled={disabled} label="" displayValue={false} @@ -103,10 +118,8 @@ export function RuntimeReservationControl( min="1" max={max} disabled={disabled} - value={Math.min(max, time).toFixed(2)} - onChange={event => { - onTimeChange(parseFloat(event.target.value)); - }} + value={displayedTime.toFixed(2)} + onChange={handleTimeChange} /> {(max === 0 || max > Number.EPSILON) && ( <> diff --git a/src/components/sharing/ShareAccessComponent.tsx b/src/components/sharing/ShareAccessComponent.tsx new file mode 100644 index 00000000..c2b058b4 --- /dev/null +++ b/src/components/sharing/ShareAccessComponent.tsx @@ -0,0 +1,2246 @@ +/* + * Copyright (c) 2023-2025 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { + KeyIcon, + PersonIcon, + OrganizationIcon, + PeopleIcon, +} from '@primer/octicons-react'; +import { Box } from '@datalayer/primer-addons'; +import { + ActionList, + ActionMenu, + Button, + Dialog, + Label, + Spinner, + Text, + TextInput, +} from '@primer/react'; +import { useToast } from '../../hooks'; +import { useCoreStore, useIAMStore } from '../../state'; +import { PrincipalAvatar } from '../principal/PrincipalAvatar'; +import { PrincipalBadge } from '../principal/PrincipalBadge'; + +// --------------------------------------------------------------------------- +// Public types (do not break callers). +// --------------------------------------------------------------------------- + +export type ItemAccessLevel = 'view' | 'update' | 'execute'; +type PrincipalKind = 'user' | 'team' | 'organization'; + +type SharingLevelPayload = { + userUids?: string[]; + teamUids?: string[]; + organizationUids?: string[]; +}; + +type SharingPayload = { + access?: Partial>; +}; + +export type ShareAccessComponentProps = { + isOpen: boolean; + requestUrl?: string; + resourceLabel: string; + resourceName?: string; + resourceDescription?: string; + onSharingAccessRestrictedChange?: ( + restricted: boolean, + message?: string, + ) => void; + defaultAccessLevel?: ItemAccessLevel; + principalKinds?: readonly PrincipalKind[]; + displayMode?: 'dialog' | 'inline'; + onClose: () => void; +}; + +// --------------------------------------------------------------------------- +// Internal types. +// --------------------------------------------------------------------------- + +type AccessByLevel = Record< + ItemAccessLevel, + { + userUids: string[]; + teamUids: string[]; + organizationUids: string[]; + } +>; + +type ACLPrincipalEntry = { + kind: PrincipalKind; + uid: string; + levels: ItemAccessLevel[]; +}; + +type OwnerPrincipal = { + kind: PrincipalKind; + uid: string; + handle: string; + displayName: string; + avatarUrl?: string; + origin?: string; + accountHandle?: string; +}; + +type ShareablePrincipal = { + kind: PrincipalKind; + uid: string; + handle: string; + name?: string | null; + email?: string | null; + avatarUrl?: string | null; + organizationUid?: string | null; + organizationHandle?: string | null; +}; + +type PrincipalSearchItem = { + kind: PrincipalKind; + uid: string; + handle: string; + displayName: string; + avatarUrl?: string; + origin?: string; + accountHandle?: string; +}; + +type PrincipalCacheEntry = { + displayName?: string; + avatarUrl?: string; + origin?: string; + handle?: string; + accountHandle?: string; +}; + +type PrincipalCache = Record; + +// --------------------------------------------------------------------------- +// Constants. +// --------------------------------------------------------------------------- + +const ACCESS_LEVELS: ItemAccessLevel[] = ['view', 'update', 'execute']; +const DEFAULT_PRINCIPAL_KINDS: readonly PrincipalKind[] = [ + 'user', + 'team', + 'organization', +]; + +const ACCESS_LEVEL_LABELS: Record = { + view: 'Viewer', + update: 'Editor', + execute: 'Executor', +}; + +// --------------------------------------------------------------------------- +// String / payload helpers. +// --------------------------------------------------------------------------- + +function pickFirstString(...values: unknown[]): string { + for (const value of values) { + if (typeof value === 'string' && value.trim()) { + return value.trim(); + } + } + return ''; +} + +function normalizePrincipalKind(kindRaw?: string): PrincipalKind { + const kind = (kindRaw || '').trim().toLowerCase(); + if (kind === 'team') { + return 'team'; + } + if (kind === 'organization' || kind === 'org') { + return 'organization'; + } + return 'user'; +} + +function toTitleCase(value: string): string { + if (!value) { + return value; + } + return value.charAt(0).toUpperCase() + value.slice(1); +} + +function normalizeUserOrigin(originRaw?: string): string | undefined { + const value = (originRaw || '').trim(); + if (!value) { + return undefined; + } + const lower = value.toLowerCase(); + if (lower === 'datalayer') { + return 'Datalayer'; + } + const extPrefix = 'urn:dla:iam:ext::'; + if (lower.startsWith(extPrefix)) { + const suffix = value.slice(extPrefix.length); + const provider = suffix.split(':')[0]?.trim(); + if (!provider) { + return 'External'; + } + return toTitleCase(provider.toLowerCase()); + } + return toTitleCase(lower); +} + +function ensurePrincipalDisplayName( + kind: PrincipalKind, + ...candidates: Array +): string { + for (const candidate of candidates) { + if (typeof candidate === 'string' && candidate.trim()) { + return candidate.trim(); + } + } + return kind === 'organization' ? 'Organization' : 'Principal'; +} + +function isSharingAuthorizationMessage(message?: string): boolean { + const normalized = (message || '').trim().toLowerCase(); + return normalized.includes('not authorized'); +} + +function principalKey(kind: PrincipalKind, uid: string): string { + return `${kind}:${uid.toLowerCase()}`; +} + +// --------------------------------------------------------------------------- +// Owner extraction (preserves all current fallbacks). +// --------------------------------------------------------------------------- + +function extractOwnerPrincipals(payload: any): OwnerPrincipal[] { + const ownersFromSharing = Array.isArray(payload?.sharing?.owners) + ? payload.sharing.owners + : []; + const ownersFromSpaceField = [ + ...(Array.isArray(payload?.space?.shared_owner_user_uids_ss) + ? payload.space.shared_owner_user_uids_ss + : []), + ...(Array.isArray(payload?.space?.shared_ower_user_uids_ss) + ? payload.space.shared_ower_user_uids_ss + : []), + ]; + + const ownerPayload = + payload?.owner || + payload?.data?.owner || + payload?.item?.owner || + payload?.space?.owner || + payload?.notebook?.owner || + payload?.lexical?.owner || + payload?.document?.owner || + payload?.cell?.owner || + payload?.resource?.owner || + payload?.sharing?.owner || + {}; + + const ownerUid = pickFirstString( + ownerPayload?.uid, + ownerPayload?.owner_uid, + ownerPayload?.ownerUid, + ownerPayload?.id, + payload?.owner_uid, + payload?.ownerUid, + ); + const ownerHandle = pickFirstString( + ownerPayload?.handle_s, + ownerPayload?.handle, + ownerPayload?.owner_handle, + ownerPayload?.ownerHandle, + payload?.owner_handle, + payload?.ownerHandle, + ); + const kindFromOwnerPayload = normalizePrincipalKind( + pickFirstString( + ownerPayload?.kind, + ownerPayload?.type, + ownerPayload?.principal_kind, + ownerPayload?.principalKind, + payload?.owner_kind, + payload?.ownerKind, + payload?.owner_type, + payload?.ownerType, + ), + ); + const accountHandle = pickFirstString( + ownerPayload?.organization_handle_s, + ownerPayload?.organizationHandle, + ownerPayload?.organization_handle, + payload?.space?.organization_handle_s, + payload?.space?.organizationHandle, + payload?.space?.organization_handle, + ); + const firstName = pickFirstString( + ownerPayload?.first_name_t, + ownerPayload?.firstName, + ); + const lastName = pickFirstString( + ownerPayload?.last_name_t, + ownerPayload?.lastName, + ); + const fullName = `${firstName} ${lastName}`.trim(); + const displayName = + fullName || + pickFirstString( + ownerPayload?.display_name_t, + ownerPayload?.display_name, + ownerPayload?.name_t, + ownerPayload?.name, + ownerHandle, + ownerUid, + ); + const origin = normalizeUserOrigin( + pickFirstString( + ownerPayload?.origin, + ownerPayload?.origin_s, + ownerPayload?.origin_t, + ), + ); + + const fallbackOwner = + ownerUid || ownerHandle + ? { + kind: kindFromOwnerPayload, + uid: ownerUid || ownerHandle, + handle: ownerHandle || accountHandle || ownerUid, + displayName, + avatarUrl: + pickFirstString( + ownerPayload?.avatar_url_s, + ownerPayload?.avatarUrl, + ownerPayload?.avatar_url, + payload?.owner_avatar_url, + payload?.owner_avatar_url_s, + payload?.ownerAvatarUrl, + ) || undefined, + origin, + accountHandle: accountHandle || undefined, + } + : null; + + const ownersFromSharingMapped = ownersFromSharing + .map((entry: any): OwnerPrincipal | null => { + if (typeof entry === 'string') { + const uid = entry.trim(); + return uid + ? { kind: 'user', uid, handle: uid, displayName: uid } + : null; + } + const uid = pickFirstString(entry?.uid, entry?.owner_uid, entry?.id); + if (!uid) { + return null; + } + const handle = pickFirstString(entry?.handle_s, entry?.handle, uid); + const ownerKind = normalizePrincipalKind( + pickFirstString(entry?.kind, entry?.type, entry?.principal_kind), + ); + const ownerOrigin = normalizeUserOrigin( + pickFirstString(entry?.origin, entry?.origin_s, entry?.origin_t), + ); + const ownerDisplayName = + pickFirstString( + entry?.display_name_t, + entry?.display_name, + entry?.name_t, + entry?.name, + handle, + uid, + ) || uid; + const ownerAvatarUrl = + pickFirstString( + entry?.avatar_url_s, + entry?.avatarUrl, + entry?.avatar_url, + ) || undefined; + return { + kind: ownerKind, + uid, + handle, + displayName: ownerDisplayName, + avatarUrl: ownerAvatarUrl, + origin: ownerOrigin, + accountHandle: + pickFirstString( + entry?.organization_handle_s, + entry?.organization_handle, + entry?.organizationHandle, + ) || undefined, + }; + }) + .filter(Boolean) as OwnerPrincipal[]; + + const ownersFromSpaceMapped = ownersFromSpaceField + .map((uid: unknown): OwnerPrincipal | null => { + if (typeof uid !== 'string' || !uid.trim()) { + return null; + } + const normalizedUid = uid.trim(); + return { + kind: 'user', + uid: normalizedUid, + handle: normalizedUid, + displayName: normalizedUid, + }; + }) + .filter(Boolean) as OwnerPrincipal[]; + + const allOwners = [ + ...ownersFromSharingMapped, + ...ownersFromSpaceMapped, + ...(fallbackOwner ? [fallbackOwner] : []), + ]; + + const deduped = new Map(); + allOwners.forEach(owner => { + const key = principalKey(owner.kind, owner.uid); + if (!deduped.has(key)) { + deduped.set(key, owner); + } + }); + return Array.from(deduped.values()); +} + +// --------------------------------------------------------------------------- +// AccessByLevel helpers. +// --------------------------------------------------------------------------- + +function emptyAccessByLevel(): AccessByLevel { + return { + view: { userUids: [], teamUids: [], organizationUids: [] }, + update: { userUids: [], teamUids: [], organizationUids: [] }, + execute: { userUids: [], teamUids: [], organizationUids: [] }, + }; +} + +function bucketFor( + kind: PrincipalKind, +): 'userUids' | 'teamUids' | 'organizationUids' { + return kind === 'user' + ? 'userUids' + : kind === 'team' + ? 'teamUids' + : 'organizationUids'; +} + +function hasPrincipal( + state: AccessByLevel, + level: ItemAccessLevel, + kind: PrincipalKind, + uid: string, +): boolean { + const lower = uid.toLowerCase(); + return state[level][bucketFor(kind)].some( + value => value.toLowerCase() === lower, + ); +} + +function withPrincipalAdded( + state: AccessByLevel, + level: ItemAccessLevel, + kind: PrincipalKind, + uid: string, +): AccessByLevel { + if (hasPrincipal(state, level, kind, uid)) { + return state; + } + const bucket = bucketFor(kind); + return { + ...state, + [level]: { + ...state[level], + [bucket]: [...state[level][bucket], uid], + }, + }; +} + +function withPrincipalRemoved( + state: AccessByLevel, + kind: PrincipalKind, + uid: string, +): AccessByLevel { + const lower = uid.toLowerCase(); + const bucket = bucketFor(kind); + const next: AccessByLevel = { + view: { ...state.view }, + update: { ...state.update }, + execute: { ...state.execute }, + }; + for (const level of ACCESS_LEVELS) { + next[level][bucket] = next[level][bucket].filter( + value => value.toLowerCase() !== lower, + ); + } + return next; +} + +function buildAclEntries( + state: AccessByLevel, + principalKinds: readonly PrincipalKind[], +): ACLPrincipalEntry[] { + const allowed = new Set(principalKinds); + const byPrincipal = new Map(); + const upsert = (kind: PrincipalKind, uid: string, level: ItemAccessLevel) => { + if (!allowed.has(kind)) { + return; + } + const key = principalKey(kind, uid); + const existing = byPrincipal.get(key); + if (!existing) { + byPrincipal.set(key, { kind, uid, levels: [level] }); + return; + } + if (!existing.levels.includes(level)) { + existing.levels.push(level); + } + }; + for (const level of ACCESS_LEVELS) { + state[level].userUids.forEach(uid => upsert('user', uid, level)); + state[level].teamUids.forEach(uid => upsert('team', uid, level)); + state[level].organizationUids.forEach(uid => + upsert('organization', uid, level), + ); + } + return Array.from(byPrincipal.values()).sort((a, b) => { + if (a.kind !== b.kind) { + return a.kind.localeCompare(b.kind); + } + return a.uid.localeCompare(b.uid); + }); +} + +function hydrateAccessFromSharing(sharing: SharingPayload): AccessByLevel { + const access = sharing.access || {}; + const view = access.view || {}; + const update = access.update || {}; + const execute = access.execute || {}; + return { + view: { + userUids: [...(view.userUids || [])], + teamUids: [...(view.teamUids || [])], + organizationUids: [...(view.organizationUids || [])], + }, + update: { + userUids: [...(update.userUids || [])], + teamUids: [...(update.teamUids || [])], + organizationUids: [...(update.organizationUids || [])], + }, + execute: { + userUids: [...(execute.userUids || [])], + teamUids: [...(execute.teamUids || [])], + organizationUids: [...(execute.organizationUids || [])], + }, + }; +} + +// --------------------------------------------------------------------------- +// Avatar shimmer (used while a user row is being hydrated). +// --------------------------------------------------------------------------- + +function AvatarShimmer({ size = 20 }: { size?: number }): JSX.Element { + return ( + + ); +} + +// --------------------------------------------------------------------------- +// Row components. +// --------------------------------------------------------------------------- + +type OwnerPrincipalRowProps = { + ownerPrincipal: OwnerPrincipal; + cache: PrincipalCache; + showAvatarSkeleton?: boolean; + isPlatformAdmin: boolean; +}; + +function OwnerPrincipalRow({ + ownerPrincipal, + cache, + showAvatarSkeleton = false, + isPlatformAdmin, +}: OwnerPrincipalRowProps): JSX.Element { + const entry = + cache[principalKey(ownerPrincipal.kind, ownerPrincipal.uid)] || {}; + const cachedHandle = entry.handle; + const safeCachedHandle = + cachedHandle && cachedHandle !== ownerPrincipal.uid ? cachedHandle : ''; + const safeOwnerHandle = + ownerPrincipal.handle && ownerPrincipal.handle !== ownerPrincipal.uid + ? ownerPrincipal.handle + : ''; + const resolvedHandle = + safeCachedHandle || safeOwnerHandle || ownerPrincipal.accountHandle; + const resolvedAccountHandle = + entry.accountHandle || ownerPrincipal.accountHandle; + const resolvedDisplayName = ensurePrincipalDisplayName( + ownerPrincipal.kind, + ownerPrincipal.displayName, + entry.displayName, + resolvedHandle, + ownerPrincipal.handle, + ownerPrincipal.accountHandle, + ownerPrincipal.uid, + ); + const resolvedAvatarUrl = ownerPrincipal.avatarUrl || entry.avatarUrl; + const resolvedOrigin = + ownerPrincipal.origin || + entry.origin || + (ownerPrincipal.kind === 'user' ? 'Datalayer' : undefined); + + return ( + + {showAvatarSkeleton ? ( + <> + + {resolvedDisplayName} + + ) : ( + + )} + + ); +} + +type AccessPrincipalRowProps = { + entry: ACLPrincipalEntry; + cache: PrincipalCache; + showAvatarSkeleton?: boolean; + isPlatformAdmin: boolean; +}; + +function AccessPrincipalRow({ + entry, + cache, + showAvatarSkeleton = false, + isPlatformAdmin, +}: AccessPrincipalRowProps): JSX.Element { + const cached = cache[principalKey(entry.kind, entry.uid)] || {}; + const cachedHandle = cached.handle; + const safeCachedHandle = + cachedHandle && cachedHandle !== entry.uid ? cachedHandle : ''; + const resolvedHandle = safeCachedHandle || cached.accountHandle; + const resolvedDisplayName = ensurePrincipalDisplayName( + entry.kind, + cached.displayName, + resolvedHandle, + entry.uid, + ); + + return ( + + {showAvatarSkeleton ? ( + <> + + {resolvedDisplayName} + + ) : ( + + )} + + ); +} + +// --------------------------------------------------------------------------- +// Main component. +// --------------------------------------------------------------------------- + +export function ShareAccessComponent({ + isOpen, + requestUrl, + resourceLabel, + resourceName, + resourceDescription: _resourceDescription, + onSharingAccessRestrictedChange, + defaultAccessLevel = 'view', + principalKinds = DEFAULT_PRINCIPAL_KINDS, + displayMode = 'dialog', + onClose, +}: ShareAccessComponentProps): JSX.Element | null { + void _resourceDescription; + const { token, user } = useIAMStore(); + const { configuration } = useCoreStore(); + const { enqueueToast } = useToast(); + const isPlatformAdmin = Boolean( + Array.isArray(user?.roles) && user.roles.includes('platform_admin'), + ); + + // ----- State ----- + const [isLoading, setIsLoading] = useState(false); + const [isSaving, setIsSaving] = useState(false); + const [selectedAccessLevel, setSelectedAccessLevel] = + useState(defaultAccessLevel); + + const [access, setAccess] = useState(emptyAccessByLevel()); + const [ownerPrincipals, setOwnerPrincipals] = useState([]); + const [shareablePrincipals, setShareablePrincipals] = useState< + ShareablePrincipal[] + >([]); + const [isLoadingShareable, setIsLoadingShareable] = useState(false); + + const [principalCache, setPrincipalCache] = useState({}); + const [hydratingUserUids, setHydratingUserUids] = useState< + Record + >({}); + + const [searchQuery, setSearchQuery] = useState(''); + const [debouncedSearchQuery, setDebouncedSearchQuery] = useState(''); + const [searchResults, setSearchResults] = useState([]); + const [isSearching, setIsSearching] = useState(false); + const [isSearchOverlayOpen, setIsSearchOverlayOpen] = useState(false); + + const [sharingAccessMessage, setSharingAccessMessage] = useState< + string | null + >(null); + const [isSharingAccessConfirmed, setIsSharingAccessConfirmed] = + useState(false); + + // ----- Refs ----- + const hasLoadedForOpenRef = useRef(false); + const hasHydratedSharingRef = useRef(false); + const lastSavedSharingRef = useRef(null); + const autoSaveTimerRef = useRef | null>(null); + const activeSearchRequestRef = useRef(0); + const searchContainerRef = useRef(null); + const searchInputRef = useRef(null); + const userHydrationMissesRef = useRef>(new Set()); + const enqueueToastRef = useRef(enqueueToast); + + useEffect(() => { + enqueueToastRef.current = enqueueToast; + }, [enqueueToast]); + + // ----- Notify caller about restricted access state ----- + useEffect(() => { + if (!onSharingAccessRestrictedChange) { + return; + } + const restricted = + !isSharingAccessConfirmed || Boolean(sharingAccessMessage); + onSharingAccessRestrictedChange( + restricted, + sharingAccessMessage || undefined, + ); + }, [ + isSharingAccessConfirmed, + sharingAccessMessage, + onSharingAccessRestrictedChange, + ]); + + // ----- Derived ----- + const canRequest = Boolean(requestUrl && token); + const canSearchPrincipals = Boolean(configuration?.iamRunUrl && token); + const iamRunUrl = configuration?.iamRunUrl; + + const principalKindsSet = useMemo( + () => new Set(principalKinds), + [principalKinds], + ); + const principalKindsKey = useMemo( + () => [...principalKinds].sort().join('|'), + [principalKinds], + ); + + const aclEntries = useMemo( + () => buildAclEntries(access, principalKinds), + [access, principalKinds], + ); + + const normalizedSearch = searchQuery.trim(); + const normalizedDebouncedSearch = debouncedSearchQuery.trim(); + const canShowSearchResults = + isSearchOverlayOpen && normalizedSearch.length > 0; + + // ----- Cache mutators (single consolidated record) ----- + const mergePrincipalCacheEntry = useCallback( + (kind: PrincipalKind, uid: string, patch: PrincipalCacheEntry) => { + if (!uid) { + return; + } + const key = principalKey(kind, uid); + setPrincipalCache(prev => { + const existing = prev[key] || {}; + const merged: PrincipalCacheEntry = { ...existing }; + let changed = false; + (Object.keys(patch) as Array).forEach( + field => { + const value = patch[field]; + if (typeof value === 'string') { + const trimmed = value.trim(); + if (trimmed && existing[field] !== trimmed) { + merged[field] = trimmed; + changed = true; + } + } + }, + ); + return changed ? { ...prev, [key]: merged } : prev; + }); + }, + [], + ); + + // ----- Reset on close / load on open ----- + useEffect(() => { + if (isOpen) { + setSelectedAccessLevel(defaultAccessLevel); + } + }, [isOpen, defaultAccessLevel]); + + useEffect(() => { + if (!isOpen) { + hasLoadedForOpenRef.current = false; + hasHydratedSharingRef.current = false; + lastSavedSharingRef.current = null; + if (autoSaveTimerRef.current) { + clearTimeout(autoSaveTimerRef.current); + autoSaveTimerRef.current = null; + } + setSearchQuery(''); + setIsSearchOverlayOpen(false); + setSearchResults([]); + setIsSearching(false); + setSharingAccessMessage(null); + setIsSharingAccessConfirmed(false); + setPrincipalCache({}); + setOwnerPrincipals([]); + setShareablePrincipals([]); + setHydratingUserUids({}); + userHydrationMissesRef.current = new Set(); + return; + } + + if (!canRequest || !requestUrl) { + setIsLoading(false); + setIsSharingAccessConfirmed(false); + return; + } + + if (hasLoadedForOpenRef.current) { + return; + } + hasLoadedForOpenRef.current = true; + + let cancelled = false; + const run = async () => { + setIsLoading(true); + setIsSharingAccessConfirmed(false); + setSharingAccessMessage(null); + try { + const response = await fetch(requestUrl, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + }); + const payload = await response.json(); + const message = + payload?.detail || + payload?.message || + `Unable to load ${resourceLabel.toLowerCase()} sharing.`; + + if (payload?.success === false) { + if (!cancelled && isSharingAuthorizationMessage(message)) { + setSharingAccessMessage(message); + setIsSharingAccessConfirmed(true); + setAccess(emptyAccessByLevel()); + setOwnerPrincipals([]); + return; + } + throw new Error(message); + } + if (!response.ok) { + if ( + response.status === 403 || + isSharingAuthorizationMessage(message) + ) { + if (!cancelled) { + setSharingAccessMessage(message); + setIsSharingAccessConfirmed(true); + setAccess(emptyAccessByLevel()); + setOwnerPrincipals([]); + } + return; + } + throw new Error(message); + } + if (cancelled) { + return; + } + + const sharing = (payload?.sharing || {}) as SharingPayload; + const owners = extractOwnerPrincipals(payload); + const hydrated = hydrateAccessFromSharing(sharing); + + setOwnerPrincipals(owners); + owners.forEach(owner => { + mergePrincipalCacheEntry(owner.kind, owner.uid, { + displayName: owner.displayName || owner.handle || owner.uid, + handle: owner.handle || owner.uid, + avatarUrl: owner.avatarUrl, + accountHandle: owner.accountHandle, + origin: owner.kind === 'user' ? owner.origin : undefined, + }); + }); + setAccess(hydrated); + lastSavedSharingRef.current = JSON.stringify(hydrated); + hasHydratedSharingRef.current = true; + setIsSharingAccessConfirmed(true); + } catch (error) { + if (cancelled) { + return; + } + const message = + error instanceof Error + ? error.message + : `Unable to load ${resourceLabel.toLowerCase()} sharing.`; + enqueueToastRef.current(message, { variant: 'error' }); + } finally { + if (!cancelled) { + setIsLoading(false); + } + } + }; + void run(); + return () => { + cancelled = true; + }; + }, [ + isOpen, + canRequest, + requestUrl, + token, + resourceLabel, + resourceName, + mergePrincipalCacheEntry, + ]); + + // ----- Fetch shareable principals on open ----- + useEffect(() => { + if (!isOpen || !canSearchPrincipals || !iamRunUrl || !token) { + return; + } + let cancelled = false; + const run = async () => { + setIsLoadingShareable(true); + try { + const response = await fetch( + `${iamRunUrl}/api/iam/v1/principals/shareable`, + { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + }, + ); + const payload = await response.json(); + if (!response.ok || payload?.success === false) { + const message = + payload?.detail || + payload?.message || + 'Unable to load shareable principals.'; + throw new Error(message); + } + if (cancelled) { + return; + } + const raw = Array.isArray(payload?.principals) + ? payload.principals + : []; + const mapped: ShareablePrincipal[] = raw + .map((entry: any): ShareablePrincipal | null => { + const uid = pickFirstString(entry?.uid); + const handle = pickFirstString(entry?.handle, entry?.handle_s); + if (!uid) { + return null; + } + const kind = normalizePrincipalKind(pickFirstString(entry?.kind)); + return { + kind, + uid, + handle: handle || uid, + name: pickFirstString(entry?.name) || null, + email: pickFirstString(entry?.email) || null, + avatarUrl: + pickFirstString(entry?.avatar_url, entry?.avatarUrl) || null, + organizationUid: + pickFirstString( + entry?.organization_uid, + entry?.organizationUid, + ) || null, + organizationHandle: + pickFirstString( + entry?.organization_handle, + entry?.organizationHandle, + ) || null, + }; + }) + .filter(Boolean) as ShareablePrincipal[]; + setShareablePrincipals(mapped); + mapped.forEach(principal => { + mergePrincipalCacheEntry(principal.kind, principal.uid, { + displayName: principal.name || principal.handle, + handle: principal.handle, + avatarUrl: principal.avatarUrl || undefined, + accountHandle: principal.organizationHandle || undefined, + origin: principal.kind === 'user' ? 'Datalayer' : undefined, + }); + }); + } catch (error) { + if (cancelled) { + return; + } + const message = + error instanceof Error + ? error.message + : 'Unable to load shareable principals.'; + enqueueToastRef.current(message, { variant: 'error' }); + } finally { + if (!cancelled) { + setIsLoadingShareable(false); + } + } + }; + void run(); + return () => { + cancelled = true; + }; + }, [isOpen, canSearchPrincipals, iamRunUrl, token, mergePrincipalCacheEntry]); + + // ----- Debounce search query ----- + useEffect(() => { + const timeout = window.setTimeout(() => { + setDebouncedSearchQuery(searchQuery); + }, 350); + return () => { + window.clearTimeout(timeout); + }; + }, [searchQuery]); + + // ----- Run search against /principals/search ----- + useEffect(() => { + if (!isOpen || !canSearchPrincipals || !iamRunUrl || !token) { + setSearchResults([]); + return; + } + const query = normalizedDebouncedSearch; + if (!query || query.length < 2) { + setSearchResults([]); + setIsSearching(false); + return; + } + let cancelled = false; + const requestId = activeSearchRequestRef.current + 1; + activeSearchRequestRef.current = requestId; + const run = async () => { + setIsSearching(true); + try { + const controller = new AbortController(); + const timeoutId = window.setTimeout(() => controller.abort(), 8000); + let response: Response; + let payload: any; + try { + response = await fetch(`${iamRunUrl}/api/iam/v1/principals/search`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + query, + principalTypes: [...principalKinds], + }), + signal: controller.signal, + }); + payload = await response.json(); + } finally { + window.clearTimeout(timeoutId); + } + if (requestId !== activeSearchRequestRef.current || cancelled) { + return; + } + if (!response.ok || !payload?.success) { + const message = + payload?.detail || + payload?.message || + 'Unable to search principals.'; + throw new Error(message); + } + const data = + payload?.data && typeof payload.data === 'object' + ? payload.data + : payload; + const users: any[] = Array.isArray(data?.users) ? data.users : []; + const teams: any[] = Array.isArray(data?.teams) ? data.teams : []; + const organizations: any[] = Array.isArray(data?.organizations) + ? data.organizations + : []; + + const mappedUsers: PrincipalSearchItem[] = users + .map((entry: any): PrincipalSearchItem | null => { + const uid = pickFirstString(entry?.uid); + const handle = pickFirstString(entry?.handle_s, entry?.handle); + if (!uid) { + return null; + } + const firstName = pickFirstString( + entry?.first_name_t, + entry?.firstName, + ); + const lastName = pickFirstString( + entry?.last_name_t, + entry?.lastName, + ); + const displayName = + `${firstName} ${lastName}`.trim() || + pickFirstString( + entry?.display_name_t, + entry?.display_name, + handle, + ); + const origin = normalizeUserOrigin( + pickFirstString(entry?.origin, entry?.origin_s, entry?.origin_t), + ); + const avatarUrl = pickFirstString( + entry?.avatar_url_s, + entry?.avatarUrl, + entry?.avatar_url, + ); + return { + kind: 'user', + uid, + handle: handle || uid, + displayName: displayName || handle || uid, + avatarUrl: avatarUrl || undefined, + origin, + }; + }) + .filter(Boolean) as PrincipalSearchItem[]; + + const mappedTeams: PrincipalSearchItem[] = teams + .map((entry: any): PrincipalSearchItem | null => { + const uid = pickFirstString(entry?.uid); + const handle = pickFirstString(entry?.handle_s, entry?.handle); + if (!uid) { + return null; + } + return { + kind: 'team', + uid, + handle: handle || uid, + displayName: + pickFirstString(entry?.name_t, entry?.name) || handle || uid, + accountHandle: + pickFirstString( + entry?.organization_handle_s, + entry?.organizationHandle, + entry?.organization_handle, + ) || undefined, + }; + }) + .filter(Boolean) as PrincipalSearchItem[]; + + const mappedOrganizations: PrincipalSearchItem[] = organizations + .map((entry: any): PrincipalSearchItem | null => { + const uid = pickFirstString(entry?.uid); + const handle = pickFirstString(entry?.handle_s, entry?.handle); + if (!uid) { + return null; + } + return { + kind: 'organization', + uid, + handle: handle || uid, + displayName: + pickFirstString(entry?.name_t, entry?.name) || handle || uid, + }; + }) + .filter(Boolean) as PrincipalSearchItem[]; + + const filtered = [ + ...mappedUsers, + ...mappedTeams, + ...mappedOrganizations, + ].filter(result => principalKindsSet.has(result.kind)); + + filtered.forEach(result => { + mergePrincipalCacheEntry(result.kind, result.uid, { + displayName: result.displayName || result.handle, + handle: result.handle, + avatarUrl: result.avatarUrl, + accountHandle: result.accountHandle, + origin: result.kind === 'user' ? result.origin : undefined, + }); + }); + + setSearchResults(filtered); + } catch (error) { + if (cancelled || requestId !== activeSearchRequestRef.current) { + return; + } + setSearchResults([]); + const message = + error instanceof Error + ? error.message + : 'Unable to search principals.'; + enqueueToastRef.current(message, { variant: 'error' }); + } finally { + if (requestId === activeSearchRequestRef.current && !cancelled) { + setIsSearching(false); + } + } + }; + void run(); + return () => { + cancelled = true; + }; + }, [ + isOpen, + canSearchPrincipals, + iamRunUrl, + token, + normalizedDebouncedSearch, + principalKindsKey, + principalKinds, + principalKindsSet, + mergePrincipalCacheEntry, + ]); + + // ----- Hydrate ACL user uids in bulk ----- + useEffect(() => { + if (!isOpen || !canSearchPrincipals || !iamRunUrl || !token) { + return; + } + const userUids = Array.from( + new Set([ + ...aclEntries + .filter(e => e.kind === 'user') + .map(e => e.uid) + .filter(Boolean), + ...ownerPrincipals + .filter(o => o.kind === 'user') + .map(o => o.uid) + .filter(Boolean), + ]), + ); + const unknown = userUids.filter(uid => { + if (!uid || userHydrationMissesRef.current.has(uid)) { + return false; + } + const cached = principalCache[principalKey('user', uid)] || {}; + return ( + !cached.displayName || + !cached.avatarUrl || + !cached.origin || + !cached.handle + ); + }); + if (unknown.length === 0) { + setHydratingUserUids({}); + return; + } + let cancelled = false; + const run = async () => { + setHydratingUserUids( + unknown.reduce( + (acc, uid) => { + acc[uid] = true; + return acc; + }, + {} as Record, + ), + ); + try { + const response = await fetch(`${iamRunUrl}/api/iam/v1/users/bulk`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ userIds: unknown }), + }); + const payload = await response.json(); + if (!response.ok || !payload?.success || cancelled) { + unknown.forEach(uid => userHydrationMissesRef.current.add(uid)); + return; + } + const data = + payload?.data && typeof payload.data === 'object' + ? payload.data + : payload; + const users: any[] = Array.isArray(data?.users) ? data.users : []; + const hydratedSet = new Set(); + users.forEach((entry: any) => { + const uid = pickFirstString(entry?.uid); + if (!uid) { + return; + } + hydratedSet.add(uid); + const handle = pickFirstString(entry?.handle_s, entry?.handle) || uid; + const firstName = pickFirstString( + entry?.first_name_t, + entry?.firstName, + ); + const lastName = pickFirstString(entry?.last_name_t, entry?.lastName); + const displayName = + `${firstName} ${lastName}`.trim() || + pickFirstString(entry?.display_name_t, entry?.display_name) || + handle; + const avatarUrl = pickFirstString( + entry?.avatar_url_s, + entry?.avatarUrl, + entry?.avatar_url, + ); + const origin = normalizeUserOrigin( + pickFirstString(entry?.origin, entry?.origin_s, entry?.origin_t), + ); + mergePrincipalCacheEntry('user', uid, { + displayName, + handle, + avatarUrl: avatarUrl || undefined, + origin, + }); + }); + unknown.forEach(uid => { + if (!hydratedSet.has(uid)) { + userHydrationMissesRef.current.add(uid); + } + }); + } catch { + unknown.forEach(uid => userHydrationMissesRef.current.add(uid)); + } finally { + if (!cancelled) { + setHydratingUserUids({}); + } + } + }; + void run(); + return () => { + cancelled = true; + }; + }, [ + isOpen, + canSearchPrincipals, + iamRunUrl, + token, + aclEntries, + ownerPrincipals, + principalCache, + mergePrincipalCacheEntry, + ]); + + // ----- Hydrate ACL team uids individually ----- + useEffect(() => { + if (!isOpen || !canSearchPrincipals || !iamRunUrl || !token) { + return; + } + const unknownTeams = aclEntries.filter(entry => { + if (entry.kind !== 'team') { + return false; + } + const cached = principalCache[principalKey('team', entry.uid)] || {}; + return !cached.displayName; + }); + if (unknownTeams.length === 0) { + return; + } + let cancelled = false; + void Promise.all( + unknownTeams.map(async entry => { + try { + const response = await fetch( + `${iamRunUrl}/api/iam/v1/teams/${encodeURIComponent(entry.uid)}`, + { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + }, + ); + const payload = await response.json(); + if (!response.ok || !payload?.success || cancelled) { + return; + } + const data = + payload?.data && typeof payload.data === 'object' + ? payload.data + : payload; + const obj = data?.team || data; + const name = pickFirstString(obj?.name_t, obj?.name); + const handle = pickFirstString(obj?.handle_s, obj?.handle); + const accountHandle = pickFirstString( + obj?.organization_handle_s, + obj?.organizationHandle, + obj?.organization_handle, + ); + mergePrincipalCacheEntry('team', entry.uid, { + displayName: name || handle, + handle, + accountHandle, + }); + } catch { + // Best effort. + } + }), + ); + return () => { + cancelled = true; + }; + }, [ + isOpen, + canSearchPrincipals, + iamRunUrl, + token, + aclEntries, + principalCache, + mergePrincipalCacheEntry, + ]); + + // ----- Auto-save on access change after hydration ----- + const saveAccess = useCallback( + async (snapshot: AccessByLevel) => { + if (!canRequest || !requestUrl) { + return; + } + setIsSaving(true); + try { + const body: SharingPayload = { + access: { + view: { + userUids: snapshot.view.userUids, + teamUids: snapshot.view.teamUids, + organizationUids: snapshot.view.organizationUids, + }, + update: { + userUids: snapshot.update.userUids, + teamUids: snapshot.update.teamUids, + organizationUids: snapshot.update.organizationUids, + }, + execute: { + userUids: snapshot.execute.userUids, + teamUids: snapshot.execute.teamUids, + organizationUids: snapshot.execute.organizationUids, + }, + }, + }; + const response = await fetch(requestUrl, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify(body), + }); + const payload = await response.json(); + const message = + payload?.detail || + payload?.message || + `Unable to update ${resourceLabel.toLowerCase()} sharing.`; + if (payload?.success === false) { + if (isSharingAuthorizationMessage(message)) { + setSharingAccessMessage(message); + return; + } + throw new Error(message); + } + if (!response.ok) { + if ( + response.status === 403 || + isSharingAuthorizationMessage(message) + ) { + setSharingAccessMessage(message); + return; + } + throw new Error(message); + } + enqueueToastRef.current(`${resourceLabel} sharing updated.`, { + variant: 'success', + }); + } catch (error) { + const message = + error instanceof Error + ? error.message + : `Unable to update ${resourceLabel.toLowerCase()} sharing.`; + enqueueToastRef.current(message, { variant: 'error' }); + } finally { + setIsSaving(false); + } + }, + [canRequest, requestUrl, token, resourceLabel], + ); + + useEffect(() => { + if (!hasHydratedSharingRef.current) { + return; + } + if (!canRequest || !requestUrl) { + return; + } + if (!isSharingAccessConfirmed || sharingAccessMessage) { + return; + } + const serialized = JSON.stringify(access); + if (lastSavedSharingRef.current === serialized) { + return; + } + if (autoSaveTimerRef.current) { + clearTimeout(autoSaveTimerRef.current); + } + autoSaveTimerRef.current = setTimeout(() => { + autoSaveTimerRef.current = null; + lastSavedSharingRef.current = serialized; + void saveAccess(access); + }, 400); + return () => { + if (autoSaveTimerRef.current) { + clearTimeout(autoSaveTimerRef.current); + autoSaveTimerRef.current = null; + } + }; + }, [ + access, + canRequest, + requestUrl, + isSharingAccessConfirmed, + sharingAccessMessage, + saveAccess, + ]); + + // ----- Action handlers ----- + const addPrincipal = useCallback( + (kind: PrincipalKind, uid: string) => { + if (!principalKindsSet.has(kind)) { + return; + } + setAccess(prev => + withPrincipalAdded(prev, selectedAccessLevel, kind, uid), + ); + }, + [principalKindsSet, selectedAccessLevel], + ); + + const removePrincipal = useCallback((kind: PrincipalKind, uid: string) => { + setAccess(prev => withPrincipalRemoved(prev, kind, uid)); + }, []); + + const handleSearchResultSelect = useCallback( + (result: PrincipalSearchItem) => { + addPrincipal(result.kind, result.uid); + setSearchQuery(''); + setIsSearchOverlayOpen(false); + setSearchResults([]); + }, + [addPrincipal], + ); + + // ----- Search overlay outside-click + escape ----- + useEffect(() => { + if (!isSearchOverlayOpen) { + return; + } + const handlePointer = (event: MouseEvent) => { + const target = event.target as Node | null; + if (target && searchContainerRef.current?.contains(target)) { + return; + } + setIsSearchOverlayOpen(false); + }; + const handleEscape = (event: KeyboardEvent) => { + if (event.key !== 'Escape') { + return; + } + event.preventDefault(); + setIsSearchOverlayOpen(false); + searchInputRef.current?.focus(); + }; + document.addEventListener('mousedown', handlePointer); + document.addEventListener('keydown', handleEscape); + return () => { + document.removeEventListener('mousedown', handlePointer); + document.removeEventListener('keydown', handleEscape); + }; + }, [isSearchOverlayOpen]); + + // ----- Shareable picker groupings ----- + const groupedShareable = useMemo(() => { + const filtered = shareablePrincipals.filter(p => + principalKindsSet.has(p.kind), + ); + const selfUid = pickFirstString(user?.uid); + const self = filtered.filter(p => p.kind === 'user' && p.uid === selfUid); + const otherUsers = filtered.filter( + p => p.kind === 'user' && p.uid !== selfUid, + ); + const orgs = filtered.filter(p => p.kind === 'organization'); + const teams = filtered.filter(p => p.kind === 'team'); + return { self, otherUsers, orgs, teams }; + }, [shareablePrincipals, principalKindsSet, user?.uid]); + + if (!isOpen) { + return null; + } + + const isReadOnly = + !canRequest || + !isSharingAccessConfirmed || + isLoading || + Boolean(sharingAccessMessage); + + // ----- Sub-renderers (kept inline for locality) ----- + const renderShareablePrincipalRow = (principal: ShareablePrincipal) => { + const alreadyAdded = hasPrincipal( + access, + selectedAccessLevel, + principal.kind, + principal.uid, + ); + const cached = + principalCache[principalKey(principal.kind, principal.uid)] || {}; + const displayName = + principal.name || cached.displayName || principal.handle || principal.uid; + const Icon = + principal.kind === 'user' + ? PersonIcon + : principal.kind === 'organization' + ? OrganizationIcon + : PeopleIcon; + return ( + { + if (!alreadyAdded) { + addPrincipal(principal.kind, principal.uid); + } + }} + disabled={alreadyAdded || isSaving || isReadOnly} + sx={{ + all: 'unset', + display: 'flex', + alignItems: 'center', + justifyContent: 'space-between', + gap: 2, + px: 2, + py: 2, + cursor: + alreadyAdded || isSaving || isReadOnly ? 'not-allowed' : 'pointer', + opacity: alreadyAdded ? 0.55 : 1, + borderRadius: 2, + borderWidth: 1, + borderStyle: 'solid', + borderColor: 'border.default', + bg: 'canvas.default', + ':hover': { + bg: + alreadyAdded || isSaving || isReadOnly + ? 'canvas.default' + : 'canvas.subtle', + }, + }} + > + + + + + {displayName} + {principal.kind === 'user' && + user?.uid && + principal.uid === user.uid && ( + + )} + + + @{principal.handle} + {principal.kind === 'team' && principal.organizationHandle && ( + + {' · '}org @{principal.organizationHandle} + + )} + + + + + + {alreadyAdded ? ( + + ) : ( + Add + )} + + + ); + }; + + const renderShareableGroup = ( + label: string, + items: ShareablePrincipal[], + ): JSX.Element | null => { + if (items.length === 0) { + return null; + } + return ( + + + {label} + + + {items.map(renderShareablePrincipalRow)} + + + ); + }; + + const content = ( + + {sharingAccessMessage && ( + + Sharing access is restricted + {sharingAccessMessage} + + )} + + + {/* Header: resource info + level selector */} + + + + Share {resourceName || `this ${resourceLabel.toLowerCase()}`}. + Pick a principal below — they will be granted the selected access + level. + + + + + + + + + {ACCESS_LEVELS.map(level => ( + setSelectedAccessLevel(level)} + > + {ACCESS_LEVEL_LABELS[level]} + + ))} + + + + + + {/* Owner */} + + Owner + {ownerPrincipals.length > 0 ? ( + + {ownerPrincipals.map((ownerPrincipal, index) => ( + + + + ))} + + ) : ( + + Owner information is not available. + + )} + + + {/* Share with… (shareable principals picker — PROMINENT) */} + + + Share with… + {isLoadingShareable && ( + + + Loading… + + )} + + {!isLoadingShareable && + groupedShareable.self.length === 0 && + groupedShareable.otherUsers.length === 0 && + groupedShareable.orgs.length === 0 && + groupedShareable.teams.length === 0 ? ( + + No principals available to share with. + + ) : ( + + {renderShareableGroup('You', groupedShareable.self)} + {renderShareableGroup('Other users', groupedShareable.otherUsers)} + {renderShareableGroup( + 'Your organizations', + groupedShareable.orgs, + )} + {renderShareableGroup('Your teams', groupedShareable.teams)} + + )} + + + {/* Secondary advanced search */} + + + Or search for any user, team, or organization + + + { + const next = e.target.value; + setSearchQuery(next); + setIsSearchOverlayOpen(next.trim().length > 0); + }} + onFocus={() => { + if (searchQuery.trim().length > 0) { + setIsSearchOverlayOpen(true); + } + }} + onKeyDown={e => { + if (e.key === 'Escape') { + e.preventDefault(); + setIsSearchOverlayOpen(false); + } + }} + placeholder="Search by handle, name, or email" + aria-label="Search principals" + disabled={isSaving} + /> + {canShowSearchResults && ( + + {isSearching ? ( + + + + Searching… + + + ) : searchResults.length === 0 ? ( + + + No principals found. + + + ) : ( + + {searchResults.map(result => ( + handleSearchResultSelect(result)} + > + + + + + {result.displayName} + {result.kind === 'user' && ( + + )} + + + @{result.handle} + + + ))} + + )} + + )} + + + + {/* ACL list */} + + + Access Control List (ACL) + + + {aclEntries.length === 0 ? ( + + + No principals shared yet. + + + ) : ( + + {aclEntries.map(entry => ( + + + + {entry.levels.map(level => ( + + ))} + + + + ))} + + )} + + + + + + {isSaving && ( + + + Saving… + + )} + {displayMode === 'dialog' && ( + + )} + + + {isLoading && ( + + Loading current sharing settings… + + )} + + ); + + if (displayMode === 'inline') { + return ( + + {content} + + ); + } + + return ( + + {content} + + ); +} + +export default ShareAccessComponent; + +// Backward-compatible aliases (deprecated: use ShareAccessComponent). +export const ShareAccessDialog = ShareAccessComponent; +export type ShareAccessDialogProps = ShareAccessComponentProps; diff --git a/src/components/sharing/SharingEditor.tsx b/src/components/sharing/SharingEditor.tsx new file mode 100644 index 00000000..dc6fe7e2 --- /dev/null +++ b/src/components/sharing/SharingEditor.tsx @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2023-2026 Datalayer, Inc. + * Distributed under the terms of the Modified BSD License. + */ + +/** + * SharingEditor — inline editor for the sharing payload shape used by + * `ShareAccessComponent` ACL endpoints. + * + * Unlike `ShareAccessComponent` (which is bound to a server-side resource + * via `requestUrl`), this component edits a free-form + * `{ access: { view/update/execute: { userUids, teamUids, organizationUids } } }` + * blob in memory. It is intended for "create" flows where the resource does + * not yet exist and the sharing payload must be POSTed alongside the rest of + * the configuration. + * + * This is a scaffold: it currently exposes a structured JSON editor with + * validation and the canonical default shape. Future iterations can replace + * the textarea with the same principal-picker UI used by + * `ShareAccessComponent`. + */ + +import { useEffect, useMemo, useState } from 'react'; +import { Box } from '@datalayer/primer-addons'; +import { FormControl, Text, Textarea } from '@primer/react'; + +export type SharingAccessLevel = 'view' | 'update' | 'execute'; + +export type SharingLevelPayload = { + userUids?: string[]; + teamUids?: string[]; + organizationUids?: string[]; +}; + +export type SharingPayload = { + access?: Partial>; +}; + +export const EMPTY_SHARING_PAYLOAD: SharingPayload = { + access: { + view: { userUids: [], teamUids: [], organizationUids: [] }, + update: { userUids: [], teamUids: [], organizationUids: [] }, + execute: { userUids: [], teamUids: [], organizationUids: [] }, + }, +}; + +export type SharingEditorProps = { + value: SharingPayload; + onChange: (next: SharingPayload) => void; + label?: string; + caption?: string; + rows?: number; + disabled?: boolean; +}; + +const stringify = (value: SharingPayload): string => { + try { + return JSON.stringify(value ?? {}, null, 2); + } catch { + return '{}'; + } +}; + +export function SharingEditor({ + value, + onChange, + label = 'Sharing', + caption = 'Edit the sharing payload. Each access level (view/update/execute) can grant access to user, team, and organization UIDs.', + rows = 10, + disabled = false, +}: SharingEditorProps): JSX.Element { + const initial = useMemo(() => stringify(value), [value]); + const [raw, setRaw] = useState(initial); + const [error, setError] = useState(null); + + useEffect(() => { + setRaw(stringify(value)); + }, [value]); + + return ( + + {label} +