From 9ee6cfab9b3fd24ff2ad22c06af41871832cebfb Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 19:09:44 -0700 Subject: [PATCH 01/12] refactor(cli): add onboard FSM transition types --- src/lib/onboard/machine/transitions.test.ts | 164 ++++++++++++++++++++ src/lib/onboard/machine/transitions.ts | 107 +++++++++++++ src/lib/onboard/machine/types.ts | 101 ++++++++++++ 3 files changed, 372 insertions(+) create mode 100644 src/lib/onboard/machine/transitions.test.ts create mode 100644 src/lib/onboard/machine/transitions.ts create mode 100644 src/lib/onboard/machine/types.ts diff --git a/src/lib/onboard/machine/transitions.test.ts b/src/lib/onboard/machine/transitions.test.ts new file mode 100644 index 0000000000..875a0ec45a --- /dev/null +++ b/src/lib/onboard/machine/transitions.test.ts @@ -0,0 +1,164 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it } from "vitest"; + +import { + ONBOARD_MACHINE_EVENT_TYPES, + ONBOARD_MACHINE_STATES, + ONBOARD_NON_TERMINAL_MACHINE_STATES, +} from "./types"; +import { + assertValidOnboardMachineTransition, + canTransitionOnboardMachineState, + getNextOnboardMachineStates, + getOnboardMachineTransition, + InvalidOnboardMachineTransitionError, + isOnboardMachineState, + isTerminalOnboardMachineState, + ONBOARD_MACHINE_DIRECT_TRANSITIONS, + ONBOARD_MACHINE_NEXT_STATES, + ONBOARD_MACHINE_TRANSITIONS, +} from "./transitions"; + +const canonicalDirectTransitions = [ + ["init", "preflight", "advance"], + ["preflight", "gateway", "advance"], + ["gateway", "provider_selection", "advance"], + ["provider_selection", "inference", "advance"], + ["inference", "provider_selection", "retry"], + ["inference", "sandbox", "advance"], + ["sandbox", "openclaw", "branch"], + ["sandbox", "agent_setup", "branch"], + ["openclaw", "policies", "advance"], + ["agent_setup", "policies", "advance"], + ["policies", "finalizing", "advance"], + ["finalizing", "post_verify", "advance"], + ["post_verify", "complete", "advance"], +] as const; + +describe("onboard machine vocabulary", () => { + it("defines the initial coarse state vocabulary from issue #3802", () => { + expect(ONBOARD_MACHINE_STATES).toEqual([ + "init", + "preflight", + "gateway", + "provider_selection", + "inference", + "sandbox", + "agent_setup", + "openclaw", + "policies", + "finalizing", + "post_verify", + "complete", + "failed", + ]); + }); + + it("defines the initial observe-only event vocabulary from issue #3802", () => { + expect(ONBOARD_MACHINE_EVENT_TYPES).toEqual([ + "onboard.started", + "onboard.resumed", + "onboard.completed", + "onboard.failed", + "state.entered", + "state.exited", + "state.skipped", + "state.completed", + "state.failed", + "state.repair.started", + "state.repair.completed", + "state.repair.failed", + "context.updated", + "resume.conflict", + "hook.started", + "hook.completed", + "hook.failed", + ]); + }); + + it("recognizes valid machine state names", () => { + expect(isOnboardMachineState("preflight")).toBe(true); + expect(isOnboardMachineState("messaging")).toBe(false); + expect(isOnboardMachineState(null)).toBe(false); + }); +}); + +describe("onboard machine transitions", () => { + it("encodes the canonical direct transition graph", () => { + expect(ONBOARD_MACHINE_DIRECT_TRANSITIONS).toEqual( + canonicalDirectTransitions.map(([from, to, kind]) => ({ from, to, kind })), + ); + }); + + it("allows every non-terminal state to fail", () => { + for (const state of ONBOARD_NON_TERMINAL_MACHINE_STATES) { + expect(canTransitionOnboardMachineState(state, "failed")).toBe(true); + expect(getOnboardMachineTransition(state, "failed")?.kind).toBe("failure"); + } + }); + + it("keeps terminal states terminal", () => { + expect(isTerminalOnboardMachineState("complete")).toBe(true); + expect(isTerminalOnboardMachineState("failed")).toBe(true); + expect(getNextOnboardMachineStates("complete")).toEqual([]); + expect(getNextOnboardMachineStates("failed")).toEqual([]); + expect(canTransitionOnboardMachineState("complete", "failed")).toBe(false); + expect(canTransitionOnboardMachineState("failed", "init")).toBe(false); + }); + + it("exposes next states in deterministic order", () => { + expect(ONBOARD_MACHINE_NEXT_STATES).toEqual({ + init: ["preflight", "failed"], + preflight: ["gateway", "failed"], + gateway: ["provider_selection", "failed"], + provider_selection: ["inference", "failed"], + inference: ["provider_selection", "sandbox", "failed"], + sandbox: ["openclaw", "agent_setup", "failed"], + agent_setup: ["policies", "failed"], + openclaw: ["policies", "failed"], + policies: ["finalizing", "failed"], + finalizing: ["post_verify", "failed"], + post_verify: ["complete", "failed"], + complete: [], + failed: [], + }); + }); + + it("classifies retry and branch transitions", () => { + expect(assertValidOnboardMachineTransition("inference", "provider_selection")).toMatchObject({ + kind: "retry", + }); + expect(assertValidOnboardMachineTransition("sandbox", "openclaw")).toMatchObject({ + kind: "branch", + }); + expect(assertValidOnboardMachineTransition("sandbox", "agent_setup")).toMatchObject({ + kind: "branch", + }); + }); + + it("rejects transitions outside the graph", () => { + expect(() => assertValidOnboardMachineTransition("init", "sandbox")).toThrow( + InvalidOnboardMachineTransitionError, + ); + expect(() => assertValidOnboardMachineTransition("complete", "failed")).toThrow( + "complete -> failed", + ); + }); + + it("keeps the next-state map aligned with the transition list", () => { + for (const state of ONBOARD_MACHINE_STATES) { + expect( + ONBOARD_MACHINE_TRANSITIONS.filter((transition) => transition.from === state).map( + (transition) => transition.to, + ), + ).toEqual(getNextOnboardMachineStates(state)); + } + }); + + it("does not contain duplicate transition edges", () => { + const edges = ONBOARD_MACHINE_TRANSITIONS.map(({ from, to }) => `${from}->${to}`); + expect(new Set(edges).size).toBe(edges.length); + }); +}); diff --git a/src/lib/onboard/machine/transitions.ts b/src/lib/onboard/machine/transitions.ts new file mode 100644 index 0000000000..9f23e3895a --- /dev/null +++ b/src/lib/onboard/machine/transitions.ts @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { OnboardMachineState, OnboardMachineTransition } from "./types"; +import { + ONBOARD_MACHINE_STATES, + ONBOARD_NON_TERMINAL_MACHINE_STATES, + ONBOARD_TERMINAL_MACHINE_STATES, +} from "./types"; + +export const ONBOARD_MACHINE_NEXT_STATES = { + init: ["preflight", "failed"], + preflight: ["gateway", "failed"], + gateway: ["provider_selection", "failed"], + provider_selection: ["inference", "failed"], + inference: ["provider_selection", "sandbox", "failed"], + sandbox: ["openclaw", "agent_setup", "failed"], + agent_setup: ["policies", "failed"], + openclaw: ["policies", "failed"], + policies: ["finalizing", "failed"], + finalizing: ["post_verify", "failed"], + post_verify: ["complete", "failed"], + complete: [], + failed: [], +} as const satisfies Readonly>; + +export const ONBOARD_MACHINE_DIRECT_TRANSITIONS = [ + { from: "init", to: "preflight", kind: "advance" }, + { from: "preflight", to: "gateway", kind: "advance" }, + { from: "gateway", to: "provider_selection", kind: "advance" }, + { from: "provider_selection", to: "inference", kind: "advance" }, + { from: "inference", to: "provider_selection", kind: "retry" }, + { from: "inference", to: "sandbox", kind: "advance" }, + { from: "sandbox", to: "openclaw", kind: "branch" }, + { from: "sandbox", to: "agent_setup", kind: "branch" }, + { from: "openclaw", to: "policies", kind: "advance" }, + { from: "agent_setup", to: "policies", kind: "advance" }, + { from: "policies", to: "finalizing", kind: "advance" }, + { from: "finalizing", to: "post_verify", kind: "advance" }, + { from: "post_verify", to: "complete", kind: "advance" }, +] as const satisfies readonly OnboardMachineTransition[]; + +export const ONBOARD_MACHINE_FAILURE_TRANSITIONS = ONBOARD_NON_TERMINAL_MACHINE_STATES.map( + (from) => ({ from, to: "failed" as const, kind: "failure" as const }), +) satisfies readonly OnboardMachineTransition[]; + +export const ONBOARD_MACHINE_TRANSITIONS = [ + ...ONBOARD_MACHINE_DIRECT_TRANSITIONS, + ...ONBOARD_MACHINE_FAILURE_TRANSITIONS, +] as const satisfies readonly OnboardMachineTransition[]; + +export class InvalidOnboardMachineTransitionError extends Error { + readonly from: OnboardMachineState; + readonly to: OnboardMachineState; + + constructor(from: OnboardMachineState, to: OnboardMachineState) { + super(`Invalid onboarding machine transition: ${from} -> ${to}`); + this.name = "InvalidOnboardMachineTransitionError"; + this.from = from; + this.to = to; + } +} + +export function isOnboardMachineState(value: unknown): value is OnboardMachineState { + return typeof value === "string" && ONBOARD_MACHINE_STATES.includes(value as OnboardMachineState); +} + +export function isTerminalOnboardMachineState( + state: OnboardMachineState, +): state is "complete" | "failed" { + return ONBOARD_TERMINAL_MACHINE_STATES.includes(state as "complete" | "failed"); +} + +export function getNextOnboardMachineStates( + from: OnboardMachineState, +): readonly OnboardMachineState[] { + return ONBOARD_MACHINE_NEXT_STATES[from]; +} + +export function canTransitionOnboardMachineState( + from: OnboardMachineState, + to: OnboardMachineState, +): boolean { + return getNextOnboardMachineStates(from).includes(to); +} + +export function getOnboardMachineTransition( + from: OnboardMachineState, + to: OnboardMachineState, +): OnboardMachineTransition | null { + return ( + ONBOARD_MACHINE_TRANSITIONS.find( + (transition) => transition.from === from && transition.to === to, + ) ?? null + ); +} + +export function assertValidOnboardMachineTransition( + from: OnboardMachineState, + to: OnboardMachineState, +): OnboardMachineTransition { + const transition = getOnboardMachineTransition(from, to); + if (!transition) { + throw new InvalidOnboardMachineTransitionError(from, to); + } + return transition; +} diff --git a/src/lib/onboard/machine/types.ts b/src/lib/onboard/machine/types.ts new file mode 100644 index 0000000000..bbba7bd5f6 --- /dev/null +++ b/src/lib/onboard/machine/types.ts @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +/** + * Coarse onboarding finite-state-machine vocabulary. + * + * These types intentionally model only major step boundaries. Mid-operation + * resume inside gateway startup, sandbox creation, credential upserts, model + * probes, or policy application is out of scope for the initial FSM shell. + */ + +export const ONBOARD_MACHINE_STATES = [ + "init", + "preflight", + "gateway", + "provider_selection", + "inference", + "sandbox", + "agent_setup", + "openclaw", + "policies", + "finalizing", + "post_verify", + "complete", + "failed", +] as const; + +export type OnboardMachineState = (typeof ONBOARD_MACHINE_STATES)[number]; + +export const ONBOARD_TERMINAL_MACHINE_STATES = ["complete", "failed"] as const; + +export type OnboardTerminalMachineState = + (typeof ONBOARD_TERMINAL_MACHINE_STATES)[number]; + +export type OnboardNonTerminalMachineState = Exclude< + OnboardMachineState, + OnboardTerminalMachineState +>; + +export const ONBOARD_NON_TERMINAL_MACHINE_STATES: readonly OnboardNonTerminalMachineState[] = + ONBOARD_MACHINE_STATES.filter( + (state): state is OnboardNonTerminalMachineState => + !ONBOARD_TERMINAL_MACHINE_STATES.includes(state as OnboardTerminalMachineState), + ); + +export const ONBOARD_MACHINE_EVENT_TYPES = [ + "onboard.started", + "onboard.resumed", + "onboard.completed", + "onboard.failed", + "state.entered", + "state.exited", + "state.skipped", + "state.completed", + "state.failed", + "state.repair.started", + "state.repair.completed", + "state.repair.failed", + "context.updated", + "resume.conflict", + "hook.started", + "hook.completed", + "hook.failed", +] as const; + +export type OnboardMachineEventType = (typeof ONBOARD_MACHINE_EVENT_TYPES)[number]; + +export type OnboardMachineTransitionKind = + | "advance" + | "retry" + | "branch" + | "failure"; + +export interface OnboardMachineTransition { + from: OnboardMachineState; + to: OnboardMachineState; + kind: OnboardMachineTransitionKind; +} + +/** + * Stable, redacted context keys that machine events may expose. + * + * Do not add raw secrets or unredacted URLs here. Runtime-derived topology + * decisions such as Docker/WSL reachability, Ollama proxy necessity, or live + * gateway health should be recomputed during execution rather than stored as + * durable FSM context. + */ +export interface OnboardMachineContext { + agent?: string | null; + sandboxName?: string | null; + provider?: string | null; + model?: string | null; + endpointUrl?: string | null; + credentialEnv?: string | null; + preferredInferenceApi?: string | null; + hermesAuthMethod?: "oauth" | "api_key" | null; + hermesToolGateways?: string[] | null; + policyPresets?: string[] | null; + messagingChannels?: string[] | null; + gpuPassthrough?: boolean; +} From b9e4545e44066975dab7945a93b580b366ec82c2 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 19:27:06 -0700 Subject: [PATCH 02/12] refactor(cli): emit onboard session machine events --- src/lib/onboard/machine/events.ts | 166 ++++++++++++++++++++++++++ src/lib/state/onboard-session.test.ts | 90 ++++++++++++++ src/lib/state/onboard-session.ts | 94 +++++++++++++-- 3 files changed, 343 insertions(+), 7 deletions(-) create mode 100644 src/lib/onboard/machine/events.ts diff --git a/src/lib/onboard/machine/events.ts b/src/lib/onboard/machine/events.ts new file mode 100644 index 0000000000..9a68d3f899 --- /dev/null +++ b/src/lib/onboard/machine/events.ts @@ -0,0 +1,166 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { JsonObject, JsonValue } from "../../core/json-types"; +import { redactSensitiveText, redactUrl } from "../../security/redact"; +import type { HermesAuthMethod, Session } from "../../state/onboard-session"; +import type { + OnboardMachineContext, + OnboardMachineEventType, + OnboardMachineState, +} from "./types"; + +export const ONBOARD_SESSION_STEP_TO_MACHINE_STATE = { + preflight: "preflight", + gateway: "gateway", + provider_selection: "provider_selection", + inference: "inference", + sandbox: "sandbox", + agent_setup: "agent_setup", + openclaw: "openclaw", + policies: "policies", +} as const satisfies Readonly>; + +export type OnboardSessionStepName = keyof typeof ONBOARD_SESSION_STEP_TO_MACHINE_STATE; + +export interface OnboardMachineEvent { + version: 1; + type: OnboardMachineEventType; + occurredAt: string; + sessionId: string | null; + state: OnboardMachineState | null; + step: OnboardSessionStepName | null; + context: OnboardMachineContext; + error: string | null; + metadata: JsonObject; +} + +export type OnboardMachineEventListener = (event: OnboardMachineEvent) => void; + +const listeners = new Set(); + +export function addOnboardMachineEventListener( + listener: OnboardMachineEventListener, +): () => void { + listeners.add(listener); + return () => { + listeners.delete(listener); + }; +} + +export function clearOnboardMachineEventListeners(): void { + listeners.clear(); +} + +export function isOnboardSessionStepName(value: string): value is OnboardSessionStepName { + return Object.prototype.hasOwnProperty.call(ONBOARD_SESSION_STEP_TO_MACHINE_STATE, value); +} + +export function machineStateFromOnboardSessionStep( + stepName: string | null | undefined, +): OnboardMachineState | null { + if (!stepName || !isOnboardSessionStepName(stepName)) return null; + return ONBOARD_SESSION_STEP_TO_MACHINE_STATE[stepName]; +} + +function nullableString(value: unknown): string | null { + return typeof value === "string" ? value : null; +} + +function stringArray(value: unknown): string[] | null { + if (!Array.isArray(value)) return null; + return value.filter((entry): entry is string => typeof entry === "string"); +} + +function hermesAuthMethod(value: unknown): HermesAuthMethod | null { + return value === "oauth" || value === "api_key" ? value : null; +} + +function booleanValue(value: unknown): boolean | undefined { + return typeof value === "boolean" ? value : undefined; +} + +function sanitizeJsonValue(value: unknown): JsonValue { + if (typeof value === "string") return redactUrl(value) ?? redactSensitiveText(value) ?? ""; + if (typeof value === "number" && Number.isFinite(value)) return value; + if (typeof value === "boolean" || value === null) return value; + if (Array.isArray(value)) return value.map((entry) => sanitizeJsonValue(entry)); + if (typeof value !== "object" || value === null) return String(value); + + const result: JsonObject = {}; + for (const [key, entry] of Object.entries(value)) { + result[key] = sanitizeJsonValue(entry); + } + return result; +} + +export function sanitizeOnboardMachineEventMetadata( + metadata: Record | null | undefined, +): JsonObject { + if (!metadata || typeof metadata !== "object" || Array.isArray(metadata)) return {}; + const sanitized: JsonObject = {}; + for (const [key, value] of Object.entries(metadata)) { + sanitized[key] = sanitizeJsonValue(value); + } + return sanitized; +} + +export function buildOnboardMachineContext(session: Session): OnboardMachineContext { + const endpointUrl = redactUrl(session.endpointUrl); + return { + agent: nullableString(session.agent), + sandboxName: nullableString(session.sandboxName), + provider: nullableString(session.provider), + model: nullableString(session.model), + endpointUrl, + credentialEnv: nullableString(session.credentialEnv), + preferredInferenceApi: nullableString(session.preferredInferenceApi), + hermesAuthMethod: hermesAuthMethod(session.hermesAuthMethod), + hermesToolGateways: stringArray(session.hermesToolGateways), + policyPresets: stringArray(session.policyPresets), + messagingChannels: stringArray(session.messagingChannels), + gpuPassthrough: booleanValue(session.gpuPassthrough), + }; +} + +export function createOnboardMachineEvent({ + type, + session, + step, + state, + error = null, + metadata = {}, +}: { + type: OnboardMachineEventType; + session: Session; + step?: string | null; + state?: OnboardMachineState | null; + error?: string | null; + metadata?: Record | null; +}): OnboardMachineEvent { + const normalizedStep = step && isOnboardSessionStepName(step) ? step : null; + return { + version: 1, + type, + occurredAt: new Date().toISOString(), + sessionId: nullableString(session.sessionId), + state: state ?? machineStateFromOnboardSessionStep(normalizedStep), + step: normalizedStep, + context: buildOnboardMachineContext(session), + error: redactSensitiveText(error), + metadata: sanitizeOnboardMachineEventMetadata(metadata), + }; +} + +export function emitOnboardMachineEvent(event: OnboardMachineEvent): void { + if (listeners.size === 0) return; + for (const listener of listeners) { + try { + listener(event); + } catch { + // Event observers are diagnostics only. A broken observer must not + // change onboarding behavior; hook failure events are introduced by the + // later observe-only hook API. + } + } +} diff --git a/src/lib/state/onboard-session.test.ts b/src/lib/state/onboard-session.test.ts index b2c925858f..5ddd94908d 100644 --- a/src/lib/state/onboard-session.test.ts +++ b/src/lib/state/onboard-session.test.ts @@ -9,11 +9,15 @@ import { createRequire } from "node:module"; const require = createRequire(import.meta.url); const distPath = require.resolve("../../../dist/lib/state/onboard-session"); +const eventsDistPath = require.resolve("../../../dist/lib/onboard/machine/events"); const originalHome = process.env.HOME; type OnboardSessionModule = typeof import("../../../dist/lib/state/onboard-session"); +type OnboardMachineEventsModule = typeof import("../../../dist/lib/onboard/machine/events"); +type OnboardMachineEvent = import("../../../dist/lib/onboard/machine/events").OnboardMachineEvent; type LoadedSession = NonNullable>; type DebugSummary = NonNullable>; let session: OnboardSessionModule; +let machineEvents: OnboardMachineEventsModule; let tmpDir: string; function requireLoadedSession( @@ -44,13 +48,18 @@ beforeEach(() => { tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "nemoclaw-onboard-session-")); process.env.HOME = tmpDir; delete require.cache[distPath]; + delete require.cache[eventsDistPath]; session = require("../../../dist/lib/state/onboard-session"); + machineEvents = require("../../../dist/lib/onboard/machine/events"); + machineEvents.clearOnboardMachineEventListeners(); session.clearSession(); session.releaseOnboardLock(); }); afterEach(() => { + machineEvents.clearOnboardMachineEventListeners(); delete require.cache[distPath]; + delete require.cache[eventsDistPath]; fs.rmSync(tmpDir, { recursive: true, force: true }); if (originalHome === undefined) { delete process.env.HOME; @@ -117,6 +126,87 @@ describe("onboard session", () => { expect(loaded.failure.message).toMatch(/Sandbox creation failed/); }); + it("emits redacted structured machine events for session step mutations", () => { + const emitted: OnboardMachineEvent[] = []; + machineEvents.addOnboardMachineEventListener((event) => emitted.push(event)); + + session.saveSession(session.createSession({ sessionId: "session-1" })); + session.markStepStarted("gateway"); + session.markStepComplete("gateway", { + sandboxName: "my-assistant", + endpointUrl: + "https://alice:super-secret-token@example.com/v1?token=super-secret-token&keep=yes#token=super-secret-token", + credentialEnv: "NVIDIA_API_KEY", + }); + session.markStepSkipped("openclaw"); + session.markStepFailed("sandbox", "NVIDIA_API_KEY=super-secret-token"); + session.completeSession({ provider: "ollama-local", credentialEnv: null }); + + expect(emitted.map((event) => event.type)).toEqual([ + "state.entered", + "context.updated", + "state.completed", + "state.skipped", + "state.failed", + "onboard.failed", + "context.updated", + "onboard.completed", + ]); + expect(emitted[0]).toMatchObject({ + version: 1, + sessionId: "session-1", + state: "gateway", + step: "gateway", + error: null, + }); + expect(emitted[1].context).toMatchObject({ + sandboxName: "my-assistant", + credentialEnv: "NVIDIA_API_KEY", + }); + expect(emitted[1].context.endpointUrl).toBe( + "https://example.com/v1?token=%3CREDACTED%3E&keep=yes", + ); + expect(emitted[1].metadata.fields).toEqual([ + "sandboxName", + "endpointUrl", + "credentialEnv", + ]); + expect(emitted[4]).toMatchObject({ + type: "state.failed", + state: "sandbox", + step: "sandbox", + error: "NVIDIA_API_KEY=", + }); + expect(emitted[5]).toMatchObject({ type: "onboard.failed", state: "failed" }); + expect(emitted.at(-1)).toMatchObject({ type: "onboard.completed", state: "complete" }); + expect(JSON.stringify(emitted)).not.toContain("super-secret-token"); + + const persisted = JSON.parse(fs.readFileSync(session.SESSION_FILE, "utf8")); + expect(persisted.events).toBeUndefined(); + }); + + it("keeps event observer failures from changing session mutation behavior", () => { + machineEvents.addOnboardMachineEventListener(() => { + throw new Error("observer failed"); + }); + + session.saveSession(session.createSession()); + expect(() => session.markStepStarted("preflight")).not.toThrow(); + + const loaded = requireLoadedSession(session.loadSession()); + expect(loaded.steps.preflight.status).toBe("in_progress"); + }); + + it("does not emit machine events for unknown session step names", () => { + const emitted: OnboardMachineEvent[] = []; + machineEvents.addOnboardMachineEventListener((event) => emitted.push(event)); + + session.saveSession(session.createSession()); + session.markStepStarted("not_a_real_step"); + + expect(emitted).toEqual([]); + }); + it("persists safe provider metadata without persisting secrets", () => { session.saveSession(session.createSession()); const unsafeProviderUpdate: Parameters[1] & { diff --git a/src/lib/state/onboard-session.ts b/src/lib/state/onboard-session.ts index f05c1116e8..7fe94d8096 100644 --- a/src/lib/state/onboard-session.ts +++ b/src/lib/state/onboard-session.ts @@ -18,6 +18,10 @@ import { sanitizeMessagingChannelConfig, type MessagingChannelConfig, } from "../messaging-channel-config"; +import { + createOnboardMachineEvent, + emitOnboardMachineEvent, +} from "../onboard/machine/events"; import { redactSensitiveText, redactUrl } from "../security/redact"; export const SESSION_VERSION = 1; @@ -883,7 +887,8 @@ export function updateSession(mutator: (session: Session) => Session | void): Se } export function markStepStarted(stepName: string): Session { - return updateSession((session) => { + let shouldEmit = false; + const updatedSession = updateSession((session) => { const step = session.steps[stepName]; if (!step) return session; step.status = "in_progress"; @@ -893,12 +898,21 @@ export function markStepStarted(stepName: string): Session { session.lastStepStarted = stepName; session.failure = null; session.status = "in_progress"; + shouldEmit = true; return session; }); + if (shouldEmit) { + emitOnboardMachineEvent( + createOnboardMachineEvent({ type: "state.entered", session: updatedSession, step: stepName }), + ); + } + return updatedSession; } export function markStepComplete(stepName: string, updates: SessionUpdates = {}): Session { - return updateSession((session) => { + const safeUpdates = filterSafeUpdates(updates); + let shouldEmit = false; + const updatedSession = updateSession((session) => { const step = session.steps[stepName]; if (!step) return session; step.status = "complete"; @@ -906,13 +920,31 @@ export function markStepComplete(stepName: string, updates: SessionUpdates = {}) step.error = null; session.lastCompletedStep = stepName; session.failure = null; - Object.assign(session, filterSafeUpdates(updates)); + Object.assign(session, safeUpdates); + shouldEmit = true; return session; }); + if (shouldEmit) { + if (Object.keys(safeUpdates).length > 0) { + emitOnboardMachineEvent( + createOnboardMachineEvent({ + type: "context.updated", + session: updatedSession, + step: stepName, + metadata: { fields: Object.keys(safeUpdates) }, + }), + ); + } + emitOnboardMachineEvent( + createOnboardMachineEvent({ type: "state.completed", session: updatedSession, step: stepName }), + ); + } + return updatedSession; } export function markStepSkipped(stepName: string): Session { - return updateSession((session) => { + let shouldEmit = false; + const updatedSession = updateSession((session) => { const step = session.steps[stepName]; if (!step) return session; if (step.status === "complete" || step.status === "failed") return session; @@ -920,12 +952,20 @@ export function markStepSkipped(stepName: string): Session { step.startedAt = null; step.completedAt = null; step.error = null; + shouldEmit = true; return session; }); + if (shouldEmit) { + emitOnboardMachineEvent( + createOnboardMachineEvent({ type: "state.skipped", session: updatedSession, step: stepName }), + ); + } + return updatedSession; } export function markStepFailed(stepName: string, message: string | null = null): Session { - return updateSession((session) => { + let shouldEmit = false; + const updatedSession = updateSession((session) => { const step = session.steps[stepName]; if (!step) return session; step.status = "failed"; @@ -937,18 +977,58 @@ export function markStepFailed(stepName: string, message: string | null = null): recordedAt: new Date().toISOString(), }); session.status = "failed"; + shouldEmit = true; return session; }); + if (shouldEmit) { + emitOnboardMachineEvent( + createOnboardMachineEvent({ + type: "state.failed", + session: updatedSession, + step: stepName, + error: message, + }), + ); + emitOnboardMachineEvent( + createOnboardMachineEvent({ + type: "onboard.failed", + session: updatedSession, + state: "failed", + step: stepName, + error: message, + }), + ); + } + return updatedSession; } export function completeSession(updates: SessionUpdates = {}): Session { - return updateSession((session) => { - Object.assign(session, filterSafeUpdates(updates)); + const safeUpdates = filterSafeUpdates(updates); + const updatedSession = updateSession((session) => { + Object.assign(session, safeUpdates); session.status = "complete"; session.resumable = false; session.failure = null; return session; }); + if (Object.keys(safeUpdates).length > 0) { + emitOnboardMachineEvent( + createOnboardMachineEvent({ + type: "context.updated", + session: updatedSession, + state: "complete", + metadata: { fields: Object.keys(safeUpdates) }, + }), + ); + } + emitOnboardMachineEvent( + createOnboardMachineEvent({ + type: "onboard.completed", + session: updatedSession, + state: "complete", + }), + ); + return updatedSession; } export function summarizeForDebug( From 651e2a07c3f34bd38cf942d08ad350e5d6b5eb86 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 21:47:57 -0700 Subject: [PATCH 03/12] refactor(cli): persist onboard machine snapshots --- src/lib/actions/inference-set.test.ts | 8 +- src/lib/state/onboard-session.test.ts | 115 ++++++++++++++++++++ src/lib/state/onboard-session.ts | 145 ++++++++++++++++++++++++-- 3 files changed, 259 insertions(+), 9 deletions(-) diff --git a/src/lib/actions/inference-set.test.ts b/src/lib/actions/inference-set.test.ts index ae091f7adf..f6c178f0cf 100644 --- a/src/lib/actions/inference-set.test.ts +++ b/src/lib/actions/inference-set.test.ts @@ -86,9 +86,15 @@ function baseSession(overrides: Partial = {}): Session { telegramConfig: null, wechatConfig: null, metadata: { gatewayName: "nemoclaw", fromDockerfile: null }, + machine: { + version: 1, + state: "complete", + stateEnteredAt: "2026-05-11T00:00:00.000Z", + revision: 0, + }, steps: {}, ...overrides, - }; + } as Session; } function createDeps(options: { diff --git a/src/lib/state/onboard-session.test.ts b/src/lib/state/onboard-session.test.ts index 5ddd94908d..8e4b9f5cbc 100644 --- a/src/lib/state/onboard-session.test.ts +++ b/src/lib/state/onboard-session.test.ts @@ -40,6 +40,14 @@ function requireDebugSummary( return summary; } +function normalizeLegacySession( + legacy: unknown, +): ReturnType { + return session.normalizeSession( + legacy as Parameters[0], + ); +} + beforeEach(() => { // Recreate tmpDir per test so lock artifacts (and any other on-disk state) // from a previous test cannot leak into this one. Without this, malformed @@ -80,6 +88,12 @@ describe("onboard session", () => { const dirStat = fs.statSync(path.dirname(session.SESSION_FILE)); expect(saved.mode).toBe("non-interactive"); + expect(saved.machine).toMatchObject({ + version: 1, + state: "init", + revision: 0, + }); + expect(saved.machine.stateEnteredAt).toBeTruthy(); expect(fs.existsSync(session.SESSION_FILE)).toBe(true); expect(stat.mode & 0o777).toBe(0o600); expect(dirStat.mode & 0o777).toBe(0o700); @@ -124,6 +138,107 @@ describe("onboard session", () => { } expect(loaded.failure.step).toBe("sandbox"); expect(loaded.failure.message).toMatch(/Sandbox creation failed/); + expect(loaded.machine.state).toBe("failed"); + }); + + it("persists a compact machine snapshot across step boundaries", () => { + session.saveSession(session.createSession()); + let loaded = requireLoadedSession(session.loadSession()); + expect(loaded.machine).toMatchObject({ state: "init", revision: 0 }); + + session.markStepStarted("preflight"); + loaded = requireLoadedSession(session.loadSession()); + expect(loaded.machine).toMatchObject({ state: "preflight", revision: 1 }); + expect(loaded.machine.stateEnteredAt).toBe(loaded.steps.preflight.startedAt); + + session.markStepComplete("preflight"); + loaded = requireLoadedSession(session.loadSession()); + expect(loaded.machine).toMatchObject({ state: "gateway", revision: 2 }); + expect(loaded.machine.stateEnteredAt).toBe(loaded.steps.preflight.completedAt); + + session.markStepComplete("gateway"); + loaded = requireLoadedSession(session.loadSession()); + expect(loaded.machine).toMatchObject({ state: "provider_selection", revision: 3 }); + + session.completeSession(); + loaded = requireLoadedSession(session.loadSession()); + expect(loaded.machine).toMatchObject({ state: "complete", revision: 4 }); + expect(requireDebugSummary(session.summarizeForDebug()).machine).toEqual(loaded.machine); + }); + + it("normalizes old sessions without machine snapshots", () => { + type LegacySession = Omit, "machine"> & { + machine?: unknown; + }; + const legacy = session.createSession({ + sessionId: "legacy-session", + startedAt: "2026-01-01T00:00:00.000Z", + updatedAt: "2026-01-01T00:05:00.000Z", + }) as unknown as LegacySession; + delete legacy.machine; + legacy.steps.gateway.status = "in_progress"; + legacy.steps.gateway.startedAt = "2026-01-01T00:02:00.000Z"; + legacy.lastStepStarted = "gateway"; + + let normalized = requireLoadedSession(normalizeLegacySession(legacy)); + expect(normalized.machine).toEqual({ + version: 1, + state: "gateway", + stateEnteredAt: "2026-01-01T00:02:00.000Z", + revision: 0, + }); + + legacy.steps.gateway.status = "complete"; + legacy.steps.gateway.completedAt = "2026-01-01T00:03:00.000Z"; + legacy.lastCompletedStep = "gateway"; + normalized = requireLoadedSession(normalizeLegacySession(legacy)); + expect(normalized.machine).toEqual({ + version: 1, + state: "provider_selection", + stateEnteredAt: "2026-01-01T00:03:00.000Z", + revision: 0, + }); + + legacy.status = "failed"; + legacy.failure = { + step: "gateway", + message: "boom", + recordedAt: "2026-01-01T00:04:00.000Z", + }; + normalized = requireLoadedSession(normalizeLegacySession(legacy)); + expect(normalized.machine).toEqual({ + version: 1, + state: "failed", + stateEnteredAt: "2026-01-01T00:04:00.000Z", + revision: 0, + }); + + legacy.status = "complete"; + normalized = requireLoadedSession(normalizeLegacySession(legacy)); + expect(normalized.machine.state).toBe("complete"); + }); + + it("normalizes invalid machine snapshots from old sessions", () => { + type LegacySession = Omit, "machine"> & { + machine?: unknown; + }; + const legacy = session.createSession({ lastCompletedStep: "policies" }) as unknown as LegacySession; + legacy.steps.policies.status = "complete"; + legacy.steps.policies.completedAt = "2026-01-01T00:08:00.000Z"; + legacy.machine = { + version: 1, + state: "not-a-state", + stateEnteredAt: "2026-01-01T00:09:00.000Z", + revision: -1, + }; + + const normalized = requireLoadedSession(normalizeLegacySession(legacy)); + expect(normalized.machine).toEqual({ + version: 1, + state: "finalizing", + stateEnteredAt: "2026-01-01T00:08:00.000Z", + revision: 0, + }); }); it("emits redacted structured machine events for session step mutations", () => { diff --git a/src/lib/state/onboard-session.ts b/src/lib/state/onboard-session.ts index 7fe94d8096..f739f330d2 100644 --- a/src/lib/state/onboard-session.ts +++ b/src/lib/state/onboard-session.ts @@ -21,10 +21,14 @@ import { import { createOnboardMachineEvent, emitOnboardMachineEvent, + machineStateFromOnboardSessionStep, } from "../onboard/machine/events"; +import { isOnboardMachineState } from "../onboard/machine/transitions"; +import type { OnboardMachineState } from "../onboard/machine/types"; import { redactSensitiveText, redactUrl } from "../security/redact"; export const SESSION_VERSION = 1; +export const MACHINE_SNAPSHOT_VERSION = 1; export const SESSION_DIR = path.join(process.env.HOME || "/tmp", ".nemoclaw"); export const SESSION_FILE = path.join(SESSION_DIR, "onboard-session.json"); export const LOCK_FILE = path.join(SESSION_DIR, "onboard.lock"); @@ -64,6 +68,13 @@ export interface SessionMetadata { fromDockerfile: string | null; } +export interface OnboardMachineSnapshot { + version: typeof MACHINE_SNAPSHOT_VERSION; + state: OnboardMachineState; + stateEnteredAt: string | null; + revision: number; +} + export interface Session { version: number; sessionId: string; @@ -115,6 +126,7 @@ export interface Session { telegramConfig: TelegramConfig | null; wechatConfig: WechatConfig | null; metadata: SessionMetadata; + machine: OnboardMachineSnapshot; steps: Record; } @@ -198,6 +210,7 @@ export interface DebugSessionSummary { lastStepStarted: string | null; lastCompletedStep: string | null; failure: SessionFailure | null; + machine: OnboardMachineSnapshot; steps: Record; } @@ -240,6 +253,10 @@ function readPositiveInteger(value: SessionJsonValue | undefined): number | null return typeof value === "number" && Number.isInteger(value) && value > 0 ? value : null; } +function readNonNegativeInteger(value: SessionJsonValue | undefined): number | null { + return typeof value === "number" && Number.isInteger(value) && value >= 0 ? value : null; +} + function readStringArray(value: SessionJsonValue | undefined): string[] | null { if (!Array.isArray(value)) return null; return value.filter((entry): entry is string => typeof entry === "string"); @@ -308,6 +325,17 @@ function parseStepState(value: SessionJsonValue | undefined): StepState | null { }; } +function parseMachineSnapshot(value: SessionJsonValue | undefined): OnboardMachineSnapshot | null { + if (!isObject(value) || value.version !== MACHINE_SNAPSHOT_VERSION) return null; + if (!isOnboardMachineState(value.state)) return null; + return { + version: MACHINE_SNAPSHOT_VERSION, + state: value.state, + stateEnteredAt: readString(value.stateEnteredAt), + revision: readNonNegativeInteger(value.revision) ?? 0, + }; +} + function parseLockInfo(value: SessionJsonValue | undefined): LockInfo | null { if (!isObject(value) || typeof value.pid !== "number") return null; return { @@ -335,9 +363,97 @@ export function sanitizeFailure( // ── Session CRUD ───────────────────────────────────────────────── +function createMachineSnapshot( + state: OnboardMachineState, + stateEnteredAt: string | null, + revision = 0, +): OnboardMachineSnapshot { + return { + version: MACHINE_SNAPSHOT_VERSION, + state, + stateEnteredAt, + revision: Math.max(0, Math.trunc(revision)), + }; +} + +function nextMachineStateAfterCompletedStep( + stepName: string | null | undefined, + session: Pick, +): OnboardMachineState | null { + switch (stepName) { + case "preflight": + return "gateway"; + case "gateway": + return "provider_selection"; + case "provider_selection": + return "inference"; + case "inference": + return "sandbox"; + case "sandbox": + return session.agent ? "agent_setup" : "openclaw"; + case "openclaw": + case "agent_setup": + return "policies"; + case "policies": + return "finalizing"; + default: + return null; + } +} + +function inferMachineState(session: Session): OnboardMachineState { + if (session.status === "complete") return "complete"; + if (session.status === "failed") return "failed"; + + const startedState = machineStateFromOnboardSessionStep(session.lastStepStarted); + const startedStep = session.lastStepStarted ? session.steps[session.lastStepStarted] : null; + if (startedState && startedStep?.status === "in_progress") return startedState; + + return nextMachineStateAfterCompletedStep(session.lastCompletedStep, session) ?? "init"; +} + +function inferMachineStateEnteredAt(session: Session, state: OnboardMachineState): string | null { + if (state === "failed") return session.failure?.recordedAt ?? session.updatedAt; + if (state === "complete") return session.updatedAt; + + const startedState = machineStateFromOnboardSessionStep(session.lastStepStarted); + const startedStep = session.lastStepStarted ? session.steps[session.lastStepStarted] : null; + if (state === startedState && startedStep?.status === "in_progress") { + return startedStep.startedAt ?? session.updatedAt; + } + + if (nextMachineStateAfterCompletedStep(session.lastCompletedStep, session) === state) { + const completedStep = session.lastCompletedStep ? session.steps[session.lastCompletedStep] : null; + return completedStep?.completedAt ?? session.updatedAt; + } + + return session.startedAt; +} + +function inferMachineSnapshot(session: Session): OnboardMachineSnapshot { + const state = inferMachineState(session); + return createMachineSnapshot(state, inferMachineStateEnteredAt(session, state)); +} + +function transitionMachineSnapshot(session: Session, state: OnboardMachineState, now: string): void { + const current = session.machine ?? createMachineSnapshot("init", session.startedAt); + if (current.state === state) { + session.machine = { + ...current, + stateEnteredAt: current.stateEnteredAt ?? now, + }; + return; + } + session.machine = createMachineSnapshot(state, now, current.revision + 1); +} + export function createSession(overrides: Partial = {}): Session { const now = new Date().toISOString(); - return { + const steps = { + ...defaultSteps(), + ...(overrides.steps ?? {}), + }; + const session: Session = { version: SESSION_VERSION, sessionId: overrides.sessionId ?? `${Date.now()}-${randomUUID()}`, resumable: true, @@ -376,11 +492,11 @@ export function createSession(overrides: Partial = {}): Session { gatewayName: overrides.metadata?.gatewayName ?? "nemoclaw", fromDockerfile: overrides.metadata?.fromDockerfile ?? null, }, - steps: { - ...defaultSteps(), - ...(overrides.steps ?? {}), - }, + machine: parseMachineSnapshot(overrides.machine as SessionJsonValue | undefined) ?? + createMachineSnapshot("init", now), + steps, }; + return session; } export function normalizeSession(data: Session | SessionJsonValue | undefined): Session | null { @@ -429,6 +545,8 @@ export function normalizeSession(data: Session | SessionJsonValue | undefined): } } + normalized.machine = parseMachineSnapshot(data.machine) ?? inferMachineSnapshot(normalized); + return normalized; } @@ -891,13 +1009,16 @@ export function markStepStarted(stepName: string): Session { const updatedSession = updateSession((session) => { const step = session.steps[stepName]; if (!step) return session; + const now = new Date().toISOString(); step.status = "in_progress"; - step.startedAt = new Date().toISOString(); + step.startedAt = now; step.completedAt = null; step.error = null; session.lastStepStarted = stepName; session.failure = null; session.status = "in_progress"; + const state = machineStateFromOnboardSessionStep(stepName); + if (state) transitionMachineSnapshot(session, state, now); shouldEmit = true; return session; }); @@ -915,12 +1036,15 @@ export function markStepComplete(stepName: string, updates: SessionUpdates = {}) const updatedSession = updateSession((session) => { const step = session.steps[stepName]; if (!step) return session; + const now = new Date().toISOString(); step.status = "complete"; - step.completedAt = new Date().toISOString(); + step.completedAt = now; step.error = null; session.lastCompletedStep = stepName; session.failure = null; Object.assign(session, safeUpdates); + const nextState = nextMachineStateAfterCompletedStep(stepName, session); + if (nextState) transitionMachineSnapshot(session, nextState, now); shouldEmit = true; return session; }); @@ -968,15 +1092,17 @@ export function markStepFailed(stepName: string, message: string | null = null): const updatedSession = updateSession((session) => { const step = session.steps[stepName]; if (!step) return session; + const now = new Date().toISOString(); step.status = "failed"; step.completedAt = null; step.error = redactSensitiveText(message); session.failure = sanitizeFailure({ step: stepName, message, - recordedAt: new Date().toISOString(), + recordedAt: now, }); session.status = "failed"; + transitionMachineSnapshot(session, "failed", now); shouldEmit = true; return session; }); @@ -1005,10 +1131,12 @@ export function markStepFailed(stepName: string, message: string | null = null): export function completeSession(updates: SessionUpdates = {}): Session { const safeUpdates = filterSafeUpdates(updates); const updatedSession = updateSession((session) => { + const now = new Date().toISOString(); Object.assign(session, safeUpdates); session.status = "complete"; session.resumable = false; session.failure = null; + transitionMachineSnapshot(session, "complete", now); return session; }); if (Object.keys(safeUpdates).length > 0) { @@ -1057,6 +1185,7 @@ export function summarizeForDebug( lastStepStarted: session.lastStepStarted, lastCompletedStep: session.lastCompletedStep, failure: sanitizeFailure(session.failure), + machine: session.machine, steps: Object.fromEntries( Object.entries(session.steps).map(([name, step]) => [ name, From f756907b5c07a0bb2d09049ab6b4fa7cda681709 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 22:12:25 -0700 Subject: [PATCH 04/12] refactor(cli): add onboard runtime shell --- src/lib/onboard/machine/runtime.test.ts | 184 +++++++++++++++++ src/lib/onboard/machine/runtime.ts | 263 ++++++++++++++++++++++++ 2 files changed, 447 insertions(+) create mode 100644 src/lib/onboard/machine/runtime.test.ts create mode 100644 src/lib/onboard/machine/runtime.ts diff --git a/src/lib/onboard/machine/runtime.test.ts b/src/lib/onboard/machine/runtime.test.ts new file mode 100644 index 0000000000..becca6028e --- /dev/null +++ b/src/lib/onboard/machine/runtime.test.ts @@ -0,0 +1,184 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it } from "vitest"; + +import { + createSession, + filterSafeUpdates, + normalizeSession, + type Session, +} from "../../state/onboard-session"; +import type { OnboardMachineEvent } from "./events"; +import { OnboardRuntime, type OnboardRuntimeDeps } from "./runtime"; +import { InvalidOnboardMachineTransitionError } from "./transitions"; + +function cloneSession(session: Session): Session { + return normalizeSession(JSON.parse(JSON.stringify(session))) ?? session; +} + +function createHarness(initialSession: Session | null = createSession()) { + let session = initialSession ? cloneSession(initialSession) : null; + const events: OnboardMachineEvent[] = []; + let tick = 0; + const deps: OnboardRuntimeDeps = { + loadSession: () => (session ? cloneSession(session) : null), + createSession: (overrides) => createSession(overrides), + saveSession: (next) => { + session = cloneSession(next); + return cloneSession(session); + }, + updateSession: (mutator) => { + const current = session ? cloneSession(session) : createSession(); + const next = mutator(current) ?? current; + session = cloneSession(next); + return cloneSession(session); + }, + filterSafeUpdates, + emitEvent: (event) => events.push(event), + now: () => `2026-05-19T00:00:${String(tick++).padStart(2, "0")}.000Z`, + }; + return { + runtime: new OnboardRuntime(deps), + events, + getSession: () => { + if (!session) throw new Error("Expected runtime session"); + return cloneSession(session); + }, + }; +} + +function sessionInState(state: Session["machine"]["state"]): Session { + const session = createSession(); + session.machine = { + version: 1, + state, + stateEnteredAt: "2026-05-19T00:00:00.000Z", + revision: 7, + }; + return session; +} + +describe("OnboardRuntime", () => { + it("starts a session and emits started/resumed lifecycle events", async () => { + const { runtime, events, getSession } = createHarness(null); + + const started = await runtime.start(); + expect(started.machine.state).toBe("init"); + expect(getSession().machine.state).toBe("init"); + expect(events[0]).toMatchObject({ type: "onboard.started", state: "init" }); + + await runtime.start({ resumed: true }); + expect(events[1]).toMatchObject({ type: "onboard.resumed", state: "init" }); + }); + + it("validates and persists explicit transitions", async () => { + const { runtime, events, getSession } = createHarness(); + + await runtime.transition("preflight"); + + expect(getSession().machine).toEqual({ + version: 1, + state: "preflight", + stateEnteredAt: "2026-05-19T00:00:00.000Z", + revision: 1, + }); + expect(events.map((event) => event.type)).toEqual(["state.exited", "state.entered"]); + expect(events[0]).toMatchObject({ state: "init" }); + expect(events[1]).toMatchObject({ state: "preflight" }); + + await expect(runtime.transition("sandbox")).rejects.toThrow( + InvalidOnboardMachineTransitionError, + ); + expect(getSession().machine.state).toBe("preflight"); + }); + + it("applies only safe context updates and emits redacted context events", async () => { + const { runtime, events, getSession } = createHarness(); + + await runtime.updateContext({ + provider: "nvidia-prod", + endpointUrl: "https://alice:secret@example.com/v1?token=super-secret&keep=yes#token=frag", + credentialEnv: "NVIDIA_API_KEY", + apiKey: "super-secret", + } as Parameters[0] & { apiKey: string }); + + expect(getSession()).toMatchObject({ + provider: "nvidia-prod", + endpointUrl: "https://example.com/v1?token=%3CREDACTED%3E&keep=yes", + credentialEnv: "NVIDIA_API_KEY", + }); + expect("apiKey" in getSession()).toBe(false); + expect(events).toHaveLength(1); + expect(events[0]).toMatchObject({ type: "context.updated", state: "init" }); + expect(events[0].metadata.fields).toEqual(["provider", "endpointUrl", "credentialEnv"]); + expect(JSON.stringify(events)).not.toContain("super-secret"); + }); + + it("fails non-terminal sessions with redacted failure events", async () => { + const { runtime, events, getSession } = createHarness(sessionInState("gateway")); + + await runtime.fail("NVIDIA_API_KEY=super-secret", { step: "gateway" }); + + expect(getSession()).toMatchObject({ + status: "failed", + failure: { step: "gateway", message: "NVIDIA_API_KEY=" }, + machine: { state: "failed", revision: 8 }, + }); + expect(events.map((event) => event.type)).toEqual(["state.failed", "onboard.failed"]); + expect(events[0]).toMatchObject({ state: "gateway", step: "gateway" }); + expect(events[1]).toMatchObject({ state: "failed", step: "gateway" }); + expect(JSON.stringify(events)).not.toContain("super-secret"); + }); + + it("rejects terminal-state failure and invalid completion transitions", async () => { + const completeHarness = createHarness(sessionInState("complete")); + await expect(completeHarness.runtime.fail("boom")).rejects.toThrow("complete -> failed"); + expect(completeHarness.getSession().machine.state).toBe("complete"); + + const policiesHarness = createHarness(sessionInState("policies")); + await expect(policiesHarness.runtime.complete()).rejects.toThrow("policies -> complete"); + expect(policiesHarness.getSession().machine.state).toBe("policies"); + }); + + it("completes from post_verify and emits completion events", async () => { + const { runtime, events, getSession } = createHarness(sessionInState("post_verify")); + + await runtime.complete({ sandboxName: "my-assistant" }); + + expect(getSession()).toMatchObject({ + status: "complete", + resumable: false, + sandboxName: "my-assistant", + machine: { state: "complete", revision: 8 }, + }); + expect(events.map((event) => event.type)).toEqual([ + "context.updated", + "state.completed", + "state.entered", + "onboard.completed", + ]); + }); + + it("emits skipped and repair events without mutating durable state", async () => { + const { runtime, events, getSession } = createHarness(sessionInState("provider_selection")); + + await runtime.markSkipped("provider_selection", { reason: "resume" }); + await runtime.emitRepairEvent("state.repair.started", { + state: "provider_selection", + metadata: { action: "ollama-systemd" }, + }); + await runtime.emitRepairEvent("state.repair.completed", { state: "provider_selection" }); + + expect(getSession().machine.state).toBe("provider_selection"); + expect(events.map((event) => event.type)).toEqual([ + "state.skipped", + "state.repair.started", + "state.repair.completed", + ]); + expect(events[0].metadata.reason).toBe("resume"); + await expect(runtime.markSkipped("complete")).rejects.toThrow( + "Terminal onboarding state cannot be skipped", + ); + }); +}); diff --git a/src/lib/onboard/machine/runtime.ts b/src/lib/onboard/machine/runtime.ts new file mode 100644 index 0000000000..3e72cd0ccc --- /dev/null +++ b/src/lib/onboard/machine/runtime.ts @@ -0,0 +1,263 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { JsonObject } from "../../core/json-types"; +import * as onboardSession from "../../state/onboard-session"; +import type { Session, SessionUpdates } from "../../state/onboard-session"; +import { + createOnboardMachineEvent, + emitOnboardMachineEvent, + type OnboardMachineEvent, +} from "./events"; +import { + assertValidOnboardMachineTransition, + canTransitionOnboardMachineState, + isTerminalOnboardMachineState, +} from "./transitions"; +import type { OnboardMachineEventType, OnboardMachineState } from "./types"; + +export interface OnboardRuntimeDeps { + loadSession(): Session | null; + createSession(overrides?: Partial): Session; + saveSession(session: Session): Session; + updateSession(mutator: (session: Session) => Session | void): Session; + filterSafeUpdates(updates: SessionUpdates): Partial; + emitEvent(event: OnboardMachineEvent): void; + now(): string; +} + +export type OnboardRuntimeTransitionOptions = { + metadata?: Record | null; +}; + +export type OnboardRuntimeUpdateOptions = { + state?: OnboardMachineState | null; + metadata?: Record | null; +}; + +export type OnboardRuntimeFailureOptions = { + step?: string | null; + metadata?: Record | null; +}; + +function defaultDeps(): OnboardRuntimeDeps { + return { + loadSession: onboardSession.loadSession, + createSession: onboardSession.createSession, + saveSession: onboardSession.saveSession, + updateSession: onboardSession.updateSession, + filterSafeUpdates: onboardSession.filterSafeUpdates, + emitEvent: emitOnboardMachineEvent, + now: () => new Date().toISOString(), + }; +} + +function eventMetadata(metadata: Record | null | undefined): JsonObject { + return metadata && typeof metadata === "object" && !Array.isArray(metadata) + ? (metadata as JsonObject) + : {}; +} + +function snapshotFor( + state: OnboardMachineState, + stateEnteredAt: string | null, + revision: number, +): onboardSession.OnboardMachineSnapshot { + return { + version: onboardSession.MACHINE_SNAPSHOT_VERSION, + state, + stateEnteredAt, + revision: Math.max(0, Math.trunc(revision)), + }; +} + +export class OnboardRuntime { + private readonly deps: OnboardRuntimeDeps; + + constructor(deps: Partial = {}) { + this.deps = { ...defaultDeps(), ...deps }; + } + + async session(): Promise { + return this.ensureSession(); + } + + async start(options: { resumed?: boolean; metadata?: Record | null } = {}): Promise { + const session = this.ensureSession(); + this.emit(options.resumed === true ? "onboard.resumed" : "onboard.started", session, { + state: session.machine.state, + metadata: options.metadata, + }); + return session; + } + + async transition( + to: OnboardMachineState, + options: OnboardRuntimeTransitionOptions = {}, + ): Promise { + const current = this.ensureSession(); + const from = current.machine.state; + assertValidOnboardMachineTransition(from, to); + + const enteredAt = this.deps.now(); + const updated = this.deps.updateSession((session) => { + session.machine = snapshotFor(to, enteredAt, session.machine.revision + 1); + if (to === "failed") { + session.status = "failed"; + } else if (to === "complete") { + session.status = "complete"; + session.resumable = false; + session.failure = null; + } else if (session.status !== "failed") { + session.status = "in_progress"; + } + return session; + }); + + this.emit("state.exited", updated, { state: from, metadata: options.metadata }); + this.emit("state.entered", updated, { state: to, metadata: options.metadata }); + return updated; + } + + async updateContext( + updates: SessionUpdates, + options: OnboardRuntimeUpdateOptions = {}, + ): Promise { + const safeUpdates = this.deps.filterSafeUpdates(updates); + const fields = Object.keys(safeUpdates); + const updated = this.deps.updateSession((session) => { + Object.assign(session, safeUpdates); + return session; + }); + if (fields.length > 0) { + this.emit("context.updated", updated, { + state: options.state ?? updated.machine.state, + metadata: { ...eventMetadata(options.metadata), fields }, + }); + } + return updated; + } + + async complete(updates: SessionUpdates = {}): Promise { + const current = this.ensureSession(); + const from = current.machine.state; + assertValidOnboardMachineTransition(from, "complete"); + + const safeUpdates = this.deps.filterSafeUpdates(updates); + const fields = Object.keys(safeUpdates); + const enteredAt = this.deps.now(); + const updated = this.deps.updateSession((session) => { + Object.assign(session, safeUpdates); + session.status = "complete"; + session.resumable = false; + session.failure = null; + session.machine = snapshotFor("complete", enteredAt, session.machine.revision + 1); + return session; + }); + + if (fields.length > 0) { + this.emit("context.updated", updated, { + state: "complete", + metadata: { fields }, + }); + } + this.emit("state.completed", updated, { state: from }); + this.emit("state.entered", updated, { state: "complete" }); + this.emit("onboard.completed", updated, { state: "complete" }); + return updated; + } + + async fail(message: string | null, options: OnboardRuntimeFailureOptions = {}): Promise { + const current = this.ensureSession(); + const from = current.machine.state; + if (!canTransitionOnboardMachineState(from, "failed")) { + assertValidOnboardMachineTransition(from, "failed"); + } + + const recordedAt = this.deps.now(); + const updated = this.deps.updateSession((session) => { + session.status = "failed"; + session.failure = onboardSession.sanitizeFailure({ + step: options.step ?? null, + message, + recordedAt, + }); + session.machine = snapshotFor("failed", recordedAt, session.machine.revision + 1); + return session; + }); + + this.emit("state.failed", updated, { + state: from, + step: options.step, + error: message, + metadata: options.metadata, + }); + this.emit("onboard.failed", updated, { + state: "failed", + step: options.step, + error: message, + metadata: options.metadata, + }); + return updated; + } + + async markSkipped( + state: OnboardMachineState, + metadata: Record | null = null, + ): Promise { + const session = this.ensureSession(); + if (isTerminalOnboardMachineState(state)) { + throw new Error(`Terminal onboarding state cannot be skipped: ${state}`); + } + this.emit("state.skipped", session, { state, metadata }); + return session; + } + + async emitRepairEvent( + type: Extract< + OnboardMachineEventType, + "state.repair.started" | "state.repair.completed" | "state.repair.failed" + >, + options: { + state?: OnboardMachineState | null; + error?: string | null; + metadata?: Record | null; + } = {}, + ): Promise { + const session = this.ensureSession(); + this.emit(type, session, { + state: options.state ?? session.machine.state, + error: options.error ?? null, + metadata: options.metadata, + }); + return session; + } + + private ensureSession(): Session { + const existing = this.deps.loadSession(); + if (existing) return existing; + return this.deps.saveSession(this.deps.createSession()); + } + + private emit( + type: OnboardMachineEventType, + session: Session, + options: { + state?: OnboardMachineState | null; + step?: string | null; + error?: string | null; + metadata?: Record | null; + } = {}, + ): void { + this.deps.emitEvent( + createOnboardMachineEvent({ + type, + session, + state: options.state ?? session.machine.state, + step: options.step ?? null, + error: options.error ?? null, + metadata: options.metadata, + }), + ); + } +} From 702454b2d9c3547a95a4c51f4f4ec9a6c5780ca0 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 22:26:42 -0700 Subject: [PATCH 05/12] refactor(cli): route onboard step boundaries through runtime --- src/lib/agent/onboard.ts | 4 +- src/lib/onboard.ts | 87 +++++++++++++++---------- src/lib/onboard/machine/runtime.test.ts | 56 ++++++++++++++-- src/lib/onboard/machine/runtime.ts | 30 +++++++++ 4 files changed, 135 insertions(+), 42 deletions(-) diff --git a/src/lib/agent/onboard.ts b/src/lib/agent/onboard.ts index 2446108910..f08c32b9c6 100644 --- a/src/lib/agent/onboard.ts +++ b/src/lib/agent/onboard.ts @@ -31,7 +31,7 @@ export interface OnboardContext { buildSandboxConfigSyncScript: (config: LooseObject) => string; writeSandboxConfigSyncFile: (script: string) => string; cleanupTempDir: (file: string, prefix: string) => void; - startRecordedStep: (stepName: string, updates: LooseObject) => void; + startRecordedStep: (stepName: string, updates: LooseObject) => Promise; skippedStepMessage: (stepName: string, sandboxName: string) => void; } @@ -424,7 +424,7 @@ export async function handleAgentSetup( } } - startRecordedStep("agent_setup", { sandboxName, provider, model }); + await startRecordedStep("agent_setup", { sandboxName, provider, model }); step(7, 8, `Setting up ${agent.displayName} inside sandbox`); const binaryAvailability = verifyAgentBinaryAvailable(sandboxName, agent, runCaptureOpenshell); diff --git a/src/lib/onboard.ts b/src/lib/onboard.ts index bc231df3a5..470639b346 100644 --- a/src/lib/onboard.ts +++ b/src/lib/onboard.ts @@ -279,6 +279,7 @@ const { resolveSandboxImageTagFromCreateOutput } = require("./domain/sandbox/image-tag") as typeof import("./domain/sandbox/image-tag"); const nim: typeof import("./inference/nim") = require("./inference/nim"); const onboardSession: typeof import("./state/onboard-session") = require("./state/onboard-session"); +const { OnboardRuntime }: typeof import("./onboard/machine/runtime") = require("./onboard/machine/runtime"); const policies: typeof import("./policy") = require("./policy"); const tiers: typeof import("./policy/tiers") = require("./policy/tiers"); const { ensureUsageNoticeConsent } = require("./onboard/usage-notice"); @@ -409,6 +410,7 @@ const USE_COLOR = !process.env.NO_COLOR && !!process.stdout.isTTY; const DIM = USE_COLOR ? "\x1b[2m" : ""; const RESET = USE_COLOR ? "\x1b[0m" : ""; let OPENSHELL_BIN: string | null = null; +let ONBOARD_RUNTIME: import("./onboard/machine/runtime").OnboardRuntime | null = null; const GATEWAY_NAME = "nemoclaw"; const BACK_TO_SELECTION = "__NEMOCLAW_BACK_TO_SELECTION__"; type HermesAuthMethod = "oauth" | "api_key"; @@ -9017,7 +9019,12 @@ function toSessionUpdates( return normalized; } -function startRecordedStep( +function getOnboardRuntime(): import("./onboard/machine/runtime").OnboardRuntime { + if (!ONBOARD_RUNTIME) ONBOARD_RUNTIME = new OnboardRuntime(); + return ONBOARD_RUNTIME; +} + +async function startRecordedStep( stepName: string, updates: { sandboxName?: string | null; @@ -9025,20 +9032,30 @@ function startRecordedStep( model?: string | null; policyPresets?: string[] | null; } = {}, -): void { - onboardSession.markStepStarted(stepName); +): Promise { + const runtime = getOnboardRuntime(); + await runtime.markStepStarted(stepName); if (Object.keys(updates).length > 0) { - onboardSession.updateSession((session: Session) => { - if (updates.sandboxName !== undefined) session.sandboxName = updates.sandboxName; - if (updates.provider !== undefined) session.provider = updates.provider; - if (updates.model !== undefined) session.model = updates.model; - if (updates.policyPresets !== undefined) session.policyPresets = updates.policyPresets; - return session; - }); + await runtime.updateContext(toSessionUpdates(updates)); } maybeForceE2eStepFailure(stepName); } +async function recordStepComplete( + stepName: string, + updates: SessionUpdates = {}, +): Promise { + return getOnboardRuntime().markStepComplete(stepName, updates); +} + +async function recordStepSkipped(stepName: string): Promise { + return getOnboardRuntime().markStepSkipped(stepName); +} + +async function recordSessionComplete(updates: SessionUpdates = {}): Promise { + return getOnboardRuntime().completeSession(updates); +} + const ONBOARD_STEP_INDEX: Record = { preflight: { number: 1, title: "Preflight checks" }, gateway: { number: 2, title: "Starting OpenShell gateway" }, @@ -9074,6 +9091,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { RECREATE_SANDBOX = opts.recreateSandbox || process.env.NEMOCLAW_RECREATE_SANDBOX === "1"; AUTO_YES = opts.autoYes === true || process.env.NEMOCLAW_YES === "1"; _preflightDashboardPort = opts.controlUiPort || null; + ONBOARD_RUNTIME = new OnboardRuntime(); delete process.env.OPENSHELL_GATEWAY; const resume = opts.resume === true; const fresh = opts.fresh === true; @@ -9422,9 +9440,9 @@ async function onboard(opts: OnboardOptions = {}): Promise { }), ); } else { - startRecordedStep("preflight"); + await startRecordedStep("preflight"); gpu = await preflight({ ...opts, optedOutGpuPassthrough: opts.noGpu === true }); - onboardSession.markStepComplete("preflight"); + await recordStepComplete("preflight"); } const sandboxGpuConfig = resolveSandboxGpuConfig(gpu, { flag: effectiveSandboxGpuFlag, @@ -9560,11 +9578,11 @@ async function onboard(opts: OnboardOptions = {}): Promise { resume && session?.steps?.gateway?.status === "complete" && canReuseHealthyGateway; if (resumeGateway) { skippedStepMessage("gateway", "running"); - onboardSession.markStepComplete("gateway"); + await recordStepComplete("gateway"); } else if (!resume && canReuseHealthyGateway) { skippedStepMessage("gateway", "running", "reuse"); note(" Reusing healthy NemoClaw gateway."); - onboardSession.markStepComplete("gateway"); + await recordStepComplete("gateway"); } else { if (resume && session?.steps?.gateway?.status === "complete") { if (gatewayReuseState === "active-unnamed") { @@ -9582,9 +9600,9 @@ async function onboard(opts: OnboardOptions = {}): Promise { retireLegacyGatewayForDockerDriverUpgrade(); gatewayReuseState = "missing"; } - startRecordedStep("gateway"); + await startRecordedStep("gateway"); await startGateway(gpu, { gpuPassthrough }); - onboardSession.markStepComplete("gateway"); + await recordStepComplete("gateway"); } // #2753: prefer requestedSandboxName over an unconfirmed session name. @@ -9635,7 +9653,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { // below). A SIGINT between any earlier step and createSandbox would // otherwise leave a phantom that `nemoclaw list` resurrects until // manually destroyed. - startRecordedStep("provider_selection"); + await startRecordedStep("provider_selection"); const selection = await setupNim(gpu, sandboxName, agent); model = selection.model; provider = selection.provider; @@ -9645,7 +9663,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { hermesToolGateways = selection.hermesToolGateways; preferredInferenceApi = selection.preferredInferenceApi; nimContainer = selection.nimContainer; - onboardSession.markStepComplete( + await recordStepComplete( "provider_selection", toSessionUpdates({ provider, @@ -9678,7 +9696,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { if (!sandboxName) { sandboxName = await promptValidatedSandboxName(agent); } - startRecordedStep("inference", { provider, model }); + await startRecordedStep("inference", { provider, model }); const inferenceResult = await setupInference( sandboxName, model, @@ -9692,7 +9710,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { forceProviderSelection = true; continue; } - onboardSession.markStepComplete( + await recordStepComplete( "inference", toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), ); @@ -9712,7 +9730,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { if (nimContainer && sandboxName) { registry.updateSandbox(sandboxName, { nimContainer }); } - onboardSession.markStepComplete( + await recordStepComplete( "inference", toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), ); @@ -9751,7 +9769,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { } } - startRecordedStep("inference", { provider, model }); + await startRecordedStep("inference", { provider, model }); const inferenceResult = await setupInference( sandboxName, model, @@ -9769,7 +9787,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { if (nimContainer && sandboxName) { registry.updateSandbox(sandboxName, { nimContainer }); } - onboardSession.markStepComplete( + await recordStepComplete( "inference", toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), ); @@ -9906,7 +9924,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { } else { nextWebSearchConfig = await configureWebSearch(null, agent, webSearchSupportProbePath); } - startRecordedStep("sandbox", { provider, model }); + await startRecordedStep("sandbox", { provider, model }); const recordedMessagingChannels = getRecordedMessagingChannelsForResume(resume, session, sandboxName); if (recordedMessagingChannels) { selectedMessagingChannels = recordedMessagingChannels; @@ -9960,7 +9978,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { ...getSandboxAgentRegistryFields(agent, !fromDockerfile), }); registry.setDefault(sandboxName); - onboardSession.markStepComplete( + await recordStepComplete( "sandbox", toSessionUpdates({ sandboxName, @@ -9996,24 +10014,24 @@ async function onboard(opts: OnboardOptions = {}): Promise { skippedStepMessage, }); ensureAgentDashboardForward(sandboxName, agent); - onboardSession.markStepSkipped("openclaw"); + await recordStepSkipped("openclaw"); } else { const resumeOpenclaw = resume && sandboxName && isOpenclawReady(sandboxName); if (resumeOpenclaw) { skippedStepMessage("openclaw", sandboxName); - onboardSession.markStepComplete( + await recordStepComplete( "openclaw", toSessionUpdates({ sandboxName, provider, model, hermesAuthMethod, hermesToolGateways }), ); } else { - startRecordedStep("openclaw", { sandboxName, provider, model }); + await startRecordedStep("openclaw", { sandboxName, provider, model }); await setupOpenclaw(sandboxName, model, provider); - onboardSession.markStepComplete( + await recordStepComplete( "openclaw", toSessionUpdates({ sandboxName, provider, model, hermesAuthMethod, hermesToolGateways }), ); } - onboardSession.markStepSkipped("agent_setup"); + await recordStepSkipped("agent_setup"); } const latestSession = onboardSession.loadSession(); @@ -10066,7 +10084,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { arePolicyPresetsApplied(sandboxName, recordedPolicyPresetsForSupport); if (resumePolicies) { skippedStepMessage("policies", recordedPolicyPresetsForSupport.join(", ")); - onboardSession.markStepComplete( + await recordStepComplete( "policies", toSessionUpdates({ sandboxName, @@ -10076,7 +10094,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { }), ); } else { - startRecordedStep("policies", { + await startRecordedStep("policies", { sandboxName, provider, model, @@ -10102,7 +10120,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { }); }, }); - onboardSession.markStepComplete( + await recordStepComplete( "policies", toSessionUpdates({ sandboxName, provider, model, policyPresets: appliedPolicyPresets }), ); @@ -10112,7 +10130,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { ensureAgentDashboardForward(sandboxName, agent); } - onboardSession.completeSession( + await recordSessionComplete( toSessionUpdates({ sandboxName, provider, model, hermesAuthMethod, hermesToolGateways }), ); completed = true; @@ -10192,6 +10210,7 @@ async function onboard(opts: OnboardOptions = {}): Promise { printDashboard(sandboxName, model, provider, nimContainer, agent); } finally { releaseOnboardLock(); + ONBOARD_RUNTIME = null; } } diff --git a/src/lib/onboard/machine/runtime.test.ts b/src/lib/onboard/machine/runtime.test.ts index becca6028e..7b26269541 100644 --- a/src/lib/onboard/machine/runtime.test.ts +++ b/src/lib/onboard/machine/runtime.test.ts @@ -7,7 +7,9 @@ import { createSession, filterSafeUpdates, normalizeSession, + sanitizeFailure, type Session, + type SessionUpdates, } from "../../state/onboard-session"; import type { OnboardMachineEvent } from "./events"; import { OnboardRuntime, type OnboardRuntimeDeps } from "./runtime"; @@ -21,6 +23,12 @@ function createHarness(initialSession: Session | null = createSession()) { let session = initialSession ? cloneSession(initialSession) : null; const events: OnboardMachineEvent[] = []; let tick = 0; + const updateSession = (mutator: (value: Session) => Session | void): Session => { + const current = session ? cloneSession(session) : createSession(); + const next = mutator(current) ?? current; + session = cloneSession(next); + return cloneSession(session); + }; const deps: OnboardRuntimeDeps = { loadSession: () => (session ? cloneSession(session) : null), createSession: (overrides) => createSession(overrides), @@ -28,12 +36,48 @@ function createHarness(initialSession: Session | null = createSession()) { session = cloneSession(next); return cloneSession(session); }, - updateSession: (mutator) => { - const current = session ? cloneSession(session) : createSession(); - const next = mutator(current) ?? current; - session = cloneSession(next); - return cloneSession(session); - }, + updateSession, + markStepStarted: (stepName) => + updateSession((current) => { + const step = current.steps[stepName]; + if (!step) return current; + step.status = "in_progress"; + current.lastStepStarted = stepName; + current.status = "in_progress"; + return current; + }), + markStepComplete: (stepName, updates: SessionUpdates = {}) => + updateSession((current) => { + const step = current.steps[stepName]; + if (!step) return current; + step.status = "complete"; + current.lastCompletedStep = stepName; + Object.assign(current, filterSafeUpdates(updates)); + return current; + }), + markStepSkipped: (stepName) => + updateSession((current) => { + const step = current.steps[stepName]; + if (!step) return current; + step.status = "skipped"; + return current; + }), + markStepFailed: (stepName, message) => + updateSession((current) => { + const step = current.steps[stepName]; + if (!step) return current; + step.status = "failed"; + current.status = "failed"; + current.failure = sanitizeFailure({ step: stepName, message, recordedAt: "now" }); + return current; + }), + completeSession: (updates: SessionUpdates = {}) => + updateSession((current) => { + Object.assign(current, filterSafeUpdates(updates)); + current.status = "complete"; + current.resumable = false; + return current; + }), filterSafeUpdates, emitEvent: (event) => events.push(event), now: () => `2026-05-19T00:00:${String(tick++).padStart(2, "0")}.000Z`, diff --git a/src/lib/onboard/machine/runtime.ts b/src/lib/onboard/machine/runtime.ts index 3e72cd0ccc..2e5d584f3b 100644 --- a/src/lib/onboard/machine/runtime.ts +++ b/src/lib/onboard/machine/runtime.ts @@ -21,6 +21,11 @@ export interface OnboardRuntimeDeps { createSession(overrides?: Partial): Session; saveSession(session: Session): Session; updateSession(mutator: (session: Session) => Session | void): Session; + markStepStarted(stepName: string): Session; + markStepComplete(stepName: string, updates?: SessionUpdates): Session; + markStepSkipped(stepName: string): Session; + markStepFailed(stepName: string, message?: string | null): Session; + completeSession(updates?: SessionUpdates): Session; filterSafeUpdates(updates: SessionUpdates): Partial; emitEvent(event: OnboardMachineEvent): void; now(): string; @@ -46,6 +51,11 @@ function defaultDeps(): OnboardRuntimeDeps { createSession: onboardSession.createSession, saveSession: onboardSession.saveSession, updateSession: onboardSession.updateSession, + markStepStarted: onboardSession.markStepStarted, + markStepComplete: onboardSession.markStepComplete, + markStepSkipped: onboardSession.markStepSkipped, + markStepFailed: onboardSession.markStepFailed, + completeSession: onboardSession.completeSession, filterSafeUpdates: onboardSession.filterSafeUpdates, emitEvent: emitOnboardMachineEvent, now: () => new Date().toISOString(), @@ -91,6 +101,26 @@ export class OnboardRuntime { return session; } + async markStepStarted(stepName: string): Promise { + return this.deps.markStepStarted(stepName); + } + + async markStepComplete(stepName: string, updates: SessionUpdates = {}): Promise { + return this.deps.markStepComplete(stepName, updates); + } + + async markStepSkipped(stepName: string): Promise { + return this.deps.markStepSkipped(stepName); + } + + async markStepFailed(stepName: string, message: string | null = null): Promise { + return this.deps.markStepFailed(stepName, message); + } + + async completeSession(updates: SessionUpdates = {}): Promise { + return this.deps.completeSession(updates); + } + async transition( to: OnboardMachineState, options: OnboardRuntimeTransitionOptions = {}, From 60acb65261157d19741963f604f148474da04218 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 22:35:53 -0700 Subject: [PATCH 06/12] refactor(cli): add observe-only onboard hooks --- src/lib/onboard/machine/hooks.test.ts | 150 ++++++++++++++++++++++++++ src/lib/onboard/machine/hooks.ts | 132 +++++++++++++++++++++++ 2 files changed, 282 insertions(+) create mode 100644 src/lib/onboard/machine/hooks.test.ts create mode 100644 src/lib/onboard/machine/hooks.ts diff --git a/src/lib/onboard/machine/hooks.test.ts b/src/lib/onboard/machine/hooks.test.ts new file mode 100644 index 0000000000..ec0fe0fcc7 --- /dev/null +++ b/src/lib/onboard/machine/hooks.test.ts @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import fs from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, describe, expect, it } from "vitest"; + +import { createSession } from "../../state/onboard-session"; +import { + clearOnboardMachineEventListeners, + createOnboardMachineEvent, + emitOnboardMachineEvent, + type OnboardMachineEvent, +} from "./events"; +import { createJsonlOnboardHook, OnboardHookDispatcher, registerOnboardHooks } from "./hooks"; + +function sampleEvent(): OnboardMachineEvent { + const session = createSession({ + sessionId: "session-1", + provider: "nvidia-prod", + endpointUrl: "https://example.com/v1?token=secret&keep=yes", + }); + return createOnboardMachineEvent({ + type: "state.entered", + session, + state: "gateway", + step: "gateway", + }); +} + +afterEach(() => { + clearOnboardMachineEventListeners(); +}); + +describe("onboard machine hooks", () => { + it("dispatches observe-only events and emits hook lifecycle events", async () => { + const observed: string[] = []; + const lifecycle: OnboardMachineEvent[] = []; + const dispatcher = new OnboardHookDispatcher( + [ + { + name: "observer", + onEvent(event) { + observed.push(event.type); + }, + }, + ], + { + emitEvent: (event) => lifecycle.push(event), + now: () => "2026-05-19T01:00:00.000Z", + }, + ); + + await dispatcher.dispatch(sampleEvent()); + + expect(observed).toEqual(["state.entered"]); + expect(lifecycle.map((event) => event.type)).toEqual(["hook.started", "hook.completed"]); + expect(lifecycle[0]).toMatchObject({ + sessionId: "session-1", + state: "gateway", + step: "gateway", + metadata: { hook: "observer", sourceType: "state.entered" }, + }); + }); + + it("warns and emits hook.failed without throwing when a hook fails", async () => { + const warnings: string[] = []; + const lifecycle: OnboardMachineEvent[] = []; + const dispatcher = new OnboardHookDispatcher( + [ + { + name: "bad-hook", + async onEvent() { + throw new Error("Bearer super-secret-token"); + }, + }, + ], + { + warn: (message) => warnings.push(message), + emitEvent: (event) => lifecycle.push(event), + now: () => "2026-05-19T01:00:00.000Z", + }, + ); + + await expect(dispatcher.dispatch(sampleEvent())).resolves.toBeUndefined(); + + expect(lifecycle.map((event) => event.type)).toEqual(["hook.started", "hook.failed"]); + expect(lifecycle[1]).toMatchObject({ + type: "hook.failed", + error: "Bearer ", + metadata: { hook: "bad-hook", sourceType: "state.entered" }, + }); + expect(warnings).toEqual(["Onboard hook 'bad-hook' failed: Bearer "]); + expect(JSON.stringify(lifecycle)).not.toContain("super-secret-token"); + expect(warnings.join("\n")).not.toContain("super-secret-token"); + }); + + it("writes JSONL hook events to an external sink", async () => { + const tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "nemoclaw-hooks-")); + try { + const filePath = path.join(tmpDir, "events.jsonl"); + const hook = createJsonlOnboardHook(filePath); + + await hook.onEvent?.(sampleEvent()); + await hook.onEvent?.( + createOnboardMachineEvent({ + type: "state.completed", + session: createSession({ sessionId: "session-1" }), + state: "gateway", + step: "gateway", + }), + ); + + const lines = fs + .readFileSync(filePath, "utf8") + .trim() + .split("\n") + .map((line) => JSON.parse(line)); + expect(lines.map((event) => event.type)).toEqual(["state.entered", "state.completed"]); + expect(lines[0].context.endpointUrl).toBe( + "https://example.com/v1?token=%3CREDACTED%3E&keep=yes", + ); + } finally { + fs.rmSync(tmpDir, { recursive: true, force: true }); + } + }); + + it("registers hooks on the machine event bus without redispatching hook lifecycle events", async () => { + const observed: string[] = []; + const unregister = registerOnboardHooks([ + { + name: "bus-observer", + onEvent(event) { + observed.push(event.type); + }, + }, + ]); + + emitOnboardMachineEvent(sampleEvent()); + await Promise.resolve(); + emitOnboardMachineEvent({ ...sampleEvent(), type: "hook.failed" }); + await Promise.resolve(); + unregister(); + emitOnboardMachineEvent({ ...sampleEvent(), type: "state.completed" }); + await Promise.resolve(); + + expect(observed).toEqual(["state.entered"]); + }); +}); diff --git a/src/lib/onboard/machine/hooks.ts b/src/lib/onboard/machine/hooks.ts new file mode 100644 index 0000000000..1dfcd7544a --- /dev/null +++ b/src/lib/onboard/machine/hooks.ts @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import fs from "node:fs"; +import path from "node:path"; + +import { redactSensitiveText } from "../../security/redact"; +import { + addOnboardMachineEventListener, + emitOnboardMachineEvent, + sanitizeOnboardMachineEventMetadata, + type OnboardMachineEvent, + type OnboardMachineEventListener, +} from "./events"; + +export interface OnboardHook { + name?: string; + onEvent?(event: OnboardMachineEvent): Promise | void; +} + +export interface OnboardHookDispatchOptions { + warn?: (message: string) => void; + emitEvent?: (event: OnboardMachineEvent) => void; + now?: () => string; +} + +export interface OnboardHookRegistrationOptions extends OnboardHookDispatchOptions { + includeHookEvents?: boolean; +} + +function hookName(hook: OnboardHook, index: number): string { + const name = typeof hook.name === "string" ? hook.name.trim() : ""; + return name || `hook-${index + 1}`; +} + +function hookLifecycleEvent( + source: OnboardMachineEvent, + type: "hook.started" | "hook.completed" | "hook.failed", + hook: OnboardHook, + index: number, + options: { + occurredAt: string; + error?: unknown; + metadata?: Record; + }, +): OnboardMachineEvent { + return { + version: 1, + type, + occurredAt: options.occurredAt, + sessionId: source.sessionId, + state: source.state, + step: source.step, + context: source.context, + error: redactSensitiveText(options.error instanceof Error ? options.error.message : options.error), + metadata: sanitizeOnboardMachineEventMetadata({ + hook: hookName(hook, index), + sourceType: source.type, + ...options.metadata, + }), + }; +} + +function isHookLifecycleEvent(event: OnboardMachineEvent): boolean { + return event.type === "hook.started" || event.type === "hook.completed" || event.type === "hook.failed"; +} + +export class OnboardHookDispatcher { + private readonly hooks: readonly OnboardHook[]; + private readonly warn: (message: string) => void; + private readonly emitEvent: (event: OnboardMachineEvent) => void; + private readonly now: () => string; + + constructor(hooks: readonly OnboardHook[], options: OnboardHookDispatchOptions = {}) { + this.hooks = hooks; + this.warn = options.warn ?? ((message) => console.warn(message)); + this.emitEvent = options.emitEvent ?? emitOnboardMachineEvent; + this.now = options.now ?? (() => new Date().toISOString()); + } + + async dispatch(event: OnboardMachineEvent): Promise { + for (const [index, hook] of this.hooks.entries()) { + if (typeof hook.onEvent !== "function") continue; + this.emitEvent( + hookLifecycleEvent(event, "hook.started", hook, index, { + occurredAt: this.now(), + }), + ); + try { + await hook.onEvent(event); + this.emitEvent( + hookLifecycleEvent(event, "hook.completed", hook, index, { + occurredAt: this.now(), + }), + ); + } catch (error) { + const name = hookName(hook, index); + const message = error instanceof Error ? error.message : String(error); + this.warn(`Onboard hook '${name}' failed: ${redactSensitiveText(message) ?? ""}`); + this.emitEvent( + hookLifecycleEvent(event, "hook.failed", hook, index, { + occurredAt: this.now(), + error: message, + }), + ); + } + } + } +} + +export function registerOnboardHooks( + hooks: readonly OnboardHook[], + options: OnboardHookRegistrationOptions = {}, +): () => void { + const dispatcher = new OnboardHookDispatcher(hooks, options); + const listener: OnboardMachineEventListener = (event) => { + if (options.includeHookEvents !== true && isHookLifecycleEvent(event)) return; + void dispatcher.dispatch(event); + }; + return addOnboardMachineEventListener(listener); +} + +export function createJsonlOnboardHook(filePath: string): OnboardHook { + const resolvedPath = path.resolve(filePath); + return { + name: "jsonl", + onEvent(event) { + fs.mkdirSync(path.dirname(resolvedPath), { recursive: true, mode: 0o700 }); + fs.appendFileSync(resolvedPath, `${JSON.stringify(event)}\n`, { mode: 0o600 }); + }, + }; +} From c2a58e6053babf96f876e0bddbd44a1f865a9340 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 22:45:51 -0700 Subject: [PATCH 07/12] refactor(cli): extract onboard preflight handler --- src/lib/onboard.ts | 87 ++++----- .../machine/handlers/preflight.test.ts | 183 ++++++++++++++++++ src/lib/onboard/machine/handlers/preflight.ts | 147 ++++++++++++++ 3 files changed, 363 insertions(+), 54 deletions(-) create mode 100644 src/lib/onboard/machine/handlers/preflight.test.ts create mode 100644 src/lib/onboard/machine/handlers/preflight.ts diff --git a/src/lib/onboard.ts b/src/lib/onboard.ts index 470639b346..50c5187326 100644 --- a/src/lib/onboard.ts +++ b/src/lib/onboard.ts @@ -280,6 +280,7 @@ const { resolveSandboxImageTagFromCreateOutput } = const nim: typeof import("./inference/nim") = require("./inference/nim"); const onboardSession: typeof import("./state/onboard-session") = require("./state/onboard-session"); const { OnboardRuntime }: typeof import("./onboard/machine/runtime") = require("./onboard/machine/runtime"); +const { handlePreflightState }: typeof import("./onboard/machine/handlers/preflight") = require("./onboard/machine/handlers/preflight"); const policies: typeof import("./policy") = require("./policy"); const tiers: typeof import("./policy/tiers") = require("./policy/tiers"); const { ensureUsageNoticeConsent } = require("./onboard/usage-notice"); @@ -9403,54 +9404,39 @@ async function onboard(opts: OnboardOptions = {}): Promise { console.log(" ==================="); const explicitSandboxGpuFlag = resolveSandboxGpuFlagFromOptions(opts); - const resumePreflight = resume && session?.steps?.preflight?.status === "complete"; - const resumeHasResolvedGpuIntent = - resumePreflight && - explicitSandboxGpuFlag === null && - opts.sandboxGpuDevice == null && - process.env.NEMOCLAW_SANDBOX_GPU === undefined && - process.env.NEMOCLAW_SANDBOX_GPU_DEVICE === undefined; - const resumedSandboxGpuOverrides = resumeHasResolvedGpuIntent - ? getResumeSandboxGpuOverrides( - resumeSandboxNameForGpu ? registry.getSandbox(resumeSandboxNameForGpu) : null, - session?.gpuPassthrough, - ) - : { flag: null, device: null }; - const effectiveSandboxGpuFlag = explicitSandboxGpuFlag ?? resumedSandboxGpuOverrides.flag; - const effectiveSandboxGpuDevice = opts.sandboxGpuDevice ?? resumedSandboxGpuOverrides.device; - let gpu; - if (resumePreflight) { - skippedStepMessage("preflight", "cached"); - gpu = nim.detectGpu(); - // Re-check the CDI spec gap on resume (#3152). The cached preflight - // result does not capture host CDI state, and the original onboard - // attempt that wrote the cache likely aborted at gateway-start with - // exactly this CDI failure — so resuming without re-checking would - // walk into the same wall. Honour persisted `gpuPassthrough: false` - // from the prior session as an opt-out, since the resume invocation - // does not need to re-pass `--no-gpu` to keep that intent (the same - // resolution is replayed a few lines below for `gpuPassthrough`). - const resumeOptedOutGpuPassthrough = - opts.noGpu === true || (opts.gpu !== true && session?.gpuPassthrough === false); - assertCdiNvidiaGpuSpecPresent(assessHost(), resumeOptedOutGpuPassthrough); - validateSandboxGpuPreflight( - resolveSandboxGpuConfig(gpu, { - flag: effectiveSandboxGpuFlag, - device: effectiveSandboxGpuDevice, - }), - ); - } else { - await startRecordedStep("preflight"); - gpu = await preflight({ ...opts, optedOutGpuPassthrough: opts.noGpu === true }); - await recordStepComplete("preflight"); - } - const sandboxGpuConfig = resolveSandboxGpuConfig(gpu, { - flag: effectiveSandboxGpuFlag, - device: effectiveSandboxGpuDevice, + const preflightResult = await handlePreflightState({ + resume, + session, + recordedSandboxName, + requestedSandboxName, + explicitSandboxGpuFlag, + sandboxGpuDevice: opts.sandboxGpuDevice ?? null, + gpuRequested: opts.gpu === true, + noGpu: opts.noGpu === true, + env: process.env, + deps: { + getSandbox: registry.getSandbox.bind(registry), + getResumeSandboxGpuOverrides, + detectGpu: nim.detectGpu, + runPreflight: (preflightOptions) => preflight({ ...opts, ...preflightOptions }), + assessHost, + assertCdiNvidiaGpuSpecPresent, + resolveSandboxGpuConfig, + validateSandboxGpuPreflight, + skippedStepMessage, + startRecordedStep, + recordStepComplete, + updateSession: onboardSession.updateSession, + }, }); - - const requestedGpuPassthrough = opts.gpu === true; - const gpuPassthrough = sandboxGpuConfig.sandboxGpuEnabled; + session = preflightResult.session; + const { + sandboxGpuConfig, + resumeHasResolvedGpuIntent, + requestedGpuPassthrough, + gpuPassthrough, + } = preflightResult; + const gpu = preflightResult.gpu ?? null; if (gpuPassthrough) { note( resumeHasResolvedGpuIntent && session?.gpuPassthrough === true @@ -9472,13 +9458,6 @@ async function onboard(opts: OnboardOptions = {}): Promise { /* lspci not available — skip hint */ } } - // Persist GPU intent in the session so resume can restore it. - if (session && session.gpuPassthrough !== gpuPassthrough) { - session = onboardSession.updateSession((current: Session) => { - current.gpuPassthrough = gpuPassthrough; - return current; - }); - } dockerGpuLocalInference.configureLocalInferenceForDockerGpuHostNetwork(sandboxGpuConfig, { dockerDriverGateway: isLinuxDockerDriverGatewayEnabled(), note, diff --git a/src/lib/onboard/machine/handlers/preflight.test.ts b/src/lib/onboard/machine/handlers/preflight.test.ts new file mode 100644 index 0000000000..fa4b859915 --- /dev/null +++ b/src/lib/onboard/machine/handlers/preflight.test.ts @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it, vi } from "vitest"; + +import { createSession, type Session } from "../../../state/onboard-session"; +import { handlePreflightState, type PreflightStateOptions } from "./preflight"; + +type Gpu = { type: string } | null; +type SandboxEntry = { sandboxGpuEnabled?: boolean }; +type Host = { cdiNvidiaGpuSpecMissing?: boolean }; + +function createDeps(overrides: Partial["deps"]> = {}) { + let session = createSession(); + return { + calls: { + start: vi.fn(), + complete: vi.fn(), + skipped: vi.fn(), + detectGpu: vi.fn(() => ({ type: "nvidia" }) as Gpu), + runPreflight: vi.fn(async () => ({ type: "nvidia" }) as Gpu), + validate: vi.fn(), + cdi: vi.fn(), + updateSession: vi.fn(), + getSandbox: vi.fn(() => ({ sandboxGpuEnabled: true })), + getOverrides: vi.fn(() => ({ flag: "enable" as const, device: "0" })), + }, + deps: { + getSandbox: (name: string) => { + const value = ({ sandboxGpuEnabled: true } satisfies SandboxEntry); + return overrides.getSandbox ? overrides.getSandbox(name) : value; + }, + getResumeSandboxGpuOverrides: ( + sandbox: SandboxEntry | null, + sessionGpuPassthrough: boolean | null | undefined, + ) => { + if (overrides.getResumeSandboxGpuOverrides) { + return overrides.getResumeSandboxGpuOverrides(sandbox, sessionGpuPassthrough); + } + return { flag: "enable" as const, device: "0" }; + }, + detectGpu: () => ({ type: "nvidia" }) as Gpu, + runPreflight: async () => ({ type: "nvidia" }) as Gpu, + assessHost: () => ({ cdiNvidiaGpuSpecMissing: false }), + assertCdiNvidiaGpuSpecPresent: vi.fn(), + resolveSandboxGpuConfig: (_gpu: Gpu, opts: { flag: "enable" | "disable" | null; device: string | null | undefined }) => ({ + sandboxGpuEnabled: opts.flag === "enable", + mode: opts.flag === "enable" ? "1" : "0", + sandboxGpuDevice: opts.device, + }), + validateSandboxGpuPreflight: vi.fn(), + skippedStepMessage: vi.fn(), + startRecordedStep: vi.fn(async () => undefined), + recordStepComplete: vi.fn(async () => session), + updateSession: vi.fn((mutator: (value: Session) => Session | void) => { + session = mutator(session) ?? session; + return session; + }), + ...overrides, + }, + getSession: () => session, + }; +} + +function baseOptions( + deps: PreflightStateOptions["deps"], + session: Session | null = createSession(), +): PreflightStateOptions { + return { + resume: false, + session, + recordedSandboxName: null, + requestedSandboxName: "my-assistant", + explicitSandboxGpuFlag: null, + sandboxGpuDevice: null, + gpuRequested: false, + noGpu: false, + env: {}, + deps, + }; +} + +describe("handlePreflightState", () => { + it("runs full preflight through recorded step boundaries", async () => { + const harness = createDeps({ + startRecordedStep: vi.fn(async () => undefined), + runPreflight: vi.fn(async () => ({ type: "nvidia" }) as Gpu), + recordStepComplete: vi.fn(async () => createSession()), + }); + + const result = await handlePreflightState({ + ...baseOptions(harness.deps), + explicitSandboxGpuFlag: "enable", + sandboxGpuDevice: "GPU-0", + }); + + expect(harness.deps.startRecordedStep).toHaveBeenCalledWith("preflight"); + expect(harness.deps.runPreflight).toHaveBeenCalledWith({ optedOutGpuPassthrough: false }); + expect(harness.deps.recordStepComplete).toHaveBeenCalledWith("preflight"); + expect(result.sandboxGpuConfig).toMatchObject({ + sandboxGpuEnabled: true, + mode: "1", + sandboxGpuDevice: "GPU-0", + }); + expect(result.gpuPassthrough).toBe(true); + }); + + it("skips full preflight on resume but re-detects GPU and revalidates CDI/sandbox GPU", async () => { + const session = createSession(); + session.steps.preflight.status = "complete"; + session.gpuPassthrough = false; + const harness = createDeps({ + detectGpu: vi.fn(() => ({ type: "nvidia" }) as Gpu), + assertCdiNvidiaGpuSpecPresent: vi.fn(), + validateSandboxGpuPreflight: vi.fn(), + skippedStepMessage: vi.fn(), + startRecordedStep: vi.fn(async () => undefined), + runPreflight: vi.fn(async () => ({ type: "should-not-run" }) as Gpu), + }); + + const result = await handlePreflightState({ + ...baseOptions(harness.deps, session), + resume: true, + gpuRequested: false, + }); + + expect(harness.deps.skippedStepMessage).toHaveBeenCalledWith("preflight", "cached"); + expect(harness.deps.detectGpu).toHaveBeenCalledOnce(); + expect(harness.deps.runPreflight).not.toHaveBeenCalled(); + expect(harness.deps.startRecordedStep).not.toHaveBeenCalled(); + expect(harness.deps.assertCdiNvidiaGpuSpecPresent).toHaveBeenCalledWith( + { cdiNvidiaGpuSpecMissing: false }, + true, + ); + expect(harness.deps.validateSandboxGpuPreflight).toHaveBeenCalledOnce(); + expect(result.resumePreflight).toBe(true); + }); + + it("restores saved sandbox GPU intent only when resume has no explicit override", async () => { + const session = createSession(); + session.steps.preflight.status = "complete"; + session.gpuPassthrough = true; + const getResumeSandboxGpuOverrides = vi.fn(() => ({ flag: "enable" as const, device: "1" })); + const getSandbox = vi.fn(() => ({ sandboxGpuEnabled: true })); + const harness = createDeps({ getResumeSandboxGpuOverrides, getSandbox }); + + const result = await handlePreflightState({ + ...baseOptions(harness.deps, session), + resume: true, + recordedSandboxName: "saved", + }); + + expect(getSandbox).toHaveBeenCalledWith("saved"); + expect(getResumeSandboxGpuOverrides).toHaveBeenCalledWith( + { sandboxGpuEnabled: true }, + true, + ); + expect(result.resumeHasResolvedGpuIntent).toBe(true); + expect(result.effectiveSandboxGpuFlag).toBe("enable"); + expect(result.effectiveSandboxGpuDevice).toBe("1"); + + await handlePreflightState({ + ...baseOptions(harness.deps, session), + resume: true, + explicitSandboxGpuFlag: "disable", + }); + expect(getResumeSandboxGpuOverrides).toHaveBeenCalledTimes(1); + }); + + it("persists effective GPU passthrough intent for later resume", async () => { + const session = createSession(); + session.gpuPassthrough = false; + const harness = createDeps(); + + const result = await handlePreflightState({ + ...baseOptions(harness.deps, session), + explicitSandboxGpuFlag: "enable", + }); + + expect(result.session?.gpuPassthrough).toBe(true); + expect(harness.deps.updateSession).toHaveBeenCalledOnce(); + }); +}); diff --git a/src/lib/onboard/machine/handlers/preflight.ts b/src/lib/onboard/machine/handlers/preflight.ts new file mode 100644 index 0000000000..cc5bd6633d --- /dev/null +++ b/src/lib/onboard/machine/handlers/preflight.ts @@ -0,0 +1,147 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { Session } from "../../../state/onboard-session"; + +export type PreflightSandboxGpuFlag = "enable" | "disable" | null; + +export interface PreflightSandboxGpuOverrides { + flag: PreflightSandboxGpuFlag; + device: string | null; +} + +export interface PreflightSandboxGpuConfig { + sandboxGpuEnabled: boolean; + mode: string; + sandboxGpuDevice?: string | null; + errors?: readonly string[]; +} + +export interface PreflightStateOptions< + Gpu, + SandboxEntry, + Host, + Config extends PreflightSandboxGpuConfig, +> { + resume: boolean; + session: Session | null; + recordedSandboxName: string | null; + requestedSandboxName: string | null; + explicitSandboxGpuFlag: PreflightSandboxGpuFlag; + sandboxGpuDevice?: string | null; + gpuRequested: boolean; + noGpu: boolean; + env: NodeJS.ProcessEnv; + deps: { + getSandbox(name: string): SandboxEntry | null; + getResumeSandboxGpuOverrides( + sandbox: SandboxEntry | null, + sessionGpuPassthrough: boolean | null | undefined, + ): PreflightSandboxGpuOverrides; + detectGpu(): Gpu; + runPreflight(options: { optedOutGpuPassthrough?: boolean }): Promise; + assessHost(): Host; + assertCdiNvidiaGpuSpecPresent(host: Host, optedOutGpuPassthrough: boolean): void; + resolveSandboxGpuConfig( + gpu: Gpu, + options: { flag: PreflightSandboxGpuFlag; device: string | null | undefined }, + ): Config; + validateSandboxGpuPreflight(config: Config): void; + skippedStepMessage(stepName: string, detail?: string | null): void; + startRecordedStep(stepName: string): Promise; + recordStepComplete(stepName: string): Promise; + updateSession(mutator: (session: Session) => Session | void): Session; + }; +} + +export interface PreflightStateResult { + gpu: Gpu; + sandboxGpuConfig: Config; + resumePreflight: boolean; + resumeHasResolvedGpuIntent: boolean; + requestedGpuPassthrough: boolean; + gpuPassthrough: boolean; + effectiveSandboxGpuFlag: PreflightSandboxGpuFlag; + effectiveSandboxGpuDevice: string | null | undefined; + session: Session | null; +} + +function envHasSandboxGpuOverride(env: NodeJS.ProcessEnv): boolean { + return env.NEMOCLAW_SANDBOX_GPU !== undefined || env.NEMOCLAW_SANDBOX_GPU_DEVICE !== undefined; +} + +export async function handlePreflightState< + Gpu, + SandboxEntry, + Host, + Config extends PreflightSandboxGpuConfig, +>({ + resume, + session, + recordedSandboxName, + requestedSandboxName, + explicitSandboxGpuFlag, + sandboxGpuDevice, + gpuRequested, + noGpu, + env, + deps, +}: PreflightStateOptions): Promise> { + const resumeSandboxNameForGpu = recordedSandboxName || requestedSandboxName || null; + const resumePreflight = resume && session?.steps?.preflight?.status === "complete"; + const resumeHasResolvedGpuIntent = + resumePreflight && + explicitSandboxGpuFlag === null && + sandboxGpuDevice == null && + !envHasSandboxGpuOverride(env); + const resumedSandboxGpuOverrides = resumeHasResolvedGpuIntent + ? deps.getResumeSandboxGpuOverrides( + resumeSandboxNameForGpu ? deps.getSandbox(resumeSandboxNameForGpu) : null, + session?.gpuPassthrough, + ) + : { flag: null, device: null }; + const effectiveSandboxGpuFlag = explicitSandboxGpuFlag ?? resumedSandboxGpuOverrides.flag; + const effectiveSandboxGpuDevice = sandboxGpuDevice ?? resumedSandboxGpuOverrides.device; + + let gpu: Gpu; + if (resumePreflight) { + deps.skippedStepMessage("preflight", "cached"); + gpu = deps.detectGpu(); + const resumeOptedOutGpuPassthrough = noGpu || (!gpuRequested && session?.gpuPassthrough === false); + deps.assertCdiNvidiaGpuSpecPresent(deps.assessHost(), resumeOptedOutGpuPassthrough); + deps.validateSandboxGpuPreflight( + deps.resolveSandboxGpuConfig(gpu, { + flag: effectiveSandboxGpuFlag, + device: effectiveSandboxGpuDevice, + }), + ); + } else { + await deps.startRecordedStep("preflight"); + gpu = await deps.runPreflight({ optedOutGpuPassthrough: noGpu }); + session = await deps.recordStepComplete("preflight"); + } + + const sandboxGpuConfig = deps.resolveSandboxGpuConfig(gpu, { + flag: effectiveSandboxGpuFlag, + device: effectiveSandboxGpuDevice, + }); + const gpuPassthrough = sandboxGpuConfig.sandboxGpuEnabled; + if (session && session.gpuPassthrough !== gpuPassthrough) { + session = deps.updateSession((current) => { + current.gpuPassthrough = gpuPassthrough; + return current; + }); + } + + return { + gpu, + sandboxGpuConfig, + resumePreflight, + resumeHasResolvedGpuIntent, + requestedGpuPassthrough: gpuRequested, + gpuPassthrough, + effectiveSandboxGpuFlag, + effectiveSandboxGpuDevice, + session, + }; +} From f17000a73716c16fe69007aeb1c8d218e646cb1d Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 23:01:05 -0700 Subject: [PATCH 08/12] refactor(cli): extract onboard gateway handler --- src/lib/onboard.ts | 147 +++---------- .../onboard/machine/handlers/gateway.test.ts | 203 ++++++++++++++++++ src/lib/onboard/machine/handlers/gateway.ts | 178 +++++++++++++++ 3 files changed, 413 insertions(+), 115 deletions(-) create mode 100644 src/lib/onboard/machine/handlers/gateway.test.ts create mode 100644 src/lib/onboard/machine/handlers/gateway.ts diff --git a/src/lib/onboard.ts b/src/lib/onboard.ts index 50c5187326..9d9b047748 100644 --- a/src/lib/onboard.ts +++ b/src/lib/onboard.ts @@ -280,6 +280,7 @@ const { resolveSandboxImageTagFromCreateOutput } = const nim: typeof import("./inference/nim") = require("./inference/nim"); const onboardSession: typeof import("./state/onboard-session") = require("./state/onboard-session"); const { OnboardRuntime }: typeof import("./onboard/machine/runtime") = require("./onboard/machine/runtime"); +const { handleGatewayState }: typeof import("./onboard/machine/handlers/gateway") = require("./onboard/machine/handlers/gateway"); const { handlePreflightState }: typeof import("./onboard/machine/handlers/preflight") = require("./onboard/machine/handlers/preflight"); const policies: typeof import("./policy") = require("./policy"); const tiers: typeof import("./policy/tiers") = require("./policy/tiers"); @@ -9464,125 +9465,41 @@ async function onboard(opts: OnboardOptions = {}): Promise { }); const gatewaySnapshot = selectNamedGatewayForReuseIfNeeded(getGatewayReuseSnapshot()); - let gatewayReuseState = gatewaySnapshot.gatewayReuseState; - gatewayReuseState = await refreshDockerDriverGatewayReuseState(gatewayReuseState); - - // Verify the legacy gateway container is actually running — openshell CLI - // metadata can be stale after a manual `docker rm`. See #2020. Newer - // package-managed OpenShell gateways do not have an openshell-cluster-* - // Docker container, so the live CLI health check is the source of truth. - if (gatewayReuseState === "healthy" && gatewayCliSupportsLifecycleCommands(runCaptureOpenshell)) { - const containerState = verifyGatewayContainerRunning(GATEWAY_NAME); - if (containerState === "missing") { - console.log(" Gateway metadata is stale (container not running). Cleaning up..."); - runOpenshell(["forward", "stop", String(DASHBOARD_PORT)], { ignoreError: true }); - gatewayReuseState = destroyGatewayForReuse( - destroyGateway, - " ✓ Stale gateway metadata cleaned up", - " ! Stale gateway metadata cleanup failed; leaving registry state intact.", - ); - } else if (containerState === "unknown") { - // Docker probe failed but cached metadata says healthy. Try the host-level - // HTTP probe — it doesn't depend on Docker, so it can confirm the gateway - // is genuinely serving even when the daemon is flaky. - if (await waitForGatewayHttpReady()) { - console.log( - " Warning: could not verify gateway container state (Docker may be unavailable), but the gateway is responding on HTTP. Proceeding with reuse.", - ); - } else { - // Docker can't be probed AND the gateway HTTP endpoint isn't - // responding. We cannot tell whether the existing gateway is live - // (transient `docker inspect` flake + warm-up miss) or genuinely - // gone. Per #2020 we must not destroy in this branch, and we must - // not downgrade to "missing" either: that would push execution into - // `startGatewayWithOptions`, whose retry hook calls - // `destroyGateway()` between attempts — which would tear down a - // possibly-live gateway. Bail with an actionable error instead. - console.log( - ` Error: could not verify gateway container state and ${getGatewayLocalEndpoint()}/ is not responding.`, - ); - console.log( - " Refusing to proceed without a clear Docker signal — restarting Docker and re-running onboard is the safe path. See #3258 / #2020.", - ); - process.exit(1); - } - } else if (!(await waitForGatewayHttpReady())) { - // Container is running but the gateway HTTP endpoint is not responding. - // Common immediately after a Docker daemon restart — the container comes - // back before the OpenShell gateway upstream finishes warming up. Safe to - // recreate because Docker is functional. See #3258. - console.log( - ` Gateway container is running but ${getGatewayLocalEndpoint()}/ is not responding. Recreating...`, - ); - runOpenshell(["forward", "stop", String(DASHBOARD_PORT)], { ignoreError: true }); - gatewayReuseState = destroyGatewayForReuse( - destroyGateway, - " ✓ Stale gateway cleaned up", - " ! Stale gateway cleanup failed; leaving registry state intact.", - ); - } else { - const imageDrift = getGatewayClusterImageDrift(); - if (imageDrift) { - console.log( - ` Gateway image ${imageDrift.currentVersion} does not match openshell ${imageDrift.expectedVersion}. Recreating...`, - ); - stopAllDashboardForwards(); - gatewayReuseState = destroyGatewayForReuse( - destroyGateway, - " ✓ Previous gateway cleaned up", - " ! Previous gateway cleanup failed; leaving registry state intact.", - ); - } - } - } - - gatewayReuseState = reconcileGatewayGpuReuseForGpuIntent({ - gatewayReuseState, + const gatewayResult = await handleGatewayState({ + resume, + session, + initialGatewayReuseState: gatewaySnapshot.gatewayReuseState, + gpu, gpuPassthrough, gatewayName: GATEWAY_NAME, - currentSandboxName: recordedSandboxName || requestedSandboxName, + dashboardPort: DASHBOARD_PORT, + recordedSandboxName, + requestedSandboxName, recreateSandbox: isRecreateSandbox(), - confirmedDockerDriverGateway: - isLinuxDockerDriverGatewayEnabled() && - gatewayReuseState === "healthy" && - !gatewayCliSupportsLifecycleCommands(runCaptureOpenshell), - stopDashboardForwards: stopAllDashboardForwards, - retireLegacyGatewayForDockerDriverUpgrade, - destroyGatewayRuntimeForGpuReuse: () => destroyGateway(() => undefined, () => false), + deps: { + refreshDockerDriverGatewayReuseState, + gatewayCliSupportsLifecycleCommands: () => gatewayCliSupportsLifecycleCommands(runCaptureOpenshell), + verifyGatewayContainerRunning, + waitForGatewayHttpReady, + getGatewayLocalEndpoint, + runOpenshell, + destroyGateway, + destroyGatewayForReuse, + getGatewayClusterImageDrift, + stopAllDashboardForwards, + reconcileGatewayGpuReuseForGpuIntent, + isLinuxDockerDriverGatewayEnabled, + retireLegacyGatewayForDockerDriverUpgrade, + destroyGatewayRuntimeForGpuReuse: () => destroyGateway(() => undefined, () => false), + skippedStepMessage, + note, + startRecordedStep, + startGateway, + recordStepComplete, + exitProcess: (code) => process.exit(code), + }, }); - - const canReuseHealthyGateway = gatewayReuseState === "healthy"; - - const resumeGateway = - resume && session?.steps?.gateway?.status === "complete" && canReuseHealthyGateway; - if (resumeGateway) { - skippedStepMessage("gateway", "running"); - await recordStepComplete("gateway"); - } else if (!resume && canReuseHealthyGateway) { - skippedStepMessage("gateway", "running", "reuse"); - note(" Reusing healthy NemoClaw gateway."); - await recordStepComplete("gateway"); - } else { - if (resume && session?.steps?.gateway?.status === "complete") { - if (gatewayReuseState === "active-unnamed") { - note(" [resume] Gateway is active but named metadata is missing; recreating it safely."); - } else if (gatewayReuseState === "foreign-active") { - note(" [resume] A different OpenShell gateway is active; NemoClaw will not reuse it."); - } else if (gatewayReuseState === "stale") { - note(" [resume] Recorded gateway is unhealthy; recreating it."); - } else { - note(" [resume] Recorded gateway state is unavailable; recreating it."); - } - } - if (isLinuxDockerDriverGatewayEnabled() && gatewayReuseState !== "missing") { - note(" Replacing legacy OpenShell gateway metadata with Docker-driver gateway."); - retireLegacyGatewayForDockerDriverUpgrade(); - gatewayReuseState = "missing"; - } - await startRecordedStep("gateway"); - await startGateway(gpu, { gpuPassthrough }); - await recordStepComplete("gateway"); - } + session = gatewayResult.session; // #2753: prefer requestedSandboxName over an unconfirmed session name. // A pre-fix session may carry sandboxName even though sandbox creation diff --git a/src/lib/onboard/machine/handlers/gateway.test.ts b/src/lib/onboard/machine/handlers/gateway.test.ts new file mode 100644 index 0000000000..266ba10360 --- /dev/null +++ b/src/lib/onboard/machine/handlers/gateway.test.ts @@ -0,0 +1,203 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it, vi } from "vitest"; + +import { createSession, type Session } from "../../../state/onboard-session"; +import type { GatewayReuseState } from "../../../state/gateway"; +import { handleGatewayState, type GatewayStateOptions } from "./gateway"; + +type Gpu = { type: string } | null; + +function createDeps(overrides: Partial["deps"]> = {}) { + const calls = { + refresh: vi.fn(async (state: GatewayReuseState) => state), + lifecycle: vi.fn(() => false), + verifyContainer: vi.fn(() => "running"), + waitHttp: vi.fn(async () => true), + runOpenshell: vi.fn(), + destroy: vi.fn(() => true), + destroyForReuse: vi.fn(() => "missing" as GatewayReuseState), + imageDrift: vi.fn(() => null), + stopForwards: vi.fn(), + reconcileGpu: vi.fn((opts: { gatewayReuseState: GatewayReuseState }) => opts.gatewayReuseState), + dockerDriver: vi.fn(() => false), + retireLegacy: vi.fn(), + destroyGpuRuntime: vi.fn(() => true), + skipped: vi.fn(), + note: vi.fn(), + startStep: vi.fn(async () => undefined), + startGateway: vi.fn(async () => undefined), + complete: vi.fn(async () => createSession()), + exit: vi.fn((code: number): never => { + throw new Error(`exit ${code}`); + }), + }; + return { + calls, + deps: { + refreshDockerDriverGatewayReuseState: calls.refresh, + gatewayCliSupportsLifecycleCommands: calls.lifecycle, + verifyGatewayContainerRunning: calls.verifyContainer, + waitForGatewayHttpReady: calls.waitHttp, + getGatewayLocalEndpoint: () => "http://127.0.0.1:31818", + runOpenshell: calls.runOpenshell, + destroyGateway: calls.destroy, + destroyGatewayForReuse: calls.destroyForReuse, + getGatewayClusterImageDrift: calls.imageDrift, + stopAllDashboardForwards: calls.stopForwards, + reconcileGatewayGpuReuseForGpuIntent: calls.reconcileGpu, + isLinuxDockerDriverGatewayEnabled: calls.dockerDriver, + retireLegacyGatewayForDockerDriverUpgrade: calls.retireLegacy, + destroyGatewayRuntimeForGpuReuse: calls.destroyGpuRuntime, + skippedStepMessage: calls.skipped, + note: calls.note, + startRecordedStep: calls.startStep, + startGateway: calls.startGateway, + recordStepComplete: calls.complete, + exitProcess: calls.exit, + ...overrides, + }, + }; +} + +function baseOptions( + deps: GatewayStateOptions["deps"], + initialGatewayReuseState: GatewayReuseState = "missing", + session: Session | null = createSession(), +): GatewayStateOptions { + return { + resume: false, + session, + initialGatewayReuseState, + gpu: { type: "nvidia" }, + gpuPassthrough: true, + gatewayName: "nemoclaw", + dashboardPort: 18789, + recordedSandboxName: null, + requestedSandboxName: "my-assistant", + recreateSandbox: false, + deps, + }; +} + +describe("handleGatewayState", () => { + it("starts the gateway when no reusable gateway exists", async () => { + const { deps, calls } = createDeps(); + + const result = await handleGatewayState(baseOptions(deps, "missing")); + + expect(calls.startStep).toHaveBeenCalledWith("gateway"); + expect(calls.startGateway).toHaveBeenCalledWith({ type: "nvidia" }, { gpuPassthrough: true }); + expect(calls.complete).toHaveBeenCalledWith("gateway"); + expect(result.gatewayReuseState).toBe("missing"); + }); + + it("reuses healthy gateways on fresh runs", async () => { + const { deps, calls } = createDeps(); + + await handleGatewayState(baseOptions(deps, "healthy")); + + expect(calls.skipped).toHaveBeenCalledWith("gateway", "running", "reuse"); + expect(calls.note).toHaveBeenCalledWith(" Reusing healthy NemoClaw gateway."); + expect(calls.startGateway).not.toHaveBeenCalled(); + expect(calls.complete).toHaveBeenCalledWith("gateway"); + }); + + it("reuses healthy gateways on resume only when the gateway step was complete", async () => { + const session = createSession(); + session.steps.gateway.status = "complete"; + const { deps, calls } = createDeps(); + + await handleGatewayState({ ...baseOptions(deps, "healthy", session), resume: true }); + + expect(calls.skipped).toHaveBeenCalledWith("gateway", "running"); + expect(calls.startGateway).not.toHaveBeenCalled(); + }); + + it("cleans stale lifecycle metadata when the gateway container is missing", async () => { + const { deps, calls } = createDeps({ + gatewayCliSupportsLifecycleCommands: vi.fn(() => true), + verifyGatewayContainerRunning: vi.fn(() => "missing" as GatewayReuseState), + destroyGatewayForReuse: vi.fn(() => "missing" as GatewayReuseState), + }); + + await handleGatewayState(baseOptions(deps, "healthy")); + + expect(calls.runOpenshell).toHaveBeenCalledWith(["forward", "stop", "18789"], { + ignoreError: true, + }); + expect(deps.destroyGatewayForReuse).toHaveBeenCalledWith( + deps.destroyGateway, + " ✓ Stale gateway metadata cleaned up", + " ! Stale gateway metadata cleanup failed; leaving registry state intact.", + ); + expect(calls.startGateway).toHaveBeenCalled(); + }); + + it("refuses to destroy an unknown container state when HTTP is also unavailable", async () => { + const { deps, calls } = createDeps({ + gatewayCliSupportsLifecycleCommands: vi.fn(() => true), + verifyGatewayContainerRunning: vi.fn(() => "unknown"), + waitForGatewayHttpReady: vi.fn(async () => false), + }); + + await expect(handleGatewayState(baseOptions(deps, "healthy"))).rejects.toThrow("exit 1"); + + expect(calls.exit).toHaveBeenCalledWith(1); + expect(calls.destroyForReuse).not.toHaveBeenCalled(); + }); + + it("recreates a running lifecycle gateway when the HTTP endpoint is unhealthy", async () => { + const { deps, calls } = createDeps({ + gatewayCliSupportsLifecycleCommands: vi.fn(() => true), + waitForGatewayHttpReady: vi.fn(async () => false), + destroyGatewayForReuse: vi.fn(() => "missing" as GatewayReuseState), + }); + + await handleGatewayState(baseOptions(deps, "healthy")); + + expect(calls.runOpenshell).toHaveBeenCalledWith(["forward", "stop", "18789"], { + ignoreError: true, + }); + expect(deps.destroyGatewayForReuse).toHaveBeenCalledWith( + deps.destroyGateway, + " ✓ Stale gateway cleaned up", + " ! Stale gateway cleanup failed; leaving registry state intact.", + ); + }); + + it("recreates on gateway image drift after stopping dashboard forwards", async () => { + const { deps, calls } = createDeps({ + gatewayCliSupportsLifecycleCommands: vi.fn(() => true), + waitForGatewayHttpReady: vi.fn(async () => true), + getGatewayClusterImageDrift: vi.fn(() => ({ currentVersion: "0.0.38", expectedVersion: "0.0.39" })), + destroyGatewayForReuse: vi.fn(() => "missing" as GatewayReuseState), + }); + + await handleGatewayState(baseOptions(deps, "healthy")); + + expect(calls.stopForwards).toHaveBeenCalledOnce(); + expect(deps.destroyGatewayForReuse).toHaveBeenCalledWith( + deps.destroyGateway, + " ✓ Previous gateway cleaned up", + " ! Previous gateway cleanup failed; leaving registry state intact.", + ); + }); + + it("replaces legacy metadata before starting the Docker-driver gateway", async () => { + const { deps, calls } = createDeps({ + isLinuxDockerDriverGatewayEnabled: vi.fn(() => true), + reconcileGatewayGpuReuseForGpuIntent: vi.fn(() => "stale" as GatewayReuseState), + }); + + const result = await handleGatewayState(baseOptions(deps, "healthy")); + + expect(calls.note).toHaveBeenCalledWith( + " Replacing legacy OpenShell gateway metadata with Docker-driver gateway.", + ); + expect(calls.retireLegacy).toHaveBeenCalledOnce(); + expect(calls.startGateway).toHaveBeenCalledOnce(); + expect(result.gatewayReuseState).toBe("missing"); + }); +}); diff --git a/src/lib/onboard/machine/handlers/gateway.ts b/src/lib/onboard/machine/handlers/gateway.ts new file mode 100644 index 0000000000..026c26e1b4 --- /dev/null +++ b/src/lib/onboard/machine/handlers/gateway.ts @@ -0,0 +1,178 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { Session } from "../../../state/onboard-session"; +import type { GatewayReuseState } from "../../../state/gateway"; + +export type GatewayContainerState = "missing" | "unknown" | string; + +export interface GatewayStateOptions { + resume: boolean; + session: Session | null; + initialGatewayReuseState: GatewayReuseState; + gpu: Gpu; + gpuPassthrough: boolean; + gatewayName: string; + dashboardPort: number; + recordedSandboxName: string | null; + requestedSandboxName: string | null; + recreateSandbox: boolean; + deps: { + refreshDockerDriverGatewayReuseState(state: GatewayReuseState): Promise; + gatewayCliSupportsLifecycleCommands(): boolean; + verifyGatewayContainerRunning(gatewayName: string): GatewayContainerState; + waitForGatewayHttpReady(): Promise; + getGatewayLocalEndpoint(): string; + runOpenshell(args: string[], opts?: { ignoreError?: boolean }): unknown; + destroyGateway(): boolean; + destroyGatewayForReuse( + destroyGateway: () => boolean, + successMessage: string, + failureMessage: string, + ): GatewayReuseState; + getGatewayClusterImageDrift(): { currentVersion: string; expectedVersion: string } | null; + stopAllDashboardForwards(): void; + reconcileGatewayGpuReuseForGpuIntent(options: { + gatewayReuseState: GatewayReuseState; + gpuPassthrough: boolean; + gatewayName: string; + currentSandboxName: string | null; + recreateSandbox: boolean; + confirmedDockerDriverGateway: boolean; + stopDashboardForwards: () => void; + retireLegacyGatewayForDockerDriverUpgrade: () => void; + destroyGatewayRuntimeForGpuReuse: () => boolean; + }): GatewayReuseState; + isLinuxDockerDriverGatewayEnabled(): boolean; + retireLegacyGatewayForDockerDriverUpgrade(): void; + destroyGatewayRuntimeForGpuReuse(): boolean; + skippedStepMessage( + stepName: string, + detail?: string | null, + reason?: "resume" | "reuse", + ): void; + note(message: string): void; + startRecordedStep(stepName: string): Promise; + startGateway(gpu: Gpu, options: { gpuPassthrough: boolean }): Promise; + recordStepComplete(stepName: string): Promise; + exitProcess(code: number): never; + }; +} + +export interface GatewayStateResult { + gatewayReuseState: GatewayReuseState; + session: Session | null; +} + +export async function handleGatewayState({ + resume, + session, + initialGatewayReuseState, + gpu, + gpuPassthrough, + gatewayName, + dashboardPort, + recordedSandboxName, + requestedSandboxName, + recreateSandbox, + deps, +}: GatewayStateOptions): Promise { + let gatewayReuseState = await deps.refreshDockerDriverGatewayReuseState(initialGatewayReuseState); + const supportsLifecycleCommands = deps.gatewayCliSupportsLifecycleCommands(); + + if (gatewayReuseState === "healthy" && supportsLifecycleCommands) { + const containerState = deps.verifyGatewayContainerRunning(gatewayName); + if (containerState === "missing") { + console.log(" Gateway metadata is stale (container not running). Cleaning up..."); + deps.runOpenshell(["forward", "stop", String(dashboardPort)], { ignoreError: true }); + gatewayReuseState = deps.destroyGatewayForReuse( + deps.destroyGateway, + " ✓ Stale gateway metadata cleaned up", + " ! Stale gateway metadata cleanup failed; leaving registry state intact.", + ); + } else if (containerState === "unknown") { + if (await deps.waitForGatewayHttpReady()) { + console.log( + " Warning: could not verify gateway container state (Docker may be unavailable), but the gateway is responding on HTTP. Proceeding with reuse.", + ); + } else { + console.log( + ` Error: could not verify gateway container state and ${deps.getGatewayLocalEndpoint()}/ is not responding.`, + ); + console.log( + " Refusing to proceed without a clear Docker signal — restarting Docker and re-running onboard is the safe path. See #3258 / #2020.", + ); + deps.exitProcess(1); + } + } else if (!(await deps.waitForGatewayHttpReady())) { + console.log( + ` Gateway container is running but ${deps.getGatewayLocalEndpoint()}/ is not responding. Recreating...`, + ); + deps.runOpenshell(["forward", "stop", String(dashboardPort)], { ignoreError: true }); + gatewayReuseState = deps.destroyGatewayForReuse( + deps.destroyGateway, + " ✓ Stale gateway cleaned up", + " ! Stale gateway cleanup failed; leaving registry state intact.", + ); + } else { + const imageDrift = deps.getGatewayClusterImageDrift(); + if (imageDrift) { + console.log( + ` Gateway image ${imageDrift.currentVersion} does not match openshell ${imageDrift.expectedVersion}. Recreating...`, + ); + deps.stopAllDashboardForwards(); + gatewayReuseState = deps.destroyGatewayForReuse( + deps.destroyGateway, + " ✓ Previous gateway cleaned up", + " ! Previous gateway cleanup failed; leaving registry state intact.", + ); + } + } + } + + gatewayReuseState = deps.reconcileGatewayGpuReuseForGpuIntent({ + gatewayReuseState, + gpuPassthrough, + gatewayName, + currentSandboxName: recordedSandboxName || requestedSandboxName, + recreateSandbox, + confirmedDockerDriverGateway: + deps.isLinuxDockerDriverGatewayEnabled() && gatewayReuseState === "healthy" && !supportsLifecycleCommands, + stopDashboardForwards: deps.stopAllDashboardForwards, + retireLegacyGatewayForDockerDriverUpgrade: deps.retireLegacyGatewayForDockerDriverUpgrade, + destroyGatewayRuntimeForGpuReuse: deps.destroyGatewayRuntimeForGpuReuse, + }); + + const canReuseHealthyGateway = gatewayReuseState === "healthy"; + const resumeGateway = resume && session?.steps?.gateway?.status === "complete" && canReuseHealthyGateway; + if (resumeGateway) { + deps.skippedStepMessage("gateway", "running"); + session = await deps.recordStepComplete("gateway"); + } else if (!resume && canReuseHealthyGateway) { + deps.skippedStepMessage("gateway", "running", "reuse"); + deps.note(" Reusing healthy NemoClaw gateway."); + session = await deps.recordStepComplete("gateway"); + } else { + if (resume && session?.steps?.gateway?.status === "complete") { + if (gatewayReuseState === "active-unnamed") { + deps.note(" [resume] Gateway is active but named metadata is missing; recreating it safely."); + } else if (gatewayReuseState === "foreign-active") { + deps.note(" [resume] A different OpenShell gateway is active; NemoClaw will not reuse it."); + } else if (gatewayReuseState === "stale") { + deps.note(" [resume] Recorded gateway is unhealthy; recreating it."); + } else { + deps.note(" [resume] Recorded gateway state is unavailable; recreating it."); + } + } + if (deps.isLinuxDockerDriverGatewayEnabled() && gatewayReuseState !== "missing") { + deps.note(" Replacing legacy OpenShell gateway metadata with Docker-driver gateway."); + deps.retireLegacyGatewayForDockerDriverUpgrade(); + gatewayReuseState = "missing"; + } + await deps.startRecordedStep("gateway"); + await deps.startGateway(gpu, { gpuPassthrough }); + session = await deps.recordStepComplete("gateway"); + } + + return { gatewayReuseState, session }; +} From 3038da47214adc54345ecd37b40fe944e943d1e7 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 23:20:19 -0700 Subject: [PATCH 09/12] refactor(cli): extract provider inference handlers --- src/lib/onboard.ts | 247 +++++---------- .../handlers/provider-inference.test.ts | 216 +++++++++++++ .../machine/handlers/provider-inference.ts | 289 ++++++++++++++++++ 3 files changed, 577 insertions(+), 175 deletions(-) create mode 100644 src/lib/onboard/machine/handlers/provider-inference.test.ts create mode 100644 src/lib/onboard/machine/handlers/provider-inference.ts diff --git a/src/lib/onboard.ts b/src/lib/onboard.ts index 9d9b047748..f7d95ae8ab 100644 --- a/src/lib/onboard.ts +++ b/src/lib/onboard.ts @@ -282,6 +282,7 @@ const onboardSession: typeof import("./state/onboard-session") = require("./stat const { OnboardRuntime }: typeof import("./onboard/machine/runtime") = require("./onboard/machine/runtime"); const { handleGatewayState }: typeof import("./onboard/machine/handlers/gateway") = require("./onboard/machine/handlers/gateway"); const { handlePreflightState }: typeof import("./onboard/machine/handlers/preflight") = require("./onboard/machine/handlers/preflight"); +const { handleProviderInferenceState }: typeof import("./onboard/machine/handlers/provider-inference") = require("./onboard/machine/handlers/provider-inference"); const policies: typeof import("./policy") = require("./policy"); const tiers: typeof import("./policy/tiers") = require("./policy/tiers"); const { ensureUsageNoticeConsent } = require("./onboard/usage-notice"); @@ -9514,181 +9515,77 @@ async function onboard(opts: OnboardOptions = {}): Promise { console.error(" Start a fresh onboard with --name to choose a different name."); process.exit(1); } - let model = session?.model || null; - let provider = session?.provider || null; - let endpointUrl = session?.endpointUrl || null; - let credentialEnv = session?.credentialEnv || null; - let hermesAuthMethod: HermesAuthMethod | null = - normalizeHermesAuthMethod(session?.hermesAuthMethod) || - (provider === hermesProviderAuth.HERMES_PROVIDER_NAME && - session?.credentialEnv === HERMES_NOUS_API_KEY_CREDENTIAL_ENV - ? HERMES_AUTH_METHOD_API_KEY - : null); - let hermesToolGateways = normalizeHermesToolGatewaySelections(session?.hermesToolGateways); - let preferredInferenceApi = session?.preferredInferenceApi || null; - let nimContainer = session?.nimContainer || null; - let webSearchConfig = session?.webSearchConfig || null; - let forceProviderSelection = false; - while (true) { - const resumeProviderSelection = - !forceProviderSelection && - resume && - session?.steps?.provider_selection?.status === "complete" && - typeof provider === "string" && - typeof model === "string"; - if (resumeProviderSelection) { - skippedStepMessage("provider_selection", `${provider} / ${model}`); - hydrateCredentialEnv(credentialEnv); - // #3342: resume short-circuits provider selection — repair the - // ollama-local systemd loopback override here so legacy 0.0.0.0 - // drop-ins from older NemoClaw versions get rewritten every resume. - repairLocalInferenceSystemdOverrideOrExit(provider, isNonInteractive); - } else { - // #2753: do not persist sandboxName to onboard-session.json before - // the sandbox actually exists in the gateway (Step 6 markStepComplete - // below). A SIGINT between any earlier step and createSandbox would - // otherwise leave a phantom that `nemoclaw list` resurrects until - // manually destroyed. - await startRecordedStep("provider_selection"); - const selection = await setupNim(gpu, sandboxName, agent); - model = selection.model; - provider = selection.provider; - endpointUrl = selection.endpointUrl; - credentialEnv = selection.credentialEnv; - hermesAuthMethod = selection.hermesAuthMethod; - hermesToolGateways = selection.hermesToolGateways; - preferredInferenceApi = selection.preferredInferenceApi; - nimContainer = selection.nimContainer; - await recordStepComplete( - "provider_selection", - toSessionUpdates({ - provider, - model, - endpointUrl, - credentialEnv, - hermesAuthMethod, - hermesToolGateways, - preferredInferenceApi, - nimContainer, - }), - ); - } - - if (typeof provider !== "string" || typeof model !== "string") { - console.error(" Inference selection did not yield a provider/model."); - process.exit(1); - } - process.env.NEMOCLAW_OPENSHELL_BIN = getOpenshellBinary(); - const needsBedrockRuntimeAdapter = - provider === "compatible-anthropic-endpoint" && - bedrockRuntimeOnboard.needsBedrockRuntimeAdapter(endpointUrl); - const resumeInference = - !needsBedrockRuntimeAdapter && - !forceProviderSelection && - resume && - isInferenceRouteReady(provider, model); - if (resumeInference) { - if (provider === hermesProviderAuth.HERMES_PROVIDER_NAME) { - if (!sandboxName) { - sandboxName = await promptValidatedSandboxName(agent); - } - await startRecordedStep("inference", { provider, model }); - const inferenceResult = await setupInference( - sandboxName, - model, - provider, - endpointUrl, - credentialEnv, - hermesAuthMethod, - hermesToolGateways, - ); - if (inferenceResult?.retry === "selection") { - forceProviderSelection = true; - continue; - } - await recordStepComplete( - "inference", - toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), - ); - break; - } - if (isRoutedInferenceProvider(provider)) { - try { - await reconcileModelRouter(); - } catch (err) { - console.error( - ` ✗ Failed to reconcile model router: ${err instanceof Error ? err.message : String(err)}`, - ); - process.exit(1); - } - } - skippedStepMessage("inference", `${provider} / ${model}`); - if (nimContainer && sandboxName) { - registry.updateSandbox(sandboxName, { nimContainer }); - } - await recordStepComplete( - "inference", - toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), - ); - break; - } - - if (!sandboxName) { - sandboxName = await promptValidatedSandboxName(agent); - } - const buildEstimateNote = - process.env.NEMOCLAW_IGNORE_RUNTIME_RESOURCES === "1" - ? null - : formatSandboxBuildEstimateNote(assessHost()); - console.log( - formatOnboardConfigSummary({ - provider, - model, - credentialEnv, - hermesAuthMethod, - webSearchConfig, - hermesToolGateways, - enabledChannels: selectedMessagingChannels.length > 0 ? selectedMessagingChannels : null, - sandboxName, - notes: buildEstimateNote ? [buildEstimateNote] : [], - }), - ); - console.log(" Web search and messaging channels will be prompted next."); - if (!isNonInteractive()) { - if (!(await promptYesNoOrDefault(" Apply this configuration?", null, true))) { - console.log(` Aborted. Re-run \`${cliName()} onboard\` to start over.`); - console.log(" Credentials entered so far were only staged in memory for this run."); - console.log( - " No new gateway credential was registered because onboarding stopped here.", - ); - process.exit(0); - } - } - - await startRecordedStep("inference", { provider, model }); - const inferenceResult = await setupInference( - sandboxName, - model, - provider, - endpointUrl, - credentialEnv, - hermesAuthMethod, - hermesToolGateways, - ); - delete process.env.NVIDIA_API_KEY; - if (inferenceResult?.retry === "selection") { - forceProviderSelection = true; - continue; - } - if (nimContainer && sandboxName) { - registry.updateSandbox(sandboxName, { nimContainer }); - } - await recordStepComplete( - "inference", - toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), - ); - break; - } + const providerInferenceResult = await handleProviderInferenceState({ + resume, + session, + gpu, + sandboxName, + agent, + initial: { + model: session?.model || null, + provider: session?.provider || null, + endpointUrl: session?.endpointUrl || null, + credentialEnv: session?.credentialEnv || null, + hermesAuthMethod: + normalizeHermesAuthMethod(session?.hermesAuthMethod) || + (session?.provider === hermesProviderAuth.HERMES_PROVIDER_NAME && + session?.credentialEnv === HERMES_NOUS_API_KEY_CREDENTIAL_ENV + ? HERMES_AUTH_METHOD_API_KEY + : null), + hermesToolGateways: normalizeHermesToolGatewaySelections(session?.hermesToolGateways), + preferredInferenceApi: session?.preferredInferenceApi || null, + nimContainer: session?.nimContainer || null, + webSearchConfig: session?.webSearchConfig || null, + }, + selectedMessagingChannels, + env: process.env, + constants: { hermesProviderName: hermesProviderAuth.HERMES_PROVIDER_NAME }, + deps: { + normalizeHermesAuthMethod, + setupNim, + setupInference, + startRecordedStep, + recordStepComplete, + toSessionUpdates: (updates) => toSessionUpdates(updates as Parameters[0]), + skippedStepMessage, + hydrateCredentialEnv, + repairLocalInferenceSystemdOverrideOrExit, + isNonInteractive, + getOpenshellBinary, + needsBedrockRuntimeAdapter: (providerName, url) => + providerName === "compatible-anthropic-endpoint" && + bedrockRuntimeOnboard.needsBedrockRuntimeAdapter(url), + isInferenceRouteReady, + isRoutedInferenceProvider, + reconcileModelRouter, + registryUpdateSandbox: (name, updates) => registry.updateSandbox(name, updates), + promptValidatedSandboxName, + assessHost, + formatSandboxBuildEstimateNote, + formatOnboardConfigSummary, + promptYesNoOrDefault, + cliName, + log: (message) => console.log(message), + error: (message) => console.error(message), + exitProcess: (code) => process.exit(code), + deleteEnv: (name) => { + delete process.env[name]; + }, + }, + }); + session = providerInferenceResult.session; + sandboxName = providerInferenceResult.sandboxName; + const { + model, + provider, + endpointUrl, + credentialEnv, + hermesAuthMethod, + hermesToolGateways, + preferredInferenceApi, + nimContainer, + } = providerInferenceResult; + let webSearchConfig = providerInferenceResult.webSearchConfig as WebSearchConfig | null; const webSearchSupportProbePath = fromDockerfile ? path.resolve(fromDockerfile) : null; const webSearchSupported = agentSupportsWebSearch(agent, webSearchSupportProbePath, ROOT); diff --git a/src/lib/onboard/machine/handlers/provider-inference.test.ts b/src/lib/onboard/machine/handlers/provider-inference.test.ts new file mode 100644 index 0000000000..bec7ea47a3 --- /dev/null +++ b/src/lib/onboard/machine/handlers/provider-inference.test.ts @@ -0,0 +1,216 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it, vi } from "vitest"; + +import { createSession, type Session, type SessionUpdates } from "../../../state/onboard-session"; +import { + handleProviderInferenceState, + type ProviderInferenceStateOptions, + type ProviderSelectionResult, +} from "./provider-inference"; + +type Gpu = { type: string } | null; +type Agent = { name: string } | null; +type Host = { cpus?: number }; + +const baseSelection: ProviderSelectionResult = { + model: "nvidia/test", + provider: "nvidia-prod", + endpointUrl: "https://integrate.api.nvidia.com/v1", + credentialEnv: "NVIDIA_API_KEY", + hermesAuthMethod: null, + hermesToolGateways: [], + preferredInferenceApi: "openai-responses", + nimContainer: null, +}; + +function createDeps(overrides: Partial["deps"]> = {}) { + const calls = { + setupNim: vi.fn(async () => ({ ...baseSelection })), + setupInference: vi.fn(async () => ({ ok: true as const })), + startStep: vi.fn(async () => undefined), + complete: vi.fn(async () => createSession()), + skipped: vi.fn(), + hydrate: vi.fn(), + repair: vi.fn(), + routeReady: vi.fn(() => false), + reconcileRouter: vi.fn(async () => undefined), + updateSandbox: vi.fn(), + promptName: vi.fn(async () => "my-assistant"), + promptYesNo: vi.fn(async () => true), + log: vi.fn(), + error: vi.fn(), + exit: vi.fn((code: number): never => { + throw new Error(`exit ${code}`); + }), + deleteEnv: vi.fn(), + }; + return { + calls, + deps: { + normalizeHermesAuthMethod: (value: string | null | undefined) => value ?? null, + setupNim: calls.setupNim, + setupInference: calls.setupInference, + startRecordedStep: calls.startStep, + recordStepComplete: calls.complete, + toSessionUpdates: (updates: Record) => updates as SessionUpdates, + skippedStepMessage: calls.skipped, + hydrateCredentialEnv: calls.hydrate, + repairLocalInferenceSystemdOverrideOrExit: calls.repair, + isNonInteractive: () => true, + getOpenshellBinary: () => "/usr/bin/openshell", + needsBedrockRuntimeAdapter: () => false, + isInferenceRouteReady: calls.routeReady, + isRoutedInferenceProvider: (provider: string) => provider === "nvidia-router", + reconcileModelRouter: calls.reconcileRouter, + registryUpdateSandbox: calls.updateSandbox, + promptValidatedSandboxName: calls.promptName, + assessHost: () => ({ cpus: 8 }), + formatSandboxBuildEstimateNote: () => "estimate", + formatOnboardConfigSummary: (options: { + provider: string; + model: string; + sandboxName: string; + }) => `summary:${options.provider}/${options.model}/${options.sandboxName}`, + promptYesNoOrDefault: calls.promptYesNo, + cliName: () => "nemoclaw", + log: calls.log, + error: calls.error, + exitProcess: calls.exit, + deleteEnv: calls.deleteEnv, + ...overrides, + }, + }; +} + +function baseOptions( + deps: ProviderInferenceStateOptions["deps"], + session: Session | null = createSession(), +): ProviderInferenceStateOptions { + return { + resume: false, + session, + gpu: { type: "nvidia" }, + sandboxName: null, + agent: null, + initial: { + model: session?.model ?? null, + provider: session?.provider ?? null, + endpointUrl: session?.endpointUrl ?? null, + credentialEnv: session?.credentialEnv ?? null, + hermesAuthMethod: session?.hermesAuthMethod ?? null, + hermesToolGateways: session?.hermesToolGateways ?? [], + preferredInferenceApi: session?.preferredInferenceApi ?? null, + nimContainer: session?.nimContainer ?? null, + webSearchConfig: session?.webSearchConfig ?? null, + }, + selectedMessagingChannels: [], + env: {}, + constants: { hermesProviderName: "hermes-provider" }, + deps, + }; +} + +describe("handleProviderInferenceState", () => { + it("runs provider selection and inference setup on a fresh flow", async () => { + const { deps, calls } = createDeps(); + + const result = await handleProviderInferenceState(baseOptions(deps)); + + expect(calls.startStep).toHaveBeenNthCalledWith(1, "provider_selection"); + expect(calls.setupNim).toHaveBeenCalledWith({ type: "nvidia" }, null, null); + expect(calls.complete).toHaveBeenCalledWith("provider_selection", expect.objectContaining({ provider: "nvidia-prod" })); + expect(calls.promptName).toHaveBeenCalledWith(null); + expect(calls.log).toHaveBeenCalledWith("summary:nvidia-prod/nvidia/test/my-assistant"); + expect(calls.startStep).toHaveBeenNthCalledWith(2, "inference", { + provider: "nvidia-prod", + model: "nvidia/test", + }); + expect(calls.setupInference).toHaveBeenCalledWith( + "my-assistant", + "nvidia/test", + "nvidia-prod", + "https://integrate.api.nvidia.com/v1", + "NVIDIA_API_KEY", + null, + [], + ); + expect(calls.deleteEnv).toHaveBeenCalledWith("NVIDIA_API_KEY"); + expect(result).toMatchObject({ + sandboxName: "my-assistant", + model: "nvidia/test", + provider: "nvidia-prod", + preferredInferenceApi: "openai-responses", + }); + }); + + it("skips provider selection and inference setup when resume state is already ready", async () => { + const session = createSession({ + provider: "ollama-local", + model: "llama3.1", + credentialEnv: null, + }); + session.steps.provider_selection.status = "complete"; + const { deps, calls } = createDeps({ isInferenceRouteReady: vi.fn(() => true) }); + + const result = await handleProviderInferenceState({ + ...baseOptions(deps, session), + resume: true, + sandboxName: "my-assistant", + }); + + expect(calls.setupNim).not.toHaveBeenCalled(); + expect(calls.setupInference).not.toHaveBeenCalled(); + expect(calls.skipped).toHaveBeenCalledWith("provider_selection", "ollama-local / llama3.1"); + expect(calls.hydrate).toHaveBeenCalledWith(null); + expect(calls.repair).toHaveBeenCalledWith("ollama-local", deps.isNonInteractive); + expect(calls.skipped).toHaveBeenCalledWith("inference", "ollama-local / llama3.1"); + expect(result).toMatchObject({ provider: "ollama-local", model: "llama3.1" }); + }); + + it("reconciles model router on resumed routed inference", async () => { + const session = createSession({ provider: "nvidia-router", model: "router/model" }); + session.steps.provider_selection.status = "complete"; + const { deps, calls } = createDeps({ isInferenceRouteReady: vi.fn(() => true) }); + + await handleProviderInferenceState({ + ...baseOptions(deps, session), + resume: true, + sandboxName: "router-sandbox", + }); + + expect(calls.reconcileRouter).toHaveBeenCalledOnce(); + }); + + it("returns to provider selection when inference setup requests a retry", async () => { + const setupNim = vi + .fn() + .mockResolvedValueOnce({ ...baseSelection, model: "bad" }) + .mockResolvedValueOnce({ ...baseSelection, model: "good" }); + const setupInference = vi + .fn() + .mockResolvedValueOnce({ retry: "selection" as const }) + .mockResolvedValueOnce({ ok: true as const }); + const { deps, calls } = createDeps({ setupNim, setupInference }); + + const result = await handleProviderInferenceState(baseOptions(deps)); + + expect(setupNim).toHaveBeenCalledTimes(2); + expect(setupInference).toHaveBeenCalledTimes(2); + expect(result.model).toBe("good"); + expect(calls.startStep).toHaveBeenCalledWith("provider_selection"); + }); + + it("aborts before inference setup when the configuration summary is rejected", async () => { + const { deps, calls } = createDeps({ + isNonInteractive: () => false, + promptYesNoOrDefault: vi.fn(async () => false), + }); + + await expect(handleProviderInferenceState(baseOptions(deps))).rejects.toThrow("exit 0"); + + expect(calls.exit).toHaveBeenCalledWith(0); + expect(calls.setupInference).not.toHaveBeenCalled(); + }); +}); diff --git a/src/lib/onboard/machine/handlers/provider-inference.ts b/src/lib/onboard/machine/handlers/provider-inference.ts new file mode 100644 index 0000000000..525b94a059 --- /dev/null +++ b/src/lib/onboard/machine/handlers/provider-inference.ts @@ -0,0 +1,289 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { Session, SessionUpdates } from "../../../state/onboard-session"; + +export type ProviderInferenceRetry = { retry: "selection" } | { ok: true; retry?: undefined }; + +export interface ProviderSelectionResult { + model: string | null; + provider: string; + endpointUrl: string | null; + credentialEnv: string | null; + hermesAuthMethod: string | null; + hermesToolGateways: string[]; + preferredInferenceApi: string | null; + nimContainer: string | null; +} + +export interface ProviderInferenceStateOptions { + resume: boolean; + session: Session | null; + gpu: Gpu; + sandboxName: string | null; + agent: Agent; + initial: { + model: string | null; + provider: string | null; + endpointUrl: string | null; + credentialEnv: string | null; + hermesAuthMethod: string | null; + hermesToolGateways: string[]; + preferredInferenceApi: string | null; + nimContainer: string | null; + webSearchConfig: any; + }; + selectedMessagingChannels: string[]; + env: NodeJS.ProcessEnv; + constants: { + hermesProviderName: string; + }; + deps: { + normalizeHermesAuthMethod(value: string | null | undefined): string | null; + setupNim(gpu: Gpu, sandboxName: string | null, agent: Agent): Promise; + setupInference( + sandboxName: string | null, + model: string, + provider: string, + endpointUrl: string | null, + credentialEnv: string | null, + hermesAuthMethod: string | null, + hermesToolGateways: string[], + ): Promise; + startRecordedStep( + stepName: string, + updates?: { provider?: string | null; model?: string | null }, + ): Promise; + recordStepComplete(stepName: string, updates: SessionUpdates): Promise; + toSessionUpdates(updates: Record): SessionUpdates; + skippedStepMessage(stepName: string, detail?: string | null): void; + hydrateCredentialEnv(credentialEnv: string | null): void; + repairLocalInferenceSystemdOverrideOrExit(provider: string | null, isNonInteractive: () => boolean): void; + isNonInteractive(): boolean; + getOpenshellBinary(): string; + needsBedrockRuntimeAdapter(provider: string, endpointUrl: string | null): boolean; + isInferenceRouteReady(provider: string, model: string): boolean; + isRoutedInferenceProvider(provider: string): boolean; + reconcileModelRouter(): Promise; + registryUpdateSandbox(sandboxName: string, updates: { nimContainer?: string | null }): void; + promptValidatedSandboxName(agent: Agent): Promise; + assessHost(): Host; + formatSandboxBuildEstimateNote(host: Host): string | null; + formatOnboardConfigSummary(options: { + provider: string; + model: string; + credentialEnv: string | null; + hermesAuthMethod: string | null; + webSearchConfig: any; + hermesToolGateways: string[]; + enabledChannels: string[] | null; + sandboxName: string; + notes: string[]; + }): string; + promptYesNoOrDefault(question: string, envVar: string | null, defaultIsYes: boolean): Promise; + cliName(): string; + log(message?: string): void; + error(message?: string): void; + exitProcess(code: number): never; + deleteEnv(name: string): void; + }; +} + +export interface ProviderInferenceStateResult { + sandboxName: string | null; + model: string; + provider: string; + endpointUrl: string | null; + credentialEnv: string | null; + hermesAuthMethod: string | null; + hermesToolGateways: string[]; + preferredInferenceApi: string | null; + nimContainer: string | null; + webSearchConfig: any; + session: Session | null; +} + +function requireSelection(provider: string | null, model: string | null): { provider: string; model: string } { + if (typeof provider !== "string" || typeof model !== "string") { + throw new Error("Inference selection did not yield a provider/model."); + } + return { provider, model }; +} + +export async function handleProviderInferenceState({ + resume, + session, + gpu, + sandboxName, + agent, + initial, + selectedMessagingChannels, + env, + constants, + deps, +}: ProviderInferenceStateOptions): Promise { + let model = initial.model; + let provider = initial.provider; + let endpointUrl = initial.endpointUrl; + let credentialEnv = initial.credentialEnv; + let hermesAuthMethod = + deps.normalizeHermesAuthMethod(initial.hermesAuthMethod) || + (provider === constants.hermesProviderName ? deps.normalizeHermesAuthMethod(initial.hermesAuthMethod) : null); + let hermesToolGateways = initial.hermesToolGateways; + let preferredInferenceApi = initial.preferredInferenceApi; + let nimContainer = initial.nimContainer; + const webSearchConfig = initial.webSearchConfig; + let forceProviderSelection = false; + + while (true) { + const resumeProviderSelection = + !forceProviderSelection && + resume && + session?.steps?.provider_selection?.status === "complete" && + typeof provider === "string" && + typeof model === "string"; + if (resumeProviderSelection) { + deps.skippedStepMessage("provider_selection", `${provider} / ${model}`); + deps.hydrateCredentialEnv(credentialEnv); + deps.repairLocalInferenceSystemdOverrideOrExit(provider, deps.isNonInteractive); + } else { + await deps.startRecordedStep("provider_selection"); + const selection = await deps.setupNim(gpu, sandboxName, agent); + model = selection.model; + provider = selection.provider; + endpointUrl = selection.endpointUrl; + credentialEnv = selection.credentialEnv; + hermesAuthMethod = selection.hermesAuthMethod; + hermesToolGateways = selection.hermesToolGateways; + preferredInferenceApi = selection.preferredInferenceApi; + nimContainer = selection.nimContainer; + session = await deps.recordStepComplete( + "provider_selection", + deps.toSessionUpdates({ + provider, + model, + endpointUrl, + credentialEnv, + hermesAuthMethod, + hermesToolGateways, + preferredInferenceApi, + nimContainer, + }), + ); + } + + const selected = requireSelection(provider, model); + provider = selected.provider; + model = selected.model; + env.NEMOCLAW_OPENSHELL_BIN = deps.getOpenshellBinary(); + const needsBedrockRuntimeAdapter = deps.needsBedrockRuntimeAdapter(provider, endpointUrl); + const resumeInference = + !needsBedrockRuntimeAdapter && + !forceProviderSelection && + resume && + deps.isInferenceRouteReady(provider, model); + if (resumeInference) { + if (provider === constants.hermesProviderName) { + if (!sandboxName) sandboxName = await deps.promptValidatedSandboxName(agent); + await deps.startRecordedStep("inference", { provider, model }); + const inferenceResult = await deps.setupInference( + sandboxName, + model, + provider, + endpointUrl, + credentialEnv, + hermesAuthMethod, + hermesToolGateways, + ); + if (inferenceResult?.retry === "selection") { + forceProviderSelection = true; + continue; + } + session = await deps.recordStepComplete( + "inference", + deps.toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), + ); + break; + } + if (deps.isRoutedInferenceProvider(provider)) { + try { + await deps.reconcileModelRouter(); + } catch (err) { + deps.error(` ✗ Failed to reconcile model router: ${err instanceof Error ? err.message : String(err)}`); + deps.exitProcess(1); + } + } + deps.skippedStepMessage("inference", `${provider} / ${model}`); + if (nimContainer && sandboxName) deps.registryUpdateSandbox(sandboxName, { nimContainer }); + session = await deps.recordStepComplete( + "inference", + deps.toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), + ); + break; + } + + if (!sandboxName) sandboxName = await deps.promptValidatedSandboxName(agent); + const buildEstimateNote = + env.NEMOCLAW_IGNORE_RUNTIME_RESOURCES === "1" + ? null + : deps.formatSandboxBuildEstimateNote(deps.assessHost()); + deps.log( + deps.formatOnboardConfigSummary({ + provider, + model, + credentialEnv, + hermesAuthMethod, + webSearchConfig, + hermesToolGateways, + enabledChannels: selectedMessagingChannels.length > 0 ? selectedMessagingChannels : null, + sandboxName, + notes: buildEstimateNote ? [buildEstimateNote] : [], + }), + ); + deps.log(" Web search and messaging channels will be prompted next."); + if (!deps.isNonInteractive()) { + if (!(await deps.promptYesNoOrDefault(" Apply this configuration?", null, true))) { + deps.log(` Aborted. Re-run \`${deps.cliName()} onboard\` to start over.`); + deps.log(" Credentials entered so far were only staged in memory for this run."); + deps.log(" No new gateway credential was registered because onboarding stopped here."); + deps.exitProcess(0); + } + } + + await deps.startRecordedStep("inference", { provider, model }); + const inferenceResult = await deps.setupInference( + sandboxName, + model, + provider, + endpointUrl, + credentialEnv, + hermesAuthMethod, + hermesToolGateways, + ); + deps.deleteEnv("NVIDIA_API_KEY"); + if (inferenceResult?.retry === "selection") { + forceProviderSelection = true; + continue; + } + if (nimContainer && sandboxName) deps.registryUpdateSandbox(sandboxName, { nimContainer }); + session = await deps.recordStepComplete( + "inference", + deps.toSessionUpdates({ provider, model, hermesAuthMethod, nimContainer, hermesToolGateways }), + ); + break; + } + + return { + sandboxName, + model, + provider, + endpointUrl, + credentialEnv, + hermesAuthMethod, + hermesToolGateways, + preferredInferenceApi, + nimContainer, + webSearchConfig, + session, + }; +} From 18ef7e763923a23a0b78739e3fc3619557ab9ac1 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 23:35:26 -0700 Subject: [PATCH 10/12] refactor(cli): extract onboard sandbox handler --- src/lib/onboard.ts | 267 ++++------------ .../onboard/machine/handlers/sandbox.test.ts | 198 ++++++++++++ src/lib/onboard/machine/handlers/sandbox.ts | 287 ++++++++++++++++++ 3 files changed, 547 insertions(+), 205 deletions(-) create mode 100644 src/lib/onboard/machine/handlers/sandbox.test.ts create mode 100644 src/lib/onboard/machine/handlers/sandbox.ts diff --git a/src/lib/onboard.ts b/src/lib/onboard.ts index f7d95ae8ab..b23e7f0fa4 100644 --- a/src/lib/onboard.ts +++ b/src/lib/onboard.ts @@ -283,6 +283,7 @@ const { OnboardRuntime }: typeof import("./onboard/machine/runtime") = require(" const { handleGatewayState }: typeof import("./onboard/machine/handlers/gateway") = require("./onboard/machine/handlers/gateway"); const { handlePreflightState }: typeof import("./onboard/machine/handlers/preflight") = require("./onboard/machine/handlers/preflight"); const { handleProviderInferenceState }: typeof import("./onboard/machine/handlers/provider-inference") = require("./onboard/machine/handlers/provider-inference"); +const { handleSandboxState }: typeof import("./onboard/machine/handlers/sandbox") = require("./onboard/machine/handlers/sandbox"); const policies: typeof import("./policy") = require("./policy"); const tiers: typeof import("./policy/tiers") = require("./policy/tiers"); const { ensureUsageNoticeConsent } = require("./onboard/usage-notice"); @@ -9587,212 +9588,68 @@ async function onboard(opts: OnboardOptions = {}): Promise { } = providerInferenceResult; let webSearchConfig = providerInferenceResult.webSearchConfig as WebSearchConfig | null; - const webSearchSupportProbePath = fromDockerfile ? path.resolve(fromDockerfile) : null; - const webSearchSupported = agentSupportsWebSearch(agent, webSearchSupportProbePath, ROOT); - if (webSearchConfig && !webSearchSupported) { - note( - ` Web search is not yet supported by ${agent?.displayName ?? "this sandbox image"}. Clearing stale config.`, - ); - webSearchConfig = null; - if (session) { - session.webSearchConfig = null; - } - onboardSession.updateSession((current: Session) => { - current.webSearchConfig = null; - return current; - }); - } - - const storedMessagingChannelConfig = getStoredMessagingChannelConfig(sandboxName, session); - const effectiveMessagingChannelConfig = hydrateMessagingChannelConfig(storedMessagingChannelConfig); - const messagingChannelConfigChanged = !messagingChannelConfigsEqual( - effectiveMessagingChannelConfig, - storedMessagingChannelConfig, - ); - if (effectiveMessagingChannelConfig) { - persistMessagingChannelConfigToSession(effectiveMessagingChannelConfig); - if (session) { - session.messagingChannelConfig = effectiveMessagingChannelConfig; - } - } - - const sandboxReuseState = getSandboxReuseState(sandboxName); - const webSearchConfigChanged = Boolean(session?.webSearchConfig) !== Boolean(webSearchConfig); - // Telegram mention-mode is baked into openclaw.json at sandbox build time, so - // changes to TELEGRAM_REQUIRE_MENTION only take effect after a rebuild. Treat - // a mismatch between the recorded config and the current env value as drift - // so the reuse path forces a recreate (mirrors webSearchConfigChanged). See - // #1737 and the CodeRabbit review on #2417. - // - // Compare *effective* modes — null and false both produce groupPolicy: open - // at config-generation time (default behavior), so they collapse to the same - // bucket here. Without this, a sandbox built before TELEGRAM_REQUIRE_MENTION - // existed (recordedTelegramRequireMention === null) would be reused with the - // old groupPolicy: open even after the user sets TELEGRAM_REQUIRE_MENTION=1, - // and vice versa. - const currentTelegramRequireMention = computeTelegramRequireMention(); - const recordedTelegramRequireMention = session?.telegramConfig?.requireMention ?? null; - const effectiveCurrent = currentTelegramRequireMention ?? false; - const effectiveRecorded = recordedTelegramRequireMention ?? false; - const telegramConfigChanged = effectiveCurrent !== effectiveRecorded; - const sandboxGpuConfigChanged = sandboxName - ? hasSandboxGpuDrift(sandboxName, sandboxGpuConfig) - : false; - const wechatConfigChanged = hasWechatConfigDrift(session); - const recordedHermesToolGateways = sandboxName - ? normalizeHermesToolGatewaySelections(registry.getSandbox(sandboxName)?.hermesToolGateways) - : []; - const hermesToolGatewayConfigChanged = !stringSetsEqual( - recordedHermesToolGateways, + const sandboxStateResult = await handleSandboxState({ + resume, + fresh, + session, + sandboxName, + model, + provider, + nimContainer, + webSearchConfig, + selectedMessagingChannels, + fromDockerfile, + agent, + gpu, + preferredInferenceApi, + sandboxGpuConfig, hermesToolGateways, - ); - const resumeSandbox = - resume && - !webSearchConfigChanged && - !telegramConfigChanged && - !sandboxGpuConfigChanged && - !wechatConfigChanged && - !messagingChannelConfigChanged && - !hermesToolGatewayConfigChanged && - session?.steps?.sandbox?.status === "complete" && - sandboxReuseState === "ready"; - if (resumeSandbox) { - if (webSearchConfig) { - note(" [resume] Reusing Brave Search configuration already baked into the sandbox."); - } - selectedMessagingChannels = session?.messagingChannels ?? []; - skippedStepMessage("sandbox", sandboxName); - } else { - if (resume && session?.steps?.sandbox?.status === "complete") { - if (webSearchConfigChanged) { - note(" [resume] Web Search configuration changed; recreating sandbox."); - if (sandboxName) { - registry.removeSandbox(sandboxName); - } - } else if (telegramConfigChanged) { - note(" [resume] TELEGRAM_REQUIRE_MENTION changed; recreating sandbox."); - if (sandboxName) { - registry.removeSandbox(sandboxName); - } - } else if (sandboxGpuConfigChanged) { - note(" [resume] Sandbox GPU settings changed; recreating sandbox."); - if (sandboxName) { - registry.removeSandbox(sandboxName); - } - } else if (wechatConfigChanged) { - note(" [resume] WeChat account metadata changed; recreating sandbox."); - if (sandboxName) { - registry.removeSandbox(sandboxName); - } - } else if (messagingChannelConfigChanged) { - note(" [resume] Messaging channel configuration changed; recreating sandbox."); - if (sandboxName) { - registry.removeSandbox(sandboxName); - } - } else if (hermesToolGatewayConfigChanged) { - note(" [resume] Hermes managed tool gateway selection changed; recreating sandbox."); - if (sandboxName) { - registry.removeSandbox(sandboxName); - } - } else if (sandboxReuseState === "not_ready") { - note( - ` [resume] Recorded sandbox '${sandboxName}' exists but is not ready; recreating it.`, - ); - repairRecordedSandbox(sandboxName); - } else { - note(" [resume] Recorded sandbox state is unavailable; recreating it."); - if (sandboxName) { - registry.removeSandbox(sandboxName); - } - } - } - let nextWebSearchConfig = webSearchConfig; - if (nextWebSearchConfig) { - note(" [resume] Revalidating Brave Search configuration for sandbox recreation."); - const braveApiKey = await ensureValidatedBraveSearchCredential(); - nextWebSearchConfig = braveApiKey ? { fetchEnabled: true } : null; - if (nextWebSearchConfig) { - note(" [resume] Reusing Brave Search configuration."); - } - } else { - nextWebSearchConfig = await configureWebSearch(null, agent, webSearchSupportProbePath); - } - await startRecordedStep("sandbox", { provider, model }); - const recordedMessagingChannels = getRecordedMessagingChannelsForResume(resume, session, sandboxName); - if (recordedMessagingChannels) { - selectedMessagingChannels = recordedMessagingChannels; - if (selectedMessagingChannels.length > 0) { - note( - ` [non-interactive] Reusing messaging channel configuration: ${selectedMessagingChannels.join(", ")}`, - ); - } - } else { - const existing = sandboxName - ? registry.getSandbox(sandboxName)?.messagingChannels ?? - session?.messagingChannels ?? - null - : session?.messagingChannels ?? null; - selectedMessagingChannels = await setupMessagingChannels(agent, existing); - } - const messagingChannelConfig = readMessagingChannelConfigFromEnv(); - onboardSession.updateSession((current: Session) => { - current.messagingChannels = selectedMessagingChannels; - current.messagingChannelConfig = messagingChannelConfig; - return current; - }); - if (!sandboxName) { - sandboxName = await promptValidatedSandboxName(agent); - } - if (typeof model !== "string" || typeof provider !== "string") { - console.error(" Inference selection is incomplete; cannot create sandbox."); - process.exit(1); - } - if (fresh) { - stopStaleDashboardListenersForSandbox(registry.listSandboxes().sandboxes, sandboxName); - } - sandboxName = await createSandbox( - gpu, - model, - provider, - preferredInferenceApi, - sandboxName, - nextWebSearchConfig, - selectedMessagingChannels, - fromDockerfile, - agent, - opts.controlUiPort || null, - sandboxGpuConfig, - hermesToolGateways, - ); - webSearchConfig = nextWebSearchConfig; - registry.updateSandbox(sandboxName, { - model, - provider, - ...getSandboxAgentRegistryFields(agent, !fromDockerfile), - }); - registry.setDefault(sandboxName); - await recordStepComplete( - "sandbox", - toSessionUpdates({ - sandboxName, - provider, - model, - nimContainer, - webSearchConfig, - messagingChannelConfig, - hermesToolGateways, - }), - ); - } - - if ( - typeof sandboxName !== "string" || - typeof provider !== "string" || - typeof model !== "string" - ) { - console.error(" Onboarding state is incomplete after sandbox setup."); - process.exit(1); - } + controlUiPort: opts.controlUiPort || null, + rootDir: ROOT, + deps: { + resolvePath: path.resolve, + agentSupportsWebSearch, + note, + updateSession: onboardSession.updateSession, + getStoredMessagingChannelConfig, + hydrateMessagingChannelConfig, + messagingChannelConfigsEqual, + persistMessagingChannelConfigToSession, + getSandboxReuseState, + computeTelegramRequireMention, + hasSandboxGpuDrift, + hasWechatConfigDrift, + getSandboxHermesToolGateways: (name) => registry.getSandbox(name)?.hermesToolGateways, + normalizeHermesToolGatewaySelections, + stringSetsEqual, + removeSandboxFromRegistry: registry.removeSandbox.bind(registry), + repairRecordedSandbox, + ensureValidatedBraveSearchCredential, + configureWebSearch, + startRecordedStep, + getRecordedMessagingChannelsForResume, + getSandboxMessagingChannels: (name) => registry.getSandbox(name)?.messagingChannels, + setupMessagingChannels, + readMessagingChannelConfigFromEnv, + promptValidatedSandboxName, + stopStaleDashboardListenersForSandbox, + listRegistrySandboxes: registry.listSandboxes, + createSandbox, + updateSandboxRegistry: (name, updates) => registry.updateSandbox(name, updates), + setDefaultSandbox: registry.setDefault, + getSandboxAgentRegistryFields, + recordStepComplete, + toSessionUpdates: (updates) => toSessionUpdates(updates as Parameters[0]), + skippedStepMessage, + error: (message) => console.error(message), + exitProcess: (code) => process.exit(code), + }, + }); + session = sandboxStateResult.session; + sandboxName = sandboxStateResult.sandboxName; + webSearchConfig = sandboxStateResult.webSearchConfig ?? null; + selectedMessagingChannels = sandboxStateResult.selectedMessagingChannels; + const webSearchSupported = sandboxStateResult.webSearchSupported; if (agent) { await agentOnboard.handleAgentSetup(sandboxName, model, provider, agent, resume, session, { diff --git a/src/lib/onboard/machine/handlers/sandbox.test.ts b/src/lib/onboard/machine/handlers/sandbox.test.ts new file mode 100644 index 0000000000..eac0ffb553 --- /dev/null +++ b/src/lib/onboard/machine/handlers/sandbox.test.ts @@ -0,0 +1,198 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it, vi } from "vitest"; + +import { createSession, type Session, type SessionUpdates } from "../../../state/onboard-session"; +import { handleSandboxState, type SandboxStateOptions } from "./sandbox"; + +type Gpu = { type: string } | null; +type Agent = { displayName?: string } | null; +type WebSearchConfig = { fetchEnabled: true }; +type MessagingChannelConfig = Record; +type SandboxGpuConfig = { sandboxGpuEnabled: boolean; mode: string }; + +function createDeps(overrides: Partial["deps"]> = {}) { + let session = createSession(); + const calls = { + note: vi.fn(), + updateSession: vi.fn((mutator: (value: Session) => Session | void) => { + session = mutator(session) ?? session; + return session; + }), + persistMessaging: vi.fn(), + removeSandbox: vi.fn(), + repairSandbox: vi.fn(), + validateBrave: vi.fn(async () => "brave-key"), + configureWebSearch: vi.fn(async () => null as WebSearchConfig | null), + startStep: vi.fn(async () => undefined), + getRecordedChannels: vi.fn(() => null), + setupMessaging: vi.fn(async () => [] as string[]), + promptName: vi.fn(async () => "my-assistant"), + stopStale: vi.fn(), + createSandbox: vi.fn(async () => "my-assistant"), + updateSandbox: vi.fn(), + setDefault: vi.fn(), + complete: vi.fn(async () => createSession()), + skipped: vi.fn(), + error: vi.fn(), + exit: vi.fn((code: number): never => { + throw new Error(`exit ${code}`); + }), + }; + return { + calls, + deps: { + resolvePath: (value: string) => `/abs/${value}`, + agentSupportsWebSearch: () => true, + note: calls.note, + updateSession: calls.updateSession, + getStoredMessagingChannelConfig: () => null, + hydrateMessagingChannelConfig: (config: MessagingChannelConfig | null) => config, + messagingChannelConfigsEqual: () => true, + persistMessagingChannelConfigToSession: calls.persistMessaging, + getSandboxReuseState: () => "missing", + computeTelegramRequireMention: () => null, + hasSandboxGpuDrift: () => false, + hasWechatConfigDrift: () => false, + getSandboxHermesToolGateways: () => [], + normalizeHermesToolGatewaySelections: (value: unknown) => (Array.isArray(value) ? (value as string[]) : []), + stringSetsEqual: (left: string[], right: string[]) => left.length === right.length && left.every((value) => right.includes(value)), + removeSandboxFromRegistry: calls.removeSandbox, + repairRecordedSandbox: calls.repairSandbox, + ensureValidatedBraveSearchCredential: calls.validateBrave, + configureWebSearch: calls.configureWebSearch, + startRecordedStep: calls.startStep, + getRecordedMessagingChannelsForResume: calls.getRecordedChannels, + getSandboxMessagingChannels: () => ["telegram"], + setupMessagingChannels: calls.setupMessaging, + readMessagingChannelConfigFromEnv: () => null, + promptValidatedSandboxName: calls.promptName, + stopStaleDashboardListenersForSandbox: calls.stopStale, + listRegistrySandboxes: () => ({ sandboxes: [{ name: "old" }] }), + createSandbox: calls.createSandbox, + updateSandboxRegistry: calls.updateSandbox, + setDefaultSandbox: calls.setDefault, + getSandboxAgentRegistryFields: () => ({ agent: null }), + recordStepComplete: calls.complete, + toSessionUpdates: (updates: Record) => updates as SessionUpdates, + skippedStepMessage: calls.skipped, + error: calls.error, + exitProcess: calls.exit, + ...overrides, + }, + getSession: () => session, + }; +} + +function baseOptions( + deps: SandboxStateOptions["deps"], + session: Session | null = createSession(), +): SandboxStateOptions { + return { + resume: false, + fresh: false, + session, + sandboxName: null, + model: "model", + provider: "provider", + nimContainer: null, + webSearchConfig: null, + selectedMessagingChannels: [], + fromDockerfile: null, + agent: null, + gpu: { type: "nvidia" }, + preferredInferenceApi: "openai-completions", + sandboxGpuConfig: { sandboxGpuEnabled: false, mode: "0" }, + hermesToolGateways: [], + controlUiPort: null, + rootDir: "/repo", + deps, + }; +} + +describe("handleSandboxState", () => { + it("creates a sandbox and records messaging/web search state", async () => { + const { deps, calls } = createDeps({ + configureWebSearch: vi.fn(async () => ({ fetchEnabled: true as const })), + readMessagingChannelConfigFromEnv: () => ({ telegram: "polling" }), + }); + calls.setupMessaging.mockResolvedValue(["telegram"]); + + const result = await handleSandboxState(baseOptions(deps)); + + expect(calls.startStep).toHaveBeenCalledWith("sandbox", { provider: "provider", model: "model" }); + expect(calls.setupMessaging).toHaveBeenCalledWith(null, null); + expect(calls.promptName).toHaveBeenCalledWith(null); + expect(calls.createSandbox).toHaveBeenCalledWith( + { type: "nvidia" }, + "model", + "provider", + "openai-completions", + "my-assistant", + { fetchEnabled: true }, + ["telegram"], + null, + null, + null, + { sandboxGpuEnabled: false, mode: "0" }, + [], + ); + expect(calls.updateSandbox).toHaveBeenCalledWith("my-assistant", expect.objectContaining({ model: "model", provider: "provider" })); + expect(calls.setDefault).toHaveBeenCalledWith("my-assistant"); + expect(calls.complete).toHaveBeenCalledWith("sandbox", expect.objectContaining({ sandboxName: "my-assistant" })); + expect(result).toMatchObject({ sandboxName: "my-assistant", selectedMessagingChannels: ["telegram"], webSearchSupported: true }); + }); + + it("reuses a completed ready sandbox on resume", async () => { + const session = createSession({ sandboxName: "saved", messagingChannels: ["slack"] }); + session.steps.sandbox.status = "complete"; + const { deps, calls } = createDeps({ getSandboxReuseState: () => "ready" }); + + const result = await handleSandboxState({ ...baseOptions(deps, session), resume: true, sandboxName: "saved" }); + + expect(calls.createSandbox).not.toHaveBeenCalled(); + expect(calls.skipped).toHaveBeenCalledWith("sandbox", "saved"); + expect(result.selectedMessagingChannels).toEqual(["slack"]); + }); + + it("removes registry state when Telegram mention-mode drift forces sandbox recreation", async () => { + const session = createSession({ telegramConfig: { requireMention: true } }); + session.steps.sandbox.status = "complete"; + const { deps, calls } = createDeps({ + getSandboxReuseState: () => "ready", + computeTelegramRequireMention: () => false, + }); + + await handleSandboxState({ + ...baseOptions(deps, session), + resume: true, + sandboxName: "saved", + }); + + expect(calls.note).toHaveBeenCalledWith(" [resume] TELEGRAM_REQUIRE_MENTION changed; recreating sandbox."); + expect(calls.removeSandbox).toHaveBeenCalledWith("saved"); + expect(calls.createSandbox).toHaveBeenCalled(); + }); + + it("repairs not-ready resumed sandboxes before recreation", async () => { + const session = createSession({ sandboxName: "saved" }); + session.steps.sandbox.status = "complete"; + const { deps, calls } = createDeps({ getSandboxReuseState: () => "not_ready" }); + + await handleSandboxState({ ...baseOptions(deps, session), resume: true, sandboxName: "saved" }); + + expect(calls.repairSandbox).toHaveBeenCalledWith("saved"); + expect(calls.createSandbox).toHaveBeenCalled(); + }); + + it("uses recorded messaging channels on non-interactive resume", async () => { + const { deps, calls } = createDeps({ getRecordedMessagingChannelsForResume: vi.fn(() => ["discord"]) }); + + const result = await handleSandboxState(baseOptions(deps)); + + expect(calls.setupMessaging).not.toHaveBeenCalled(); + expect(calls.note).toHaveBeenCalledWith(" [non-interactive] Reusing messaging channel configuration: discord"); + expect(result.selectedMessagingChannels).toEqual(["discord"]); + }); +}); diff --git a/src/lib/onboard/machine/handlers/sandbox.ts b/src/lib/onboard/machine/handlers/sandbox.ts new file mode 100644 index 0000000000..8c45215ed9 --- /dev/null +++ b/src/lib/onboard/machine/handlers/sandbox.ts @@ -0,0 +1,287 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { Session, SessionUpdates } from "../../../state/onboard-session"; + +export interface SandboxStateOptions { + resume: boolean; + fresh: boolean; + session: Session | null; + sandboxName: string | null; + model: string; + provider: string; + nimContainer: string | null; + webSearchConfig: WebSearchConfig | null; + selectedMessagingChannels: string[]; + fromDockerfile: string | null; + agent: Agent; + gpu: Gpu; + preferredInferenceApi: string | null; + sandboxGpuConfig: SandboxGpuConfig; + hermesToolGateways: string[]; + controlUiPort: number | null; + rootDir: string; + deps: { + resolvePath(value: string): string; + agentSupportsWebSearch(agent: Agent, dockerfilePathOverride: string | null, rootDir: string): boolean; + note(message: string): void; + updateSession(mutator: (session: Session) => Session | void): Session; + getStoredMessagingChannelConfig(sandboxName: string | null, session: Session | null): MessagingChannelConfig | null; + hydrateMessagingChannelConfig(config: MessagingChannelConfig | null): MessagingChannelConfig | null; + messagingChannelConfigsEqual(left: MessagingChannelConfig | null, right: MessagingChannelConfig | null): boolean; + persistMessagingChannelConfigToSession(config: MessagingChannelConfig | null): void; + getSandboxReuseState(sandboxName: string | null): string; + computeTelegramRequireMention(): boolean | null; + hasSandboxGpuDrift(sandboxName: string, config: SandboxGpuConfig): boolean; + hasWechatConfigDrift(session: Session | null): boolean; + getSandboxHermesToolGateways(sandboxName: string): unknown; + normalizeHermesToolGatewaySelections(value: unknown): string[]; + stringSetsEqual(left: string[], right: string[]): boolean; + removeSandboxFromRegistry(sandboxName: string): void; + repairRecordedSandbox(sandboxName: string | null): void; + ensureValidatedBraveSearchCredential(): Promise; + configureWebSearch( + existingConfig: WebSearchConfig | null, + agent: Agent, + dockerfilePathOverride: string | null, + ): Promise; + startRecordedStep(stepName: string, updates: { provider: string; model: string }): Promise; + getRecordedMessagingChannelsForResume( + resume: boolean, + session: Session | null, + sandboxName: string | null, + ): string[] | null; + getSandboxMessagingChannels(sandboxName: string): string[] | null | undefined; + setupMessagingChannels(agent: Agent, existingChannels: string[] | null): Promise; + readMessagingChannelConfigFromEnv(): MessagingChannelConfig | null; + promptValidatedSandboxName(agent: Agent): Promise; + stopStaleDashboardListenersForSandbox(sandboxes: unknown[], sandboxName: string): void; + listRegistrySandboxes(): { sandboxes: unknown[] }; + createSandbox( + gpu: Gpu, + model: string, + provider: string, + preferredInferenceApi: string | null, + sandboxName: string, + webSearchConfig: WebSearchConfig | null, + selectedMessagingChannels: string[], + fromDockerfile: string | null, + agent: Agent, + controlUiPort: number | null, + sandboxGpuConfig: SandboxGpuConfig, + hermesToolGateways: string[], + ): Promise; + updateSandboxRegistry(sandboxName: string, updates: Record): void; + setDefaultSandbox(sandboxName: string): void; + getSandboxAgentRegistryFields(agent: Agent, agentVersionKnown: boolean): Record; + recordStepComplete(stepName: string, updates: SessionUpdates): Promise; + toSessionUpdates(updates: Record): SessionUpdates; + skippedStepMessage(stepName: string, detail?: string | null): void; + error(message?: string): void; + exitProcess(code: number): never; + }; +} + +export interface SandboxStateResult { + sandboxName: string; + webSearchConfig: WebSearchConfig | null; + selectedMessagingChannels: string[]; + webSearchSupported: boolean; + session: Session | null; +} + +function sameEffectiveTelegramRequireMention(left: boolean | null, right: boolean | null): boolean { + return (left ?? false) === (right ?? false); +} + +export async function handleSandboxState({ + resume, + fresh, + session, + sandboxName, + model, + provider, + nimContainer, + webSearchConfig, + selectedMessagingChannels, + fromDockerfile, + agent, + gpu, + preferredInferenceApi, + sandboxGpuConfig, + hermesToolGateways, + controlUiPort, + rootDir, + deps, +}: SandboxStateOptions< + Gpu, + Agent, + WebSearchConfig, + MessagingChannelConfig, + SandboxGpuConfig +>): Promise> { + const webSearchSupportProbePath = fromDockerfile ? deps.resolvePath(fromDockerfile) : null; + const webSearchSupported = deps.agentSupportsWebSearch(agent, webSearchSupportProbePath, rootDir); + if (webSearchConfig && !webSearchSupported) { + deps.note( + ` Web search is not yet supported by ${(agent as { displayName?: string } | null)?.displayName ?? "this sandbox image"}. Clearing stale config.`, + ); + webSearchConfig = null; + if (session) session.webSearchConfig = null; + session = deps.updateSession((current) => { + current.webSearchConfig = null; + return current; + }); + } + + const storedMessagingChannelConfig = deps.getStoredMessagingChannelConfig(sandboxName, session); + const effectiveMessagingChannelConfig = deps.hydrateMessagingChannelConfig(storedMessagingChannelConfig); + const messagingChannelConfigChanged = !deps.messagingChannelConfigsEqual( + effectiveMessagingChannelConfig, + storedMessagingChannelConfig, + ); + if (effectiveMessagingChannelConfig) { + deps.persistMessagingChannelConfigToSession(effectiveMessagingChannelConfig); + if (session) session.messagingChannelConfig = effectiveMessagingChannelConfig as Session["messagingChannelConfig"]; + } + + const sandboxReuseState = deps.getSandboxReuseState(sandboxName); + const webSearchConfigChanged = Boolean(session?.webSearchConfig) !== Boolean(webSearchConfig); + const currentTelegramRequireMention = deps.computeTelegramRequireMention(); + const recordedTelegramRequireMention = session?.telegramConfig?.requireMention ?? null; + const telegramConfigChanged = !sameEffectiveTelegramRequireMention( + currentTelegramRequireMention, + recordedTelegramRequireMention, + ); + const sandboxGpuConfigChanged = sandboxName ? deps.hasSandboxGpuDrift(sandboxName, sandboxGpuConfig) : false; + const wechatConfigChanged = deps.hasWechatConfigDrift(session); + const recordedHermesToolGateways = sandboxName + ? deps.normalizeHermesToolGatewaySelections(deps.getSandboxHermesToolGateways(sandboxName)) + : []; + const hermesToolGatewayConfigChanged = !deps.stringSetsEqual(recordedHermesToolGateways, hermesToolGateways); + const resumeSandbox = + resume && + !webSearchConfigChanged && + !telegramConfigChanged && + !sandboxGpuConfigChanged && + !wechatConfigChanged && + !messagingChannelConfigChanged && + !hermesToolGatewayConfigChanged && + session?.steps?.sandbox?.status === "complete" && + sandboxReuseState === "ready"; + + if (resumeSandbox) { + if (webSearchConfig) deps.note(" [resume] Reusing Brave Search configuration already baked into the sandbox."); + selectedMessagingChannels = session?.messagingChannels ?? []; + deps.skippedStepMessage("sandbox", sandboxName); + } else { + if (resume && session?.steps?.sandbox?.status === "complete") { + if (webSearchConfigChanged) { + deps.note(" [resume] Web Search configuration changed; recreating sandbox."); + if (sandboxName) deps.removeSandboxFromRegistry(sandboxName); + } else if (telegramConfigChanged) { + deps.note(" [resume] TELEGRAM_REQUIRE_MENTION changed; recreating sandbox."); + if (sandboxName) deps.removeSandboxFromRegistry(sandboxName); + } else if (sandboxGpuConfigChanged) { + deps.note(" [resume] Sandbox GPU settings changed; recreating sandbox."); + if (sandboxName) deps.removeSandboxFromRegistry(sandboxName); + } else if (wechatConfigChanged) { + deps.note(" [resume] WeChat account metadata changed; recreating sandbox."); + if (sandboxName) deps.removeSandboxFromRegistry(sandboxName); + } else if (messagingChannelConfigChanged) { + deps.note(" [resume] Messaging channel configuration changed; recreating sandbox."); + if (sandboxName) deps.removeSandboxFromRegistry(sandboxName); + } else if (hermesToolGatewayConfigChanged) { + deps.note(" [resume] Hermes managed tool gateway selection changed; recreating sandbox."); + if (sandboxName) deps.removeSandboxFromRegistry(sandboxName); + } else if (sandboxReuseState === "not_ready") { + deps.note(` [resume] Recorded sandbox '${sandboxName}' exists but is not ready; recreating it.`); + deps.repairRecordedSandbox(sandboxName); + } else { + deps.note(" [resume] Recorded sandbox state is unavailable; recreating it."); + if (sandboxName) deps.removeSandboxFromRegistry(sandboxName); + } + } + + let nextWebSearchConfig = webSearchConfig; + if (nextWebSearchConfig) { + deps.note(" [resume] Revalidating Brave Search configuration for sandbox recreation."); + const braveApiKey = await deps.ensureValidatedBraveSearchCredential(); + nextWebSearchConfig = braveApiKey ? webSearchConfig : null; + if (nextWebSearchConfig) deps.note(" [resume] Reusing Brave Search configuration."); + } else { + nextWebSearchConfig = await deps.configureWebSearch(null, agent, webSearchSupportProbePath); + } + + await deps.startRecordedStep("sandbox", { provider, model }); + const recordedMessagingChannels = deps.getRecordedMessagingChannelsForResume(resume, session, sandboxName); + if (recordedMessagingChannels) { + selectedMessagingChannels = recordedMessagingChannels; + if (selectedMessagingChannels.length > 0) { + deps.note(` [non-interactive] Reusing messaging channel configuration: ${selectedMessagingChannels.join(", ")}`); + } + } else { + const existing = sandboxName + ? deps.getSandboxMessagingChannels(sandboxName) ?? session?.messagingChannels ?? null + : session?.messagingChannels ?? null; + selectedMessagingChannels = await deps.setupMessagingChannels(agent, existing); + } + const messagingChannelConfig = deps.readMessagingChannelConfigFromEnv(); + session = deps.updateSession((current) => { + current.messagingChannels = selectedMessagingChannels; + current.messagingChannelConfig = messagingChannelConfig as Session["messagingChannelConfig"]; + return current; + }); + + if (!sandboxName) sandboxName = await deps.promptValidatedSandboxName(agent); + if (fresh) deps.stopStaleDashboardListenersForSandbox(deps.listRegistrySandboxes().sandboxes, sandboxName); + sandboxName = await deps.createSandbox( + gpu, + model, + provider, + preferredInferenceApi, + sandboxName, + nextWebSearchConfig, + selectedMessagingChannels, + fromDockerfile, + agent, + controlUiPort, + sandboxGpuConfig, + hermesToolGateways, + ); + webSearchConfig = nextWebSearchConfig; + deps.updateSandboxRegistry(sandboxName, { + model, + provider, + ...deps.getSandboxAgentRegistryFields(agent, !fromDockerfile), + }); + deps.setDefaultSandbox(sandboxName); + session = await deps.recordStepComplete( + "sandbox", + deps.toSessionUpdates({ + sandboxName, + provider, + model, + nimContainer, + webSearchConfig, + messagingChannelConfig, + hermesToolGateways, + }), + ); + } + + if (!sandboxName) { + deps.error(" Onboarding state is incomplete after sandbox setup."); + deps.exitProcess(1); + } + const completedSandboxName = sandboxName; + if (!completedSandboxName) throw new Error("Sandbox name is required after sandbox setup"); + + return { + sandboxName: completedSandboxName, + webSearchConfig, + selectedMessagingChannels, + webSearchSupported, + session, + }; +} From 7fe9e1cba937226c582e7a8ce0fb8cc26c1e7a7d Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 23:44:11 -0700 Subject: [PATCH 11/12] refactor(cli): extract onboard agent setup handler --- src/lib/onboard.ts | 64 ++++----- .../machine/handlers/agent-setup.test.ts | 122 ++++++++++++++++++ .../onboard/machine/handlers/agent-setup.ts | 87 +++++++++++++ 3 files changed, 242 insertions(+), 31 deletions(-) create mode 100644 src/lib/onboard/machine/handlers/agent-setup.test.ts create mode 100644 src/lib/onboard/machine/handlers/agent-setup.ts diff --git a/src/lib/onboard.ts b/src/lib/onboard.ts index b23e7f0fa4..7462486abf 100644 --- a/src/lib/onboard.ts +++ b/src/lib/onboard.ts @@ -280,6 +280,7 @@ const { resolveSandboxImageTagFromCreateOutput } = const nim: typeof import("./inference/nim") = require("./inference/nim"); const onboardSession: typeof import("./state/onboard-session") = require("./state/onboard-session"); const { OnboardRuntime }: typeof import("./onboard/machine/runtime") = require("./onboard/machine/runtime"); +const { handleAgentSetupState }: typeof import("./onboard/machine/handlers/agent-setup") = require("./onboard/machine/handlers/agent-setup"); const { handleGatewayState }: typeof import("./onboard/machine/handlers/gateway") = require("./onboard/machine/handlers/gateway"); const { handlePreflightState }: typeof import("./onboard/machine/handlers/preflight") = require("./onboard/machine/handlers/preflight"); const { handleProviderInferenceState }: typeof import("./onboard/machine/handlers/provider-inference") = require("./onboard/machine/handlers/provider-inference"); @@ -9651,38 +9652,39 @@ async function onboard(opts: OnboardOptions = {}): Promise { selectedMessagingChannels = sandboxStateResult.selectedMessagingChannels; const webSearchSupported = sandboxStateResult.webSearchSupported; - if (agent) { - await agentOnboard.handleAgentSetup(sandboxName, model, provider, agent, resume, session, { - step, - runCaptureOpenshell, - openshellShellCommand, - openshellBinary: getOpenshellBinary(), - buildSandboxConfigSyncScript, - writeSandboxConfigSyncFile, - cleanupTempDir, - startRecordedStep, + const agentSetupResult = await handleAgentSetupState({ + agent, + sandboxName, + model, + provider, + resume, + session, + hermesAuthMethod, + hermesToolGateways, + deps: { + handleAgentSetup: agentOnboard.handleAgentSetup, + agentSetupContext: () => ({ + step, + runCaptureOpenshell, + openshellShellCommand, + openshellBinary: getOpenshellBinary(), + buildSandboxConfigSyncScript, + writeSandboxConfigSyncFile, + cleanupTempDir, + startRecordedStep, + skippedStepMessage, + }), + ensureAgentDashboardForward, + recordStepSkipped, + isOpenclawReady, skippedStepMessage, - }); - ensureAgentDashboardForward(sandboxName, agent); - await recordStepSkipped("openclaw"); - } else { - const resumeOpenclaw = resume && sandboxName && isOpenclawReady(sandboxName); - if (resumeOpenclaw) { - skippedStepMessage("openclaw", sandboxName); - await recordStepComplete( - "openclaw", - toSessionUpdates({ sandboxName, provider, model, hermesAuthMethod, hermesToolGateways }), - ); - } else { - await startRecordedStep("openclaw", { sandboxName, provider, model }); - await setupOpenclaw(sandboxName, model, provider); - await recordStepComplete( - "openclaw", - toSessionUpdates({ sandboxName, provider, model, hermesAuthMethod, hermesToolGateways }), - ); - } - await recordStepSkipped("agent_setup"); - } + startRecordedStep, + setupOpenclaw, + recordStepComplete, + toSessionUpdates: (updates) => toSessionUpdates(updates as Parameters[0]), + }, + }); + session = agentSetupResult.session; const latestSession = onboardSession.loadSession(); const recordedPolicyPresets = Array.isArray(latestSession?.policyPresets) diff --git a/src/lib/onboard/machine/handlers/agent-setup.test.ts b/src/lib/onboard/machine/handlers/agent-setup.test.ts new file mode 100644 index 0000000000..fd9f1d0410 --- /dev/null +++ b/src/lib/onboard/machine/handlers/agent-setup.test.ts @@ -0,0 +1,122 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it, vi } from "vitest"; + +import { createSession, type Session, type SessionUpdates } from "../../../state/onboard-session"; +import { handleAgentSetupState, type AgentSetupStateOptions } from "./agent-setup"; + +type Agent = { name: string; displayName: string }; + +function createDeps(overrides: Partial["deps"]> = {}) { + const calls = { + handleAgentSetup: vi.fn(async () => undefined), + context: vi.fn(() => ({ ctx: true })), + ensureDashboard: vi.fn(() => 18789), + skipped: vi.fn(async () => createSession()), + openclawReady: vi.fn(() => false), + skippedMessage: vi.fn(), + startStep: vi.fn(async () => undefined), + setupOpenclaw: vi.fn(async () => undefined), + complete: vi.fn(async () => createSession()), + }; + return { + calls, + deps: { + handleAgentSetup: calls.handleAgentSetup, + agentSetupContext: calls.context, + ensureAgentDashboardForward: calls.ensureDashboard, + recordStepSkipped: calls.skipped, + isOpenclawReady: calls.openclawReady, + skippedStepMessage: calls.skippedMessage, + startRecordedStep: calls.startStep, + setupOpenclaw: calls.setupOpenclaw, + recordStepComplete: calls.complete, + toSessionUpdates: (updates: Record) => updates as SessionUpdates, + ...overrides, + }, + }; +} + +function baseOptions( + deps: AgentSetupStateOptions["deps"], + agent: Agent | null = null, +): AgentSetupStateOptions { + return { + agent, + sandboxName: "my-assistant", + model: "model", + provider: "provider", + resume: false, + session: createSession(), + hermesAuthMethod: null, + hermesToolGateways: [], + deps, + }; +} + +describe("handleAgentSetupState", () => { + it("delegates non-OpenClaw agent setup and skips openclaw", async () => { + const { deps, calls } = createDeps(); + const agent = { name: "hermes", displayName: "Hermes" }; + const session = createSession(); + + await handleAgentSetupState({ ...baseOptions(deps, agent), session, resume: true }); + + expect(calls.handleAgentSetup).toHaveBeenCalledWith( + "my-assistant", + "model", + "provider", + agent, + true, + session, + { ctx: true }, + ); + expect(calls.ensureDashboard).toHaveBeenCalledWith("my-assistant", agent); + expect(calls.skipped).toHaveBeenCalledWith("openclaw"); + expect(calls.setupOpenclaw).not.toHaveBeenCalled(); + }); + + it("skips OpenClaw setup on resume when OpenClaw is ready", async () => { + const { deps, calls } = createDeps({ isOpenclawReady: vi.fn(() => true) }); + + await handleAgentSetupState({ ...baseOptions(deps), resume: true }); + + expect(calls.skippedMessage).toHaveBeenCalledWith("openclaw", "my-assistant"); + expect(calls.startStep).not.toHaveBeenCalled(); + expect(calls.setupOpenclaw).not.toHaveBeenCalled(); + expect(calls.complete).toHaveBeenCalledWith( + "openclaw", + expect.objectContaining({ sandboxName: "my-assistant", provider: "provider", model: "model" }), + ); + expect(calls.skipped).toHaveBeenCalledWith("agent_setup"); + }); + + it("runs OpenClaw setup and skips agent_setup for the default agent", async () => { + const { deps, calls } = createDeps(); + + await handleAgentSetupState({ + ...baseOptions(deps), + hermesAuthMethod: "oauth", + hermesToolGateways: ["github"], + }); + + expect(calls.startStep).toHaveBeenCalledWith("openclaw", { + sandboxName: "my-assistant", + provider: "provider", + model: "model", + }); + expect(calls.setupOpenclaw).toHaveBeenCalledWith("my-assistant", "model", "provider"); + expect(calls.complete).toHaveBeenCalledWith( + "openclaw", + expect.objectContaining({ + sandboxName: "my-assistant", + provider: "provider", + model: "model", + hermesAuthMethod: "oauth", + hermesToolGateways: ["github"], + }), + ); + expect(calls.skipped).toHaveBeenCalledWith("agent_setup"); + }); +}); diff --git a/src/lib/onboard/machine/handlers/agent-setup.ts b/src/lib/onboard/machine/handlers/agent-setup.ts new file mode 100644 index 0000000000..40330711ad --- /dev/null +++ b/src/lib/onboard/machine/handlers/agent-setup.ts @@ -0,0 +1,87 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { Session, SessionUpdates } from "../../../state/onboard-session"; + +export interface AgentSetupStateOptions { + agent: Agent | null; + sandboxName: string; + model: string; + provider: string; + resume: boolean; + session: Session | null; + hermesAuthMethod: string | null; + hermesToolGateways: string[]; + deps: { + handleAgentSetup( + sandboxName: string, + model: string, + provider: string, + agent: Agent, + resume: boolean, + session: Session | null, + context: unknown, + ): Promise; + agentSetupContext(): unknown; + ensureAgentDashboardForward(sandboxName: string, agent: Agent): number; + recordStepSkipped(stepName: string): Promise; + isOpenclawReady(sandboxName: string): boolean; + skippedStepMessage(stepName: string, detail?: string | null): void; + startRecordedStep( + stepName: string, + updates: { sandboxName: string; provider: string; model: string }, + ): Promise; + setupOpenclaw(sandboxName: string, model: string, provider: string): Promise; + recordStepComplete(stepName: string, updates: SessionUpdates): Promise; + toSessionUpdates(updates: Record): SessionUpdates; + }; +} + +export interface AgentSetupStateResult { + session: Session | null; +} + +export async function handleAgentSetupState({ + agent, + sandboxName, + model, + provider, + resume, + session, + hermesAuthMethod, + hermesToolGateways, + deps, +}: AgentSetupStateOptions): Promise { + if (agent) { + await deps.handleAgentSetup( + sandboxName, + model, + provider, + agent, + resume, + session, + deps.agentSetupContext(), + ); + deps.ensureAgentDashboardForward(sandboxName, agent); + session = await deps.recordStepSkipped("openclaw"); + return { session }; + } + + const resumeOpenclaw = resume && sandboxName && deps.isOpenclawReady(sandboxName); + if (resumeOpenclaw) { + deps.skippedStepMessage("openclaw", sandboxName); + session = await deps.recordStepComplete( + "openclaw", + deps.toSessionUpdates({ sandboxName, provider, model, hermesAuthMethod, hermesToolGateways }), + ); + } else { + await deps.startRecordedStep("openclaw", { sandboxName, provider, model }); + await deps.setupOpenclaw(sandboxName, model, provider); + session = await deps.recordStepComplete( + "openclaw", + deps.toSessionUpdates({ sandboxName, provider, model, hermesAuthMethod, hermesToolGateways }), + ); + } + session = await deps.recordStepSkipped("agent_setup"); + return { session }; +} From b9daca0cd003aeffc535592b5468893e005515e7 Mon Sep 17 00:00:00 2001 From: Carlos Villela Date: Tue, 19 May 2026 23:55:05 -0700 Subject: [PATCH 12/12] refactor(cli): extract onboard policies handler --- src/lib/onboard.ts | 114 +++-------- .../onboard/machine/handlers/policies.test.ts | 182 +++++++++++++++++ src/lib/onboard/machine/handlers/policies.ts | 189 ++++++++++++++++++ 3 files changed, 401 insertions(+), 84 deletions(-) create mode 100644 src/lib/onboard/machine/handlers/policies.test.ts create mode 100644 src/lib/onboard/machine/handlers/policies.ts diff --git a/src/lib/onboard.ts b/src/lib/onboard.ts index 7462486abf..e406d8ca0c 100644 --- a/src/lib/onboard.ts +++ b/src/lib/onboard.ts @@ -282,6 +282,7 @@ const onboardSession: typeof import("./state/onboard-session") = require("./stat const { OnboardRuntime }: typeof import("./onboard/machine/runtime") = require("./onboard/machine/runtime"); const { handleAgentSetupState }: typeof import("./onboard/machine/handlers/agent-setup") = require("./onboard/machine/handlers/agent-setup"); const { handleGatewayState }: typeof import("./onboard/machine/handlers/gateway") = require("./onboard/machine/handlers/gateway"); +const { handlePoliciesState }: typeof import("./onboard/machine/handlers/policies") = require("./onboard/machine/handlers/policies"); const { handlePreflightState }: typeof import("./onboard/machine/handlers/preflight") = require("./onboard/machine/handlers/preflight"); const { handleProviderInferenceState }: typeof import("./onboard/machine/handlers/provider-inference") = require("./onboard/machine/handlers/provider-inference"); const { handleSandboxState }: typeof import("./onboard/machine/handlers/sandbox") = require("./onboard/machine/handlers/sandbox"); @@ -9686,97 +9687,42 @@ async function onboard(opts: OnboardOptions = {}): Promise { }); session = agentSetupResult.session; - const latestSession = onboardSession.loadSession(); - const recordedPolicyPresets = Array.isArray(latestSession?.policyPresets) - ? latestSession.policyPresets - : null; - const recordedMessagingChannels = Array.isArray(latestSession?.messagingChannels) - ? latestSession.messagingChannels - : []; - const activeMessagingChannels = registry.getSandbox(sandboxName)?.messagingChannels; - verifyCompatibleEndpointSandboxSmoke({ + const policiesResult = await handlePoliciesState({ + resume, sandboxName, provider, model, - runOpenshell, - redact, endpointUrl, credentialEnv, - messagingChannels: Array.isArray(activeMessagingChannels) ? activeMessagingChannels : [], + selectedMessagingChannels, + webSearchConfig, + webSearchSupported, + hermesToolGateways, agent, + deps: { + loadSession: onboardSession.loadSession, + getActiveMessagingChannels: (name) => registry.getSandbox(name)?.messagingChannels, + verifyCompatibleEndpointSandboxSmoke: (options) => + verifyCompatibleEndpointSandboxSmoke({ + ...options, + runOpenshell, + redact, + }), + listSetupPolicyPresets: policies.listSetupPolicyPresets, + getAppliedPolicyPresets: policies.getAppliedPresets, + listCustomPolicyPresets: policies.listCustomPresets, + clampSetupPolicyPresetNames: policies.clampSetupPolicyPresetNames, + mergeRequiredHermesToolGatewayPolicyPresets, + arePolicyPresetsApplied, + skippedStepMessage, + startRecordedStep, + setupPoliciesWithSelection, + updateSession: onboardSession.updateSession, + recordStepComplete, + toSessionUpdates: (updates) => toSessionUpdates(updates as Parameters[0]), + }, }); - const policyPresetSupportOptions = { webSearchSupported }; - const selectablePolicyPresetsForSupport = [ - ...policies.listSetupPolicyPresets(sandboxName, policyPresetSupportOptions), - ...policies.getAppliedPresets(sandboxName).map((name) => ({ name })), - ]; - const customPolicyPresetNames = new Set( - policies.listCustomPresets(sandboxName).map((p: { name: string }) => p.name), - ); - let recordedPolicyPresetsForSupport = policies.clampSetupPolicyPresetNames( - recordedPolicyPresets || [], - selectablePolicyPresetsForSupport, - policyPresetSupportOptions, - customPolicyPresetNames, - ); - if (recordedPolicyPresets) { - recordedPolicyPresetsForSupport = mergeRequiredHermesToolGatewayPolicyPresets( - recordedPolicyPresetsForSupport, - hermesToolGateways, - selectablePolicyPresetsForSupport.map((p) => p.name), - ); - } - const recordedPolicyPresetsHaveUnsupported = - Array.isArray(recordedPolicyPresets) && - recordedPolicyPresetsForSupport.length !== recordedPolicyPresets.length; - const resumePolicies = - resume && - sandboxName && - !recordedPolicyPresetsHaveUnsupported && - arePolicyPresetsApplied(sandboxName, recordedPolicyPresetsForSupport); - if (resumePolicies) { - skippedStepMessage("policies", recordedPolicyPresetsForSupport.join(", ")); - await recordStepComplete( - "policies", - toSessionUpdates({ - sandboxName, - provider, - model, - policyPresets: recordedPolicyPresetsForSupport, - }), - ); - } else { - await startRecordedStep("policies", { - sandboxName, - provider, - model, - policyPresets: recordedPolicyPresetsForSupport, - }); - const appliedPolicyPresets = await setupPoliciesWithSelection(sandboxName, { - selectedPresets: - Array.isArray(recordedPolicyPresets) - ? recordedPolicyPresetsForSupport - : null, - enabledChannels: - selectedMessagingChannels.length > 0 - ? selectedMessagingChannels - : recordedMessagingChannels, - webSearchConfig, - provider, - webSearchSupported, - hermesToolGateways, - onSelection: (policyPresets) => { - onboardSession.updateSession((current: Session) => { - current.policyPresets = policyPresets; - return current; - }); - }, - }); - await recordStepComplete( - "policies", - toSessionUpdates({ sandboxName, provider, model, policyPresets: appliedPolicyPresets }), - ); - } + session = policiesResult.session; if (agent) { ensureAgentDashboardForward(sandboxName, agent); diff --git a/src/lib/onboard/machine/handlers/policies.test.ts b/src/lib/onboard/machine/handlers/policies.test.ts new file mode 100644 index 0000000000..ee315d34f0 --- /dev/null +++ b/src/lib/onboard/machine/handlers/policies.test.ts @@ -0,0 +1,182 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it, vi } from "vitest"; + +import { createSession, type Session, type SessionUpdates } from "../../../state/onboard-session"; +import { handlePoliciesState, type PoliciesStateOptions } from "./policies"; + +type Agent = { name: string } | null; +type WebSearchConfig = { fetchEnabled: true }; + +function createDeps(overrides: Partial["deps"]> = {}) { + let session = createSession(); + const calls = { + load: vi.fn(() => session), + activeChannels: vi.fn(() => ["telegram"]), + smoke: vi.fn(), + listSetup: vi.fn(() => [{ name: "npm" }, { name: "pypi" }, { name: "github" }]), + applied: vi.fn(() => [] as string[]), + custom: vi.fn(() => [] as { name: string }[]), + clamp: vi.fn((names: string[]) => names.filter((name) => name !== "unsupported")), + mergeHermes: vi.fn((selected: string[], tools: string[]) => [...selected, ...tools]), + appliedCheck: vi.fn(() => false), + skipped: vi.fn(), + startStep: vi.fn(async () => undefined), + setupPolicies: vi.fn(async () => ["npm"]), + updateSession: vi.fn((mutator: (value: Session) => Session | void) => { + session = mutator(session) ?? session; + return session; + }), + complete: vi.fn(async () => session), + }; + return { + calls, + deps: { + loadSession: calls.load, + getActiveMessagingChannels: calls.activeChannels, + verifyCompatibleEndpointSandboxSmoke: calls.smoke, + listSetupPolicyPresets: calls.listSetup, + getAppliedPolicyPresets: calls.applied, + listCustomPolicyPresets: calls.custom, + clampSetupPolicyPresetNames: calls.clamp, + mergeRequiredHermesToolGatewayPolicyPresets: calls.mergeHermes, + arePolicyPresetsApplied: calls.appliedCheck, + skippedStepMessage: calls.skipped, + startRecordedStep: calls.startStep, + setupPoliciesWithSelection: calls.setupPolicies, + updateSession: calls.updateSession, + recordStepComplete: calls.complete, + toSessionUpdates: (updates: Record) => updates as SessionUpdates, + ...overrides, + }, + setSession(next: Session) { + session = next; + }, + getSession: () => session, + }; +} + +function baseOptions( + deps: PoliciesStateOptions["deps"], +): PoliciesStateOptions { + return { + resume: false, + sandboxName: "my-assistant", + provider: "provider", + model: "model", + endpointUrl: "https://example.com/v1", + credentialEnv: "NVIDIA_API_KEY", + selectedMessagingChannels: [], + webSearchConfig: null, + webSearchSupported: true, + hermesToolGateways: [], + agent: null, + deps, + }; +} + +describe("handlePoliciesState", () => { + it("runs compatible endpoint smoke before policy selection", async () => { + const { deps, calls } = createDeps(); + + await handlePoliciesState(baseOptions(deps)); + + expect(calls.smoke).toHaveBeenCalledWith({ + sandboxName: "my-assistant", + provider: "provider", + model: "model", + endpointUrl: "https://example.com/v1", + credentialEnv: "NVIDIA_API_KEY", + messagingChannels: ["telegram"], + agent: null, + }); + expect(calls.startStep).toHaveBeenCalledWith("policies", { + sandboxName: "my-assistant", + provider: "provider", + model: "model", + policyPresets: [], + }); + expect(calls.setupPolicies).toHaveBeenCalledWith( + "my-assistant", + expect.objectContaining({ + selectedPresets: null, + enabledChannels: [], + provider: "provider", + webSearchSupported: true, + }), + ); + expect(calls.complete).toHaveBeenCalledWith( + "policies", + expect.objectContaining({ policyPresets: ["npm"] }), + ); + }); + + it("uses recorded messaging channels when no active selection exists", async () => { + const session = createSession({ messagingChannels: ["slack"] }); + const { deps, calls, setSession } = createDeps(); + setSession(session); + + await handlePoliciesState(baseOptions(deps)); + + expect(calls.setupPolicies).toHaveBeenCalledWith( + "my-assistant", + expect.objectContaining({ enabledChannels: ["slack"] }), + ); + }); + + it("resumes policies when all recorded presets are already applied", async () => { + const session = createSession({ policyPresets: ["npm"] }); + const { deps, calls, setSession } = createDeps({ + arePolicyPresetsApplied: vi.fn(() => true), + }); + setSession(session); + + const result = await handlePoliciesState({ ...baseOptions(deps), resume: true }); + + expect(calls.skipped).toHaveBeenCalledWith("policies", "npm"); + expect(calls.setupPolicies).not.toHaveBeenCalled(); + expect(calls.complete).toHaveBeenCalledWith( + "policies", + expect.objectContaining({ policyPresets: ["npm"] }), + ); + expect(result.appliedPolicyPresets).toEqual(["npm"]); + }); + + it("clamps unsupported recorded presets before interactive setup", async () => { + const session = createSession({ policyPresets: ["npm", "unsupported"] }); + const { deps, calls, setSession } = createDeps(); + setSession(session); + + await handlePoliciesState(baseOptions(deps)); + + expect(calls.clamp).toHaveBeenCalledWith( + ["npm", "unsupported"], + expect.any(Array), + { webSearchSupported: true }, + expect.any(Set), + ); + expect(calls.setupPolicies).toHaveBeenCalledWith( + "my-assistant", + expect.objectContaining({ selectedPresets: ["npm"] }), + ); + }); + + it("merges required Hermes tool gateway presets into recorded selections", async () => { + const session = createSession({ policyPresets: ["npm"] }); + const { deps, calls, setSession } = createDeps(); + setSession(session); + + await handlePoliciesState({ ...baseOptions(deps), hermesToolGateways: ["github"] }); + + expect(calls.mergeHermes).toHaveBeenCalledWith( + ["npm"], + ["github"], + ["npm", "pypi", "github"], + ); + expect(calls.setupPolicies).toHaveBeenCalledWith( + "my-assistant", + expect.objectContaining({ selectedPresets: ["npm", "github"] }), + ); + }); +}); diff --git a/src/lib/onboard/machine/handlers/policies.ts b/src/lib/onboard/machine/handlers/policies.ts new file mode 100644 index 0000000000..ad35931cbf --- /dev/null +++ b/src/lib/onboard/machine/handlers/policies.ts @@ -0,0 +1,189 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { Session, SessionUpdates } from "../../../state/onboard-session"; + +export interface PolicyPresetEntry { + name: string; + [key: string]: unknown; +} + +export interface PoliciesStateOptions { + resume: boolean; + sandboxName: string; + provider: string; + model: string; + endpointUrl: string | null; + credentialEnv: string | null; + selectedMessagingChannels: string[]; + webSearchConfig: WebSearchConfig | null; + webSearchSupported: boolean; + hermesToolGateways: string[]; + agent: Agent; + deps: { + loadSession(): Session | null; + getActiveMessagingChannels(sandboxName: string): string[] | null | undefined; + verifyCompatibleEndpointSandboxSmoke(options: { + sandboxName: string; + provider: string; + model: string; + endpointUrl: string | null; + credentialEnv: string | null; + messagingChannels: string[]; + agent: Agent; + }): void; + listSetupPolicyPresets( + sandboxName: string, + options: { webSearchSupported: boolean }, + ): PolicyPresetEntry[]; + getAppliedPolicyPresets(sandboxName: string): string[]; + listCustomPolicyPresets(sandboxName: string): PolicyPresetEntry[]; + clampSetupPolicyPresetNames( + names: string[], + selectablePresets: PolicyPresetEntry[], + options: { webSearchSupported: boolean }, + customPresetNames: Set, + ): string[]; + mergeRequiredHermesToolGatewayPolicyPresets( + selectedPresets: string[], + hermesToolGateways: string[], + selectablePresetNames: string[], + ): string[]; + arePolicyPresetsApplied(sandboxName: string, selectedPresets: string[]): boolean; + skippedStepMessage(stepName: string, detail?: string | null): void; + startRecordedStep( + stepName: string, + updates: { sandboxName: string; provider: string; model: string; policyPresets: string[] }, + ): Promise; + setupPoliciesWithSelection( + sandboxName: string, + options: { + selectedPresets: string[] | null; + enabledChannels: string[]; + webSearchConfig: WebSearchConfig | null; + provider: string; + webSearchSupported: boolean; + hermesToolGateways: string[]; + onSelection: (policyPresets: string[]) => void; + }, + ): Promise; + updateSession(mutator: (session: Session) => Session | void): Session; + recordStepComplete(stepName: string, updates: SessionUpdates): Promise; + toSessionUpdates(updates: Record): SessionUpdates; + }; +} + +export interface PoliciesStateResult { + session: Session | null; + recordedMessagingChannels: string[]; + appliedPolicyPresets: string[]; +} + +export async function handlePoliciesState({ + resume, + sandboxName, + provider, + model, + endpointUrl, + credentialEnv, + selectedMessagingChannels, + webSearchConfig, + webSearchSupported, + hermesToolGateways, + agent, + deps, +}: PoliciesStateOptions): Promise { + const latestSession = deps.loadSession(); + const recordedPolicyPresets = Array.isArray(latestSession?.policyPresets) + ? latestSession.policyPresets + : null; + const recordedMessagingChannels = Array.isArray(latestSession?.messagingChannels) + ? latestSession.messagingChannels + : []; + const activeMessagingChannels = deps.getActiveMessagingChannels(sandboxName); + deps.verifyCompatibleEndpointSandboxSmoke({ + sandboxName, + provider, + model, + endpointUrl, + credentialEnv, + messagingChannels: Array.isArray(activeMessagingChannels) ? activeMessagingChannels : [], + agent, + }); + + const policyPresetSupportOptions = { webSearchSupported }; + const selectablePolicyPresetsForSupport = [ + ...deps.listSetupPolicyPresets(sandboxName, policyPresetSupportOptions), + ...deps.getAppliedPolicyPresets(sandboxName).map((name) => ({ name })), + ]; + const customPolicyPresetNames = new Set( + deps.listCustomPolicyPresets(sandboxName).map((preset) => preset.name), + ); + let recordedPolicyPresetsForSupport = deps.clampSetupPolicyPresetNames( + recordedPolicyPresets || [], + selectablePolicyPresetsForSupport, + policyPresetSupportOptions, + customPolicyPresetNames, + ); + if (recordedPolicyPresets) { + recordedPolicyPresetsForSupport = deps.mergeRequiredHermesToolGatewayPolicyPresets( + recordedPolicyPresetsForSupport, + hermesToolGateways, + selectablePolicyPresetsForSupport.map((preset) => preset.name), + ); + } + const recordedPolicyPresetsHaveUnsupported = + Array.isArray(recordedPolicyPresets) && + recordedPolicyPresetsForSupport.length !== recordedPolicyPresets.length; + const resumePolicies = + resume && + !recordedPolicyPresetsHaveUnsupported && + deps.arePolicyPresetsApplied(sandboxName, recordedPolicyPresetsForSupport); + + let appliedPolicyPresets = recordedPolicyPresetsForSupport; + let session: Session | null; + if (resumePolicies) { + deps.skippedStepMessage("policies", recordedPolicyPresetsForSupport.join(", ")); + session = await deps.recordStepComplete( + "policies", + deps.toSessionUpdates({ + sandboxName, + provider, + model, + policyPresets: recordedPolicyPresetsForSupport, + }), + ); + } else { + await deps.startRecordedStep("policies", { + sandboxName, + provider, + model, + policyPresets: recordedPolicyPresetsForSupport, + }); + appliedPolicyPresets = await deps.setupPoliciesWithSelection(sandboxName, { + selectedPresets: Array.isArray(recordedPolicyPresets) + ? recordedPolicyPresetsForSupport + : null, + enabledChannels: + selectedMessagingChannels.length > 0 + ? selectedMessagingChannels + : recordedMessagingChannels, + webSearchConfig, + provider, + webSearchSupported, + hermesToolGateways, + onSelection: (policyPresets) => { + deps.updateSession((current) => { + current.policyPresets = policyPresets; + return current; + }); + }, + }); + session = await deps.recordStepComplete( + "policies", + deps.toSessionUpdates({ sandboxName, provider, model, policyPresets: appliedPolicyPresets }), + ); + } + + return { session, recordedMessagingChannels, appliedPolicyPresets }; +}