Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/lib/onboard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ const {
getProviderSelectionConfig,
parseGatewayInference,
} = inferenceConfig;
const { ensureResumeProviderReady } = require("./onboard/resume-provider-shim");

const onboardProviders = require("./onboard/providers");
const hermesProviderAuth = require("./hermes-provider-auth");
Expand Down Expand Up @@ -1503,8 +1504,6 @@ const {
shouldForceCompletionsApi,
} = validation;

// validateNvidiaApiKeyValue — see validation import above

async function replaceNamedCredential(
envName: string,
label: string,
Expand Down Expand Up @@ -9439,6 +9438,7 @@ async function onboard(opts: OnboardOptions = {}): Promise<void> {
recordStepComplete,
toSessionUpdates: (updates) => toSessionUpdates(updates as Parameters<typeof toSessionUpdates>[0]),
skippedStepMessage,
ensureResumeProviderReady,
hydrateCredentialEnv,
repairLocalInferenceSystemdOverrideOrExit,
isNonInteractive,
Expand Down Expand Up @@ -9992,7 +9992,6 @@ module.exports = {
recoverGatewayRuntime,
buildChain,
buildControlUiUrls,

startGateway,
findAvailableDashboardPort,
findDashboardForwardOwner,
Expand Down Expand Up @@ -10048,4 +10047,5 @@ module.exports = {
checkTelegramReachability,
TELEGRAM_NETWORK_CURL_CODES,
verifyCompatibleEndpointSandboxSmoke,
resumeProviderShimDeps: { isRoutedInferenceProvider, replaceNamedCredential },
};
40 changes: 40 additions & 0 deletions src/lib/onboard/machine/handlers/provider-inference.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ function createDeps(overrides: Partial<ProviderInferenceStateOptions<Gpu, Agent,
startStep: vi.fn(async () => undefined),
complete: vi.fn(async () => createSession()),
skipped: vi.fn(),
recoverProvider: vi.fn(async (_provider: string | null | undefined, credentialEnv: string | null | undefined) => ({
forceInferenceSetup: false,
credentialEnv: credentialEnv ?? null,
})),
hydrate: vi.fn(),
repair: vi.fn(),
routeReady: vi.fn(() => false),
Expand All @@ -56,6 +60,7 @@ function createDeps(overrides: Partial<ProviderInferenceStateOptions<Gpu, Agent,
recordStepComplete: calls.complete,
toSessionUpdates: (updates: Record<string, unknown>) => updates as SessionUpdates,
skippedStepMessage: calls.skipped,
ensureResumeProviderReady: calls.recoverProvider,
hydrateCredentialEnv: calls.hydrate,
repairLocalInferenceSystemdOverrideOrExit: calls.repair,
isNonInteractive: () => true,
Expand Down Expand Up @@ -166,13 +171,48 @@ describe("handleProviderInferenceState", () => {

expect(calls.setupNim).not.toHaveBeenCalled();
expect(calls.setupInference).not.toHaveBeenCalled();
expect(calls.recoverProvider).toHaveBeenCalledWith("ollama-local", null);
expect(calls.skipped).toHaveBeenCalledWith("provider_selection", "ollama-local / llama3.1");
expect(calls.hydrate).toHaveBeenCalledWith(null);
expect(calls.repair).toHaveBeenCalledWith("ollama-local", deps.isNonInteractive);
expect(calls.skipped).toHaveBeenCalledWith("inference", "ollama-local / llama3.1");
expect(result).toMatchObject({ provider: "ollama-local", model: "llama3.1" });
});

it("reruns inference setup when resumed provider recovery forces recreation", async () => {
const session = createSession({
provider: "compatible-endpoint",
model: "custom-model",
credentialEnv: null,
});
session.steps.provider_selection.status = "complete";
const { deps, calls } = createDeps({
isInferenceRouteReady: vi.fn(() => true),
ensureResumeProviderReady: vi.fn(async () => ({
forceInferenceSetup: true,
credentialEnv: "COMPATIBLE_API_KEY",
})),
});

await handleProviderInferenceState({
...baseOptions(deps, session),
resume: true,
sandboxName: "my-assistant",
});

expect(calls.setupNim).not.toHaveBeenCalled();
expect(calls.hydrate).toHaveBeenCalledWith("COMPATIBLE_API_KEY");
expect(calls.setupInference).toHaveBeenCalledWith(
"my-assistant",
"custom-model",
"compatible-endpoint",
null,
"COMPATIBLE_API_KEY",
null,
[],
);
});

it("reconciles model router on resumed routed inference", async () => {
const session = createSession({ provider: "nvidia-router", model: "router/model" });
session.steps.provider_selection.status = "complete";
Expand Down
9 changes: 9 additions & 0 deletions src/lib/onboard/machine/handlers/provider-inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ export interface ProviderInferenceStateOptions<Gpu, Agent, Host> {
recordStepComplete(stepName: string, updates: SessionUpdates): Promise<Session>;
toSessionUpdates(updates: Record<string, unknown>): SessionUpdates;
skippedStepMessage(stepName: string, detail?: string | null): void;
ensureResumeProviderReady(
provider: string | null | undefined,
credentialEnv: string | null | undefined,
): Promise<{ forceInferenceSetup: boolean; credentialEnv: string | null }>;
hydrateCredentialEnv(credentialEnv: string | null): void;
repairLocalInferenceSystemdOverrideOrExit(provider: string | null, isNonInteractive: () => boolean): void;
isNonInteractive(): boolean;
Expand Down Expand Up @@ -143,13 +147,17 @@ export async function handleProviderInferenceState<Gpu, Agent, Host>({
let forceProviderSelection = initialForceProviderSelection;

while (true) {
let forceInferenceSetup = false;
const resumeProviderSelection =
!forceProviderSelection &&
resume &&
session?.steps?.provider_selection?.status === "complete" &&
typeof provider === "string" &&
typeof model === "string";
if (resumeProviderSelection) {
const recovery = await deps.ensureResumeProviderReady(provider, credentialEnv);
forceInferenceSetup = recovery.forceInferenceSetup;
credentialEnv = recovery.credentialEnv;
deps.skippedStepMessage("provider_selection", `${provider} / ${model}`);
deps.hydrateCredentialEnv(credentialEnv);
deps.repairLocalInferenceSystemdOverrideOrExit(provider, deps.isNonInteractive);
Expand Down Expand Up @@ -187,6 +195,7 @@ export async function handleProviderInferenceState<Gpu, Agent, Host>({
const resumeInference =
!needsBedrockRuntimeAdapter &&
!forceProviderSelection &&
!forceInferenceSetup &&
resume &&
deps.isInferenceRouteReady(provider, model);
if (resumeInference) {
Expand Down
144 changes: 144 additions & 0 deletions src/lib/onboard/resume-provider-recovery.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

import { describe, expect, it } from "vitest";

import {
ensureResumeProviderReady,
type RemoteProviderConfigEntry,
type ResumeProviderRecoveryDeps,
} from "./resume-provider-recovery";

const COMPATIBLE_ENDPOINT_CONFIG: RemoteProviderConfigEntry = {
label: "Compatible Endpoint",
providerName: "compatible-endpoint",
providerType: "openai",
credentialEnv: "COMPATIBLE_API_KEY",
endpointUrl: "https://example/v1",
helpUrl: null,
modelMode: "input",
defaultModel: "test-model",
};

type DepsRecorder = {
log: string[];
warn: string[];
note: string[];
exitCalls: number[];
replaceCalls: Array<{ env: string; label: string }>;
deps: ResumeProviderRecoveryDeps;
};

function makeDeps(overrides: {
providerExists?: boolean;
credentialValue?: string | null;
nonInteractive?: boolean;
remoteProviderConfig?: Record<string, RemoteProviderConfigEntry>;
}): DepsRecorder {
const log: string[] = [];
const warn: string[] = [];
const note: string[] = [];
const exitCalls: number[] = [];
const replaceCalls: Array<{ env: string; label: string }> = [];
const deps: ResumeProviderRecoveryDeps = {
remoteProviderConfig: overrides.remoteProviderConfig ?? {
compatible: COMPATIBLE_ENDPOINT_CONFIG,
},
defaultRouteCredentialEnv: "OPENAI_API_KEY",
isRoutedInferenceProvider: () => false,
providerExistsInGateway: () => overrides.providerExists ?? true,
hydrateCredentialEnv: () => overrides.credentialValue ?? null,
getProviderLabel: (key) => key,
isNonInteractive: () => overrides.nonInteractive ?? false,
log: (m) => log.push(m),
warn: (m) => warn.push(m),
note: (m) => note.push(m),
exit: (code) => exitCalls.push(code),
replaceNamedCredential: async (env, label) => {
replaceCalls.push({ env, label });
return "fresh-key";
},
validateNvidiaApiKeyValue: () => null,
};
return { log, warn, note, exitCalls, replaceCalls, deps };
}

describe("ensureResumeProviderReady", () => {
it("returns false-forced when no provider is set (nothing to recover)", async () => {
const { deps } = makeDeps({ providerExists: false });
const result = await ensureResumeProviderReady(null, null, deps);
expect(result.forceInferenceSetup).toBe(false);
expect(result.credentialEnv).toBeNull();
});

it("returns false-forced when the provider is unknown and not a routed provider", async () => {
const { deps } = makeDeps({ providerExists: false });
const result = await ensureResumeProviderReady("mystery-provider", null, deps);
expect(result.forceInferenceSetup).toBe(false);
expect(result.credentialEnv).toBeNull();
});

it("returns false-forced when the provider still exists in the gateway", async () => {
const { deps } = makeDeps({ providerExists: true });
const result = await ensureResumeProviderReady("compatible-endpoint", "COMPATIBLE_API_KEY", deps);
expect(result.forceInferenceSetup).toBe(false);
expect(result.credentialEnv).toBe("COMPATIBLE_API_KEY");
});

it("emits a [resume] note and forces inference setup when credential is already hydrated", async () => {
const recorder = makeDeps({
providerExists: false,
credentialValue: "already-hydrated-key",
});
const result = await ensureResumeProviderReady(
"compatible-endpoint",
"COMPATIBLE_API_KEY",
recorder.deps,
);
expect(result.forceInferenceSetup).toBe(true);
expect(result.credentialEnv).toBe("COMPATIBLE_API_KEY");
expect(recorder.note.join("\n")).toContain("[resume]");
expect(recorder.replaceCalls).toHaveLength(0);
});

it("returns the config credential env when the resumed session did not record one", async () => {
const recorder = makeDeps({
providerExists: false,
credentialValue: "already-hydrated-key",
});
const result = await ensureResumeProviderReady("compatible-endpoint", null, recorder.deps);
expect(result.forceInferenceSetup).toBe(true);
expect(result.credentialEnv).toBe("COMPATIBLE_API_KEY");
});

it("re-prompts for credentials when the provider was reset and credential is missing (#3278)", async () => {
const recorder = makeDeps({
providerExists: false,
credentialValue: null,
});
const result = await ensureResumeProviderReady(
"compatible-endpoint",
"COMPATIBLE_API_KEY",
recorder.deps,
);
expect(result.forceInferenceSetup).toBe(true);
expect(result.credentialEnv).toBe("COMPATIBLE_API_KEY");
expect(recorder.replaceCalls).toEqual([
{ env: "COMPATIBLE_API_KEY", label: "Compatible Endpoint API key" },
]);
expect(recorder.exitCalls).toEqual([]);
});

it("exits 1 in non-interactive mode when the provider is missing and no credential is set", async () => {
const recorder = makeDeps({
providerExists: false,
credentialValue: null,
nonInteractive: true,
});
await ensureResumeProviderReady("compatible-endpoint", "COMPATIBLE_API_KEY", recorder.deps);
expect(recorder.exitCalls).toEqual([1]);
expect(recorder.warn.join("\n")).toContain("COMPATIBLE_API_KEY");
expect(recorder.warn.join("\n")).toContain("during resume");
expect(recorder.replaceCalls).toHaveLength(0);
});
});
Loading
Loading