diff --git a/src/lib/onboard.ts b/src/lib/onboard.ts index 4f4fd493e0..42fe9f56dc 100644 --- a/src/lib/onboard.ts +++ b/src/lib/onboard.ts @@ -89,6 +89,9 @@ const { const { resolveRequestedProviderSelection, }: typeof import("./onboard/provider-selection") = require("./onboard/provider-selection"); +const { + promptForInferenceProviderSelection, +}: typeof import("./onboard/provider-selection-prompt") = require("./onboard/provider-selection-prompt"); const { isLinuxDockerDriverGatewayEnabled, }: typeof import("./onboard/docker-driver-platform") = require("./onboard/docker-driver-platform"); @@ -3863,31 +3866,14 @@ async function setupNim( : ` [non-interactive] Provider: ${selected.key}`, ); } else { - const suggestions: string[] = []; - if (vllmRunning) suggestions.push("vLLM"); - if (ollamaRunning) suggestions.push("Ollama"); - if (suggestions.length > 0) { - console.log( - ` Detected local inference option${suggestions.length > 1 ? "s" : ""}: ${suggestions.join(", ")}`, - ); - console.log(""); - } - - console.log(""); - console.log(" Select your inference provider:"); - options.forEach((o, i) => { - console.log(` ${i + 1}) ${o.label}`); + selected = await promptForInferenceProviderSelection({ + options, + vllmRunning, + ollamaRunning, + prompt, + log: console.log, + selectFromNumberedMenu: selectFromNumberedMenuOrExit, }); - console.log(""); - - const envProviderHint = (process.env.NEMOCLAW_PROVIDER || "").trim().toLowerCase(); - const envProviderIdx = envProviderHint - ? options.findIndex((o) => o.key.toLowerCase() === envProviderHint) - : -1; - const defaultIdx = - (envProviderIdx >= 0 ? envProviderIdx : options.findIndex((o) => o.key === "build")) + 1; - const choice = await prompt(` Choose [${defaultIdx}]: `); - selected = selectFromNumberedMenuOrExit(choice, defaultIdx, options); } if (!selected) { diff --git a/src/lib/onboard/provider-selection-prompt.test.ts b/src/lib/onboard/provider-selection-prompt.test.ts new file mode 100644 index 0000000000..1dcf1d8040 --- /dev/null +++ b/src/lib/onboard/provider-selection-prompt.test.ts @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import assert from "node:assert/strict"; +import { describe, it, vi } from "vitest"; + +import { promptForInferenceProviderSelection } from "../../../dist/lib/onboard/provider-selection-prompt"; + +const options = [ + { key: "build", label: "NVIDIA Endpoints" }, + { key: "openai", label: "OpenAI" }, + { key: "custom", label: "Other OpenAI-compatible endpoint" }, +]; + +function makeSelectSpy() { + return vi.fn((_: string, defaultIdx: number, entries: typeof options) => entries[defaultIdx - 1]); +} + +function makePrompt(reply: string) { + return vi.fn(async (_question: string) => reply); +} + +function makeLog() { + return vi.fn((_message?: string) => {}); +} + +describe("promptForInferenceProviderSelection", () => { + it("renders the provider menu and defaults to the build provider", async () => { + const prompt = makePrompt(""); + const log = makeLog(); + const selectFromNumberedMenu = makeSelectSpy(); + + const selected = await promptForInferenceProviderSelection({ + options, + vllmRunning: false, + ollamaRunning: false, + env: {}, + prompt, + log, + selectFromNumberedMenu, + }); + + assert.equal(selected.key, "build"); + assert.deepEqual( + log.mock.calls.map((call) => call[0]), + [ + "", + " Select your inference provider:", + " 1) NVIDIA Endpoints", + " 2) OpenAI", + " 3) Other OpenAI-compatible endpoint", + "", + ], + ); + assert.equal(prompt.mock.calls[0]?.[0], " Choose [1]: "); + assert.deepEqual(selectFromNumberedMenu.mock.calls[0], ["", 1, options]); + }); + + it("prints detected local inference suggestions before the menu", async () => { + const prompt = makePrompt("2"); + const log = makeLog(); + const selectFromNumberedMenu = makeSelectSpy(); + + await promptForInferenceProviderSelection({ + options, + vllmRunning: true, + ollamaRunning: true, + env: {}, + prompt, + log, + selectFromNumberedMenu, + }); + + assert.deepEqual( + log.mock.calls.slice(0, 3).map((call) => call[0]), + [" Detected local inference options: vLLM, Ollama", "", ""], + ); + }); + + it("uses NEMOCLAW_PROVIDER as the default choice when it matches an option", async () => { + const prompt = makePrompt(""); + const log = makeLog(); + const selectFromNumberedMenu = makeSelectSpy(); + + const selected = await promptForInferenceProviderSelection({ + options, + vllmRunning: false, + ollamaRunning: false, + env: { NEMOCLAW_PROVIDER: "OPENAI" }, + prompt, + log, + selectFromNumberedMenu, + }); + + assert.equal(selected.key, "openai"); + assert.equal(prompt.mock.calls[0]?.[0], " Choose [2]: "); + assert.deepEqual(selectFromNumberedMenu.mock.calls[0], ["", 2, options]); + }); + + it("falls back to the build option when the env hint is unavailable", async () => { + const prompt = makePrompt(""); + const log = makeLog(); + const selectFromNumberedMenu = makeSelectSpy(); + const reorderedOptions = [options[1], options[2], options[0]]; + + const selected = await promptForInferenceProviderSelection({ + options: reorderedOptions, + vllmRunning: false, + ollamaRunning: false, + env: { NEMOCLAW_PROVIDER: "missing-provider" }, + prompt, + log, + selectFromNumberedMenu, + }); + + assert.equal(selected.key, "build"); + assert.equal(prompt.mock.calls[0]?.[0], " Choose [3]: "); + assert.deepEqual(selectFromNumberedMenu.mock.calls[0], ["", 3, reorderedOptions]); + }); +}); diff --git a/src/lib/onboard/provider-selection-prompt.ts b/src/lib/onboard/provider-selection-prompt.ts new file mode 100644 index 0000000000..a1de4f1cc2 --- /dev/null +++ b/src/lib/onboard/provider-selection-prompt.ts @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { ProviderMenuChoice } from "./provider-menu"; + +export interface PromptForInferenceProviderSelectionInput { + options: T[]; + vllmRunning: boolean; + ollamaRunning: boolean; + env?: NodeJS.ProcessEnv; + prompt(question: string): Promise; + log(message?: string): void; + selectFromNumberedMenu(rawChoice: string, defaultIdx: number, options: T[]): T; +} + +function getDefaultProviderIndex(options: ProviderMenuChoice[], env: NodeJS.ProcessEnv): number { + const envProviderHint = (env.NEMOCLAW_PROVIDER || "").trim().toLowerCase(); + const envProviderIdx = envProviderHint + ? options.findIndex((option) => option.key.toLowerCase() === envProviderHint) + : -1; + return ( + (envProviderIdx >= 0 ? envProviderIdx : options.findIndex((option) => option.key === "build")) + + 1 + ); +} + +function getDetectedLocalInferenceSuggestions(input: { + vllmRunning: boolean; + ollamaRunning: boolean; +}): string[] { + const suggestions: string[] = []; + if (input.vllmRunning) suggestions.push("vLLM"); + if (input.ollamaRunning) suggestions.push("Ollama"); + return suggestions; +} + +export async function promptForInferenceProviderSelection( + input: PromptForInferenceProviderSelectionInput, +): Promise { + const suggestions = getDetectedLocalInferenceSuggestions(input); + if (suggestions.length > 0) { + input.log( + ` Detected local inference option${suggestions.length > 1 ? "s" : ""}: ${suggestions.join(", ")}`, + ); + input.log(""); + } + + input.log(""); + input.log(" Select your inference provider:"); + input.options.forEach((option, index) => { + input.log(` ${index + 1}) ${option.label}`); + }); + input.log(""); + + const defaultIdx = getDefaultProviderIndex(input.options, input.env ?? process.env); + const choice = await input.prompt(` Choose [${defaultIdx}]: `); + return input.selectFromNumberedMenu(choice, defaultIdx, input.options); +}