Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions src/lib/onboard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ const {
const {
resolveRequestedProviderSelection,
}: typeof import("./onboard/provider-selection") = require("./onboard/provider-selection");
const {
promptForInferenceProviderSelection,
}: typeof import("./onboard/provider-selection-prompt") = require("./onboard/provider-selection-prompt");
const {
isLinuxDockerDriverGatewayEnabled,
}: typeof import("./onboard/docker-driver-platform") = require("./onboard/docker-driver-platform");
Expand Down Expand Up @@ -3863,31 +3866,14 @@ async function setupNim(
: ` [non-interactive] Provider: ${selected.key}`,
);
} else {
const suggestions: string[] = [];
if (vllmRunning) suggestions.push("vLLM");
if (ollamaRunning) suggestions.push("Ollama");
if (suggestions.length > 0) {
console.log(
` Detected local inference option${suggestions.length > 1 ? "s" : ""}: ${suggestions.join(", ")}`,
);
console.log("");
}

console.log("");
console.log(" Select your inference provider:");
options.forEach((o, i) => {
console.log(` ${i + 1}) ${o.label}`);
selected = await promptForInferenceProviderSelection({
options,
vllmRunning,
ollamaRunning,
prompt,
log: console.log,
selectFromNumberedMenu: selectFromNumberedMenuOrExit,
});
console.log("");

const envProviderHint = (process.env.NEMOCLAW_PROVIDER || "").trim().toLowerCase();
const envProviderIdx = envProviderHint
? options.findIndex((o) => o.key.toLowerCase() === envProviderHint)
: -1;
const defaultIdx =
(envProviderIdx >= 0 ? envProviderIdx : options.findIndex((o) => o.key === "build")) + 1;
const choice = await prompt(` Choose [${defaultIdx}]: `);
selected = selectFromNumberedMenuOrExit(choice, defaultIdx, options);
}

if (!selected) {
Expand Down
120 changes: 120 additions & 0 deletions src/lib/onboard/provider-selection-prompt.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

import assert from "node:assert/strict";
import { describe, it, vi } from "vitest";

import { promptForInferenceProviderSelection } from "../../../dist/lib/onboard/provider-selection-prompt";

const options = [
{ key: "build", label: "NVIDIA Endpoints" },
{ key: "openai", label: "OpenAI" },
{ key: "custom", label: "Other OpenAI-compatible endpoint" },
];

function makeSelectSpy() {
return vi.fn((_: string, defaultIdx: number, entries: typeof options) => entries[defaultIdx - 1]);
}

function makePrompt(reply: string) {
return vi.fn(async (_question: string) => reply);
}

function makeLog() {
return vi.fn((_message?: string) => {});
}

describe("promptForInferenceProviderSelection", () => {
it("renders the provider menu and defaults to the build provider", async () => {
const prompt = makePrompt("");
const log = makeLog();
const selectFromNumberedMenu = makeSelectSpy();

const selected = await promptForInferenceProviderSelection({
options,
vllmRunning: false,
ollamaRunning: false,
env: {},
prompt,
log,
selectFromNumberedMenu,
});

assert.equal(selected.key, "build");
assert.deepEqual(
log.mock.calls.map((call) => call[0]),
[
"",
" Select your inference provider:",
" 1) NVIDIA Endpoints",
" 2) OpenAI",
" 3) Other OpenAI-compatible endpoint",
"",
],
);
assert.equal(prompt.mock.calls[0]?.[0], " Choose [1]: ");
assert.deepEqual(selectFromNumberedMenu.mock.calls[0], ["", 1, options]);
});

it("prints detected local inference suggestions before the menu", async () => {
const prompt = makePrompt("2");
const log = makeLog();
const selectFromNumberedMenu = makeSelectSpy();

await promptForInferenceProviderSelection({
options,
vllmRunning: true,
ollamaRunning: true,
env: {},
prompt,
log,
selectFromNumberedMenu,
});

assert.deepEqual(
log.mock.calls.slice(0, 3).map((call) => call[0]),
[" Detected local inference options: vLLM, Ollama", "", ""],
);
});

it("uses NEMOCLAW_PROVIDER as the default choice when it matches an option", async () => {
const prompt = makePrompt("");
const log = makeLog();
const selectFromNumberedMenu = makeSelectSpy();

const selected = await promptForInferenceProviderSelection({
options,
vllmRunning: false,
ollamaRunning: false,
env: { NEMOCLAW_PROVIDER: "OPENAI" },
prompt,
log,
selectFromNumberedMenu,
});

assert.equal(selected.key, "openai");
assert.equal(prompt.mock.calls[0]?.[0], " Choose [2]: ");
assert.deepEqual(selectFromNumberedMenu.mock.calls[0], ["", 2, options]);
});

it("falls back to the build option when the env hint is unavailable", async () => {
const prompt = makePrompt("");
const log = makeLog();
const selectFromNumberedMenu = makeSelectSpy();
const reorderedOptions = [options[1], options[2], options[0]];

const selected = await promptForInferenceProviderSelection({
options: reorderedOptions,
vllmRunning: false,
ollamaRunning: false,
env: { NEMOCLAW_PROVIDER: "missing-provider" },
prompt,
log,
selectFromNumberedMenu,
});

assert.equal(selected.key, "build");
assert.equal(prompt.mock.calls[0]?.[0], " Choose [3]: ");
assert.deepEqual(selectFromNumberedMenu.mock.calls[0], ["", 3, reorderedOptions]);
});
});
58 changes: 58 additions & 0 deletions src/lib/onboard/provider-selection-prompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

import type { ProviderMenuChoice } from "./provider-menu";

export interface PromptForInferenceProviderSelectionInput<T extends ProviderMenuChoice> {
options: T[];
vllmRunning: boolean;
ollamaRunning: boolean;
env?: NodeJS.ProcessEnv;
prompt(question: string): Promise<string>;
log(message?: string): void;
selectFromNumberedMenu(rawChoice: string, defaultIdx: number, options: T[]): T;
}

function getDefaultProviderIndex(options: ProviderMenuChoice[], env: NodeJS.ProcessEnv): number {
const envProviderHint = (env.NEMOCLAW_PROVIDER || "").trim().toLowerCase();
const envProviderIdx = envProviderHint
? options.findIndex((option) => option.key.toLowerCase() === envProviderHint)
: -1;
return (
(envProviderIdx >= 0 ? envProviderIdx : options.findIndex((option) => option.key === "build")) +
1
);
}

function getDetectedLocalInferenceSuggestions(input: {
vllmRunning: boolean;
ollamaRunning: boolean;
}): string[] {
const suggestions: string[] = [];
if (input.vllmRunning) suggestions.push("vLLM");
if (input.ollamaRunning) suggestions.push("Ollama");
return suggestions;
}

export async function promptForInferenceProviderSelection<T extends ProviderMenuChoice>(
input: PromptForInferenceProviderSelectionInput<T>,
): Promise<T> {
const suggestions = getDetectedLocalInferenceSuggestions(input);
if (suggestions.length > 0) {
input.log(
` Detected local inference option${suggestions.length > 1 ? "s" : ""}: ${suggestions.join(", ")}`,
);
input.log("");
}

input.log("");
input.log(" Select your inference provider:");
input.options.forEach((option, index) => {
input.log(` ${index + 1}) ${option.label}`);
});
input.log("");

const defaultIdx = getDefaultProviderIndex(input.options, input.env ?? process.env);
const choice = await input.prompt(` Choose [${defaultIdx}]: `);
return input.selectFromNumberedMenu(choice, defaultIdx, input.options);
}
Loading