diff --git a/src/colab/api.ts b/src/colab/api.ts index 3b949cbd..a12d2bc5 100644 --- a/src/colab/api.ts +++ b/src/colab/api.ts @@ -391,3 +391,18 @@ export function variantToMachineType(variant: Variant): string { return "TPU"; } } + +/** + * Maps a Colab {@link Shape} to a human-friendly RAM type name. + * + * @param shape - The Colab {@link Shape}. + * @returns The human-friendly RAM type name. + */ +export function shapeToRamType(shape: Shape): string { + switch (shape) { + case Shape.STANDARD: + return "Standard RAM"; + case Shape.HIGHMEM: + return "High RAM"; + } +} diff --git a/src/colab/client.ts b/src/colab/client.ts index 6ccf3772..4aec12cf 100644 --- a/src/colab/client.ts +++ b/src/colab/client.ts @@ -15,6 +15,7 @@ import { Assignment, CcuInfo, Variant, + Shape, GetAssignmentResponse, CcuInfoSchema, AssignmentSchema, @@ -110,6 +111,7 @@ export class ColabClient { * This value should always be a string of length 44. * @param variant - The machine variant to assign. * @param accelerator - The accelerator to assign. + * @param shape - The machine shape to assign. * @param signal - Optional {@link AbortSignal} to cancel the request. * @returns The assignment which is assigned to the user. * @throws TooManyAssignmentsError if the user has too many assignments. @@ -120,12 +122,14 @@ export class ColabClient { notebookHash: UUID, variant: Variant, accelerator?: string, + shape?: Shape, signal?: AbortSignal, ): Promise<{ assignment: Assignment; isNew: boolean }> { const assignment = await this.getAssignment( notebookHash, variant, accelerator, + shape, signal, ); switch (assignment.kind) { @@ -143,6 +147,7 @@ export class ColabClient { assignment.xsrfToken, variant, accelerator, + shape, signal, ); } catch (error) { @@ -352,9 +357,10 @@ export class ColabClient { notebookHash: UUID, variant: Variant, accelerator?: string, + shape?: Shape, signal?: AbortSignal, ): Promise { - const url = this.buildAssignUrl(notebookHash, variant, accelerator); + const url = this.buildAssignUrl(notebookHash, variant, accelerator, shape); const response = await this.issueRequest( url, { method: "GET", signal }, @@ -372,9 +378,10 @@ export class ColabClient { xsrfToken: string, variant: Variant, accelerator?: string, + shape?: Shape, signal?: AbortSignal, ): Promise { - const url = this.buildAssignUrl(notebookHash, variant, accelerator); + const url = this.buildAssignUrl(notebookHash, variant, accelerator, shape); return await this.issueRequest( url, { @@ -390,6 +397,7 @@ export class ColabClient { notebookHash: UUID, variant: Variant, accelerator?: string, + shape?: Shape, ): URL { const url = new URL(`${TUN_ENDPOINT}/assign`, this.colabDomain); url.searchParams.append("nbh", uuidToWebSafeBase64(notebookHash)); @@ -399,6 +407,11 @@ export class ColabClient { if (accelerator) { url.searchParams.append("accelerator", accelerator); } + // Only include shape parameter when it's not STANDARD, as STANDARD is the + // default behavior when the parameter is omitted. + if (shape !== undefined && shape !== Shape.STANDARD) { + url.searchParams.append("shape", shape.toString()); + } return url; } diff --git a/src/colab/client.unit.test.ts b/src/colab/client.unit.test.ts index 1495ed0f..786be025 100644 --- a/src/colab/client.unit.test.ts +++ b/src/colab/client.unit.test.ts @@ -281,6 +281,104 @@ describe("ColabClient", () => { }); } + it("creates a new assignment with highmem shape", async () => { + const variant = Variant.GPU; + const accelerator = "T4"; + const postQueryParams: Record = { + ...queryParams, + variant, + accelerator, + shape: Shape.HIGHMEM.toString(), + }; + const assignmentResponse = { + ...DEFAULT_ASSIGNMENT_RESPONSE, + variant, + accelerator, + machineShape: Shape.HIGHMEM, + }; + fetchStub + .withArgs( + urlMatcher({ + method: "POST", + host: COLAB_HOST, + path: ASSIGN_PATH, + queryParams: postQueryParams, + otherHeaders: { + [COLAB_XSRF_TOKEN_HEADER.key]: "mock-xsrf-token", + }, + }), + ) + .resolves( + new Response(withXSSI(JSON.stringify(assignmentResponse)), { + status: 200, + }), + ); + + const expectedAssignment: Assignment = { + ...DEFAULT_ASSIGNMENT, + variant, + accelerator, + machineShape: Shape.HIGHMEM, + }; + await expect( + client.assign(NOTEBOOK_HASH, variant, accelerator, Shape.HIGHMEM), + ).to.eventually.deep.equal({ + assignment: expectedAssignment, + isNew: true, + }); + + sinon.assert.calledTwice(fetchStub); + }); + + it("does not include shape param when shape is STANDARD", async () => { + const variant = Variant.GPU; + const accelerator = "T4"; + // Note: no shape param expected in queryParams + const postQueryParams: Record = { + ...queryParams, + variant, + accelerator, + }; + const assignmentResponse = { + ...DEFAULT_ASSIGNMENT_RESPONSE, + variant, + accelerator, + machineShape: Shape.STANDARD, + }; + fetchStub + .withArgs( + urlMatcher({ + method: "POST", + host: COLAB_HOST, + path: ASSIGN_PATH, + queryParams: postQueryParams, + otherHeaders: { + [COLAB_XSRF_TOKEN_HEADER.key]: "mock-xsrf-token", + }, + }), + ) + .resolves( + new Response(withXSSI(JSON.stringify(assignmentResponse)), { + status: 200, + }), + ); + + const expectedAssignment: Assignment = { + ...DEFAULT_ASSIGNMENT, + variant, + accelerator, + machineShape: Shape.STANDARD, + }; + await expect( + client.assign(NOTEBOOK_HASH, variant, accelerator, Shape.STANDARD), + ).to.eventually.deep.equal({ + assignment: expectedAssignment, + isNew: true, + }); + + sinon.assert.calledTwice(fetchStub); + }); + it("rejects when assignments exceed limit", async () => { fetchStub .withArgs( diff --git a/src/colab/server-picker.ts b/src/colab/server-picker.ts index 595259ad..ec030ff9 100644 --- a/src/colab/server-picker.ts +++ b/src/colab/server-picker.ts @@ -8,7 +8,7 @@ import vscode, { QuickPickItem } from "vscode"; import { InputStep, MultiStepInput } from "../common/multi-step-quickpick"; import { AssignmentManager } from "../jupyter/assignments"; import { ColabServerDescriptor } from "../jupyter/servers"; -import { Variant, variantToMachineType } from "./api"; +import { Shape, shapeToRamType, Variant, variantToMachineType } from "./api"; /** Provides an explanation to the user on updating the server alias. */ export const PROMPT_SERVER_ALIAS = @@ -32,10 +32,13 @@ export class ServerPicker { * server type. * * @param availableServers - The available servers to pick from. + * @param isHighMemEligible - Whether the user is eligible to select high-mem + * machines (e.g., Colab Pro users). * @returns The selected server, or undefined if the user cancels. */ async prompt( availableServers: ColabServerDescriptor[], + isHighMemEligible: boolean = false, ): Promise { const variantToAccelerators = new Map>(); for (const server of availableServers) { @@ -50,7 +53,12 @@ export class ServerPicker { const state: Partial = {}; await MultiStepInput.run(this.vs, (input) => - this.promptForVariant(input, state, variantToAccelerators), + this.promptForVariant( + input, + state, + variantToAccelerators, + isHighMemEligible, + ), ); if ( state.variant === undefined || @@ -59,17 +67,22 @@ export class ServerPicker { ) { return undefined; } - return { + const result: ColabServerDescriptor = { label: state.alias, variant: state.variant, accelerator: state.accelerator, }; + if (state.shape !== undefined) { + return { ...result, shape: state.shape }; + } + return result; } private async promptForVariant( input: MultiStepInput, state: Partial, acceleratorsByVariant: Map>, + isHighMemEligible: boolean, ): Promise { const items: VariantPick[] = []; for (const variant of acceleratorsByVariant.keys()) { @@ -94,16 +107,27 @@ export class ServerPicker { // Skip prompting for an accelerator for the default variant (CPU). if (state.variant === Variant.DEFAULT) { state.accelerator = "NONE"; - return (input: MultiStepInput) => this.promptForAlias(input, state); + if (isHighMemEligible) { + return (input: MultiStepInput) => + this.promptForShape(input, state, isHighMemEligible); + } + return (input: MultiStepInput) => + this.promptForAlias(input, state, isHighMemEligible); } return (input: MultiStepInput) => - this.promptForAccelerator(input, state, acceleratorsByVariant); + this.promptForAccelerator( + input, + state, + acceleratorsByVariant, + isHighMemEligible, + ); } private async promptForAccelerator( input: MultiStepInput, state: PartialServerWith<"variant">, acceleratorsByVariant: Map>, + isHighMemEligible: boolean, ): Promise { const accelerators = acceleratorsByVariant.get(state.variant) ?? new Set(); const items: AcceleratorPick[] = []; @@ -117,7 +141,7 @@ export class ServerPicker { title: "Select an accelerator", step: 2, // Since we have to pick an accelerator, we've added a step. - totalSteps: 3, + totalSteps: isHighMemEligible ? 4 : 3, items, activeItem: items.find((item) => item.value === state.accelerator), buttons: [input.vs.QuickInputButtons.Back], @@ -127,18 +151,59 @@ export class ServerPicker { return; } - return (input: MultiStepInput) => this.promptForAlias(input, state); + if (isHighMemEligible) { + return (input: MultiStepInput) => + this.promptForShape(input, state, isHighMemEligible); + } + return (input: MultiStepInput) => + this.promptForAlias(input, state, isHighMemEligible); + } + + private async promptForShape( + input: MultiStepInput, + state: PartialServerWith<"variant">, + isHighMemEligible: boolean, + ): Promise { + const items: ShapePick[] = [ + { value: Shape.STANDARD, label: shapeToRamType(Shape.STANDARD) }, + { value: Shape.HIGHMEM, label: shapeToRamType(Shape.HIGHMEM) }, + ]; + const hasAccelerator = state.accelerator && state.accelerator !== "NONE"; + const step = hasAccelerator ? 3 : 2; + const pick = await input.showQuickPick({ + title: "Select RAM", + step, + totalSteps: step + 1, + items, + activeItem: items.find((item) => item.value === state.shape), + buttons: [input.vs.QuickInputButtons.Back], + }); + state.shape = pick.value; + if (state.shape === undefined) { + return; + } + + return (input: MultiStepInput) => + this.promptForAlias(input, state, isHighMemEligible); } private async promptForAlias( input: MultiStepInput, state: PartialServerWith<"variant">, + isHighMemEligible: boolean, ): Promise { const placeholder = await this.assignments.getDefaultLabel( state.variant, state.accelerator, ); - const step = state.accelerator && state.accelerator !== "NONE" ? 3 : 2; + const hasAccelerator = state.accelerator && state.accelerator !== "NONE"; + let step = 2; + if (hasAccelerator) { + step = 3; + } + if (isHighMemEligible) { + step = hasAccelerator ? 4 : 3; + } const alias = await input.showInputBox({ title: "Alias your server", step, @@ -157,6 +222,7 @@ export class ServerPicker { interface Server { variant: Variant; accelerator: string; + shape?: Shape; alias: string; } @@ -185,3 +251,7 @@ interface VariantPick extends QuickPickItem { interface AcceleratorPick extends QuickPickItem { value: string; } + +interface ShapePick extends QuickPickItem { + value: Shape; +} diff --git a/src/colab/server-picker.unit.test.ts b/src/colab/server-picker.unit.test.ts index 1edf98d4..d6fd698b 100644 --- a/src/colab/server-picker.unit.test.ts +++ b/src/colab/server-picker.unit.test.ts @@ -14,7 +14,7 @@ import { buildQuickPickStub, } from "../test/helpers/quick-input"; import { newVsCodeStub, VsCodeStub } from "../test/helpers/vscode"; -import { Variant } from "./api"; +import { Shape, Variant } from "./api"; import { ServerPicker } from "./server-picker"; const AVAILABLE_SERVERS = [ @@ -334,5 +334,203 @@ describe("ServerPicker", () => { expect(aliasInputBoxStub.step).to.equal(3); expect(aliasInputBoxStub.totalSteps).to.equal(3); }); + + describe("when high-mem is eligible", () => { + it("prompts for shape after variant for CPU", async () => { + const variantQuickPickStub = stubQuickPickForCall(0); + const shapeQuickPickStub = stubQuickPickForCall(1); + const aliasInputBoxStub = stubInputBoxForCall(0); + + const variantPickerShown = variantQuickPickStub.nextShow(); + void serverPicker.prompt(AVAILABLE_SERVERS, true); + await variantPickerShown; + const shapePickerShown = shapeQuickPickStub.nextShow(); + variantQuickPickStub.onDidChangeSelection.yield([ + { value: Variant.DEFAULT, label: "CPU" }, + ]); + await shapePickerShown; + const aliasInputShown = aliasInputBoxStub.nextShow(); + shapeQuickPickStub.onDidChangeSelection.yield([ + { value: Shape.HIGHMEM, label: "High RAM" }, + ]); + await aliasInputShown; + }); + + it("prompts for shape after accelerator for GPU", async () => { + const variantQuickPickStub = stubQuickPickForCall(0); + const acceleratorQuickPickStub = stubQuickPickForCall(1); + const shapeQuickPickStub = stubQuickPickForCall(2); + const aliasInputBoxStub = stubInputBoxForCall(0); + + const variantPickerShown = variantQuickPickStub.nextShow(); + void serverPicker.prompt(AVAILABLE_SERVERS, true); + await variantPickerShown; + const acceleratorPickerShown = acceleratorQuickPickStub.nextShow(); + variantQuickPickStub.onDidChangeSelection.yield([ + { value: Variant.GPU, label: "GPU" }, + ]); + await acceleratorPickerShown; + const shapePickerShown = shapeQuickPickStub.nextShow(); + acceleratorQuickPickStub.onDidChangeSelection.yield([ + { value: "T4", label: "T4" }, + ]); + await shapePickerShown; + const aliasInputShown = aliasInputBoxStub.nextShow(); + shapeQuickPickStub.onDidChangeSelection.yield([ + { value: Shape.STANDARD, label: "Standard RAM" }, + ]); + await aliasInputShown; + }); + + it("returns the server type with shape when all prompts are answered for GPU", async () => { + const variantQuickPickStub = stubQuickPickForCall(0); + const acceleratorQuickPickStub = stubQuickPickForCall(1); + const shapeQuickPickStub = stubQuickPickForCall(2); + const aliasInputBoxStub = stubInputBoxForCall(0); + + const variantPickerShown = variantQuickPickStub.nextShow(); + const prompt = serverPicker.prompt(AVAILABLE_SERVERS, true); + await variantPickerShown; + const acceleratorPickerShown = acceleratorQuickPickStub.nextShow(); + variantQuickPickStub.onDidChangeSelection.yield([ + { value: Variant.GPU, label: "GPU" }, + ]); + await acceleratorPickerShown; + const shapePickerShown = shapeQuickPickStub.nextShow(); + acceleratorQuickPickStub.onDidChangeSelection.yield([ + { value: "T4", label: "T4" }, + ]); + await shapePickerShown; + const aliasInputShown = aliasInputBoxStub.nextShow(); + shapeQuickPickStub.onDidChangeSelection.yield([ + { value: Shape.HIGHMEM, label: "High RAM" }, + ]); + await aliasInputShown; + aliasInputBoxStub.value = "foo"; + aliasInputBoxStub.onDidChangeValue.yield("foo"); + aliasInputBoxStub.onDidAccept.yield(); + + await expect(prompt).to.eventually.be.deep.equal({ + label: "foo", + variant: Variant.GPU, + accelerator: "T4", + shape: Shape.HIGHMEM, + }); + }); + + it("returns the server type with shape when all prompts are answered for CPU", async () => { + const variantQuickPickStub = stubQuickPickForCall(0); + const shapeQuickPickStub = stubQuickPickForCall(1); + const aliasInputBoxStub = stubInputBoxForCall(0); + + const variantPickerShown = variantQuickPickStub.nextShow(); + const prompt = serverPicker.prompt(AVAILABLE_SERVERS, true); + await variantPickerShown; + const shapePickerShown = shapeQuickPickStub.nextShow(); + variantQuickPickStub.onDidChangeSelection.yield([ + { value: Variant.DEFAULT, label: "CPU" }, + ]); + await shapePickerShown; + const aliasInputShown = aliasInputBoxStub.nextShow(); + shapeQuickPickStub.onDidChangeSelection.yield([ + { value: Shape.HIGHMEM, label: "High RAM" }, + ]); + await aliasInputShown; + aliasInputBoxStub.value = "my-cpu"; + aliasInputBoxStub.onDidChangeValue.yield("my-cpu"); + aliasInputBoxStub.onDidAccept.yield(); + + await expect(prompt).to.eventually.be.deep.equal({ + label: "my-cpu", + variant: Variant.DEFAULT, + accelerator: "NONE", + shape: Shape.HIGHMEM, + }); + }); + + it("sets the right step for CPU with shape selection", async () => { + const variantQuickPickStub = stubQuickPickForCall(0); + const shapeQuickPickStub = stubQuickPickForCall(1); + const aliasInputBoxStub = stubInputBoxForCall(0); + const variantPickerShown = variantQuickPickStub.nextShow(); + const shapePickerShown = shapeQuickPickStub.nextShow(); + const aliasInputShown = aliasInputBoxStub.nextShow(); + + void serverPicker.prompt(AVAILABLE_SERVERS, true); + + await variantPickerShown; + expect(variantQuickPickStub.step).to.equal(1); + expect(variantQuickPickStub.totalSteps).to.equal(2); + + variantQuickPickStub.onDidChangeSelection.yield([ + { value: Variant.DEFAULT, label: "CPU" }, + ]); + await shapePickerShown; + expect(shapeQuickPickStub.step).to.equal(2); + expect(shapeQuickPickStub.totalSteps).to.equal(3); + + shapeQuickPickStub.onDidChangeSelection.yield([ + { value: Shape.HIGHMEM, label: "High RAM" }, + ]); + await aliasInputShown; + expect(aliasInputBoxStub.step).to.equal(3); + expect(aliasInputBoxStub.totalSteps).to.equal(3); + }); + + it("sets the right step for GPU with shape selection", async () => { + const variantQuickPickStub = stubQuickPickForCall(0); + const acceleratorQuickPickStub = stubQuickPickForCall(1); + const shapeQuickPickStub = stubQuickPickForCall(2); + const aliasInputBoxStub = stubInputBoxForCall(0); + const variantPickerShown = variantQuickPickStub.nextShow(); + const acceleratorPickerShown = acceleratorQuickPickStub.nextShow(); + const shapePickerShown = shapeQuickPickStub.nextShow(); + const aliasInputShown = aliasInputBoxStub.nextShow(); + + void serverPicker.prompt(AVAILABLE_SERVERS, true); + + await variantPickerShown; + expect(variantQuickPickStub.step).to.equal(1); + expect(variantQuickPickStub.totalSteps).to.equal(2); + + variantQuickPickStub.onDidChangeSelection.yield([ + { value: Variant.GPU, label: "GPU" }, + ]); + await acceleratorPickerShown; + expect(acceleratorQuickPickStub.step).to.equal(2); + expect(acceleratorQuickPickStub.totalSteps).to.equal(4); + + acceleratorQuickPickStub.onDidChangeSelection.yield([ + { value: "T4", label: "T4" }, + ]); + await shapePickerShown; + expect(shapeQuickPickStub.step).to.equal(3); + expect(shapeQuickPickStub.totalSteps).to.equal(4); + + shapeQuickPickStub.onDidChangeSelection.yield([ + { value: Shape.STANDARD, label: "Standard RAM" }, + ]); + await aliasInputShown; + expect(aliasInputBoxStub.step).to.equal(4); + expect(aliasInputBoxStub.totalSteps).to.equal(4); + }); + + it("returns undefined when selecting shape is cancelled", async () => { + const variantQuickPickStub = stubQuickPickForCall(0); + const shapeQuickPickStub = stubQuickPickForCall(1); + + const variantPickerShown = variantQuickPickStub.nextShow(); + const prompt = serverPicker.prompt(AVAILABLE_SERVERS, true); + await variantPickerShown; + const shapePickerShown = shapeQuickPickStub.nextShow(); + variantQuickPickStub.onDidChangeSelection.yield([ + { value: Variant.DEFAULT, label: "CPU" }, + ]); + await shapePickerShown; + shapeQuickPickStub.onDidHide.yield(); + + await expect(prompt).to.eventually.be.undefined; + }); + }); }); }); diff --git a/src/jupyter/assignments.ts b/src/jupyter/assignments.ts index 4ebcbfbb..5d12cac4 100644 --- a/src/jupyter/assignments.ts +++ b/src/jupyter/assignments.ts @@ -224,6 +224,7 @@ export class AssignmentManager implements vscode.Disposable { id, descriptor.variant, descriptor.accelerator, + descriptor.shape, signal, )); } catch (error) { diff --git a/src/jupyter/provider.ts b/src/jupyter/provider.ts index c4a7f3fc..92e7cf89 100644 --- a/src/jupyter/provider.ts +++ b/src/jupyter/provider.ts @@ -221,8 +221,18 @@ export class ColabJupyterServerProvider } private async assignServer(): Promise { + // Check if user is a Pro subscriber (Pro or Pro+) to enable high-mem option + let isHighMemEligible = false; + try { + const tier = await this.client.getSubscriptionTier(); + isHighMemEligible = + tier === SubscriptionTier.PRO || tier === SubscriptionTier.PRO_PLUS; + } catch { + // If we can't determine the tier, default to no high-mem option + } const serverType = await this.serverPicker.prompt( await this.assignmentManager.getAvailableServerDescriptors(), + isHighMemEligible, ); if (!serverType) { throw new this.vs.CancellationError(); diff --git a/src/jupyter/provider.unit.test.ts b/src/jupyter/provider.unit.test.ts index 7e45ac79..1ebff155 100644 --- a/src/jupyter/provider.unit.test.ts +++ b/src/jupyter/provider.unit.test.ts @@ -445,6 +445,7 @@ describe("ColabJupyterServerProvider", () => { describe("for new Colab server", () => { it("returns undefined when navigating back out of the flow", async () => { + colabClientStub.getSubscriptionTier.resolves(SubscriptionTier.NONE); serverPickerStub.prompt.rejects(InputFlowAction.back); await expect( @@ -456,7 +457,8 @@ describe("ColabJupyterServerProvider", () => { sinon.assert.calledOnce(serverPickerStub.prompt); }); - it("completes assigning a server", async () => { + it("completes assigning a server for free user without high-mem option", async () => { + colabClientStub.getSubscriptionTier.resolves(SubscriptionTier.NONE); const availableServers = [DEFAULT_SERVER]; assignmentStub.getAvailableServerDescriptors.resolves( availableServers, @@ -467,7 +469,7 @@ describe("ColabJupyterServerProvider", () => { accelerator: DEFAULT_SERVER.accelerator, }; serverPickerStub.prompt - .withArgs(availableServers) + .withArgs(availableServers, false) .resolves(selectedServer); assignmentStub.assignServer .withArgs(selectedServer) @@ -480,9 +482,111 @@ describe("ColabJupyterServerProvider", () => { ), ).to.eventually.deep.equal(DEFAULT_SERVER); - sinon.assert.calledOnce(serverPickerStub.prompt); + sinon.assert.calledOnceWithExactly( + serverPickerStub.prompt, + availableServers, + false, + ); sinon.assert.calledOnce(assignmentStub.assignServer); }); + + it("enables high-mem option for Pro users", async () => { + colabClientStub.getSubscriptionTier.resolves(SubscriptionTier.PRO); + const availableServers = [DEFAULT_SERVER]; + assignmentStub.getAvailableServerDescriptors.resolves( + availableServers, + ); + const selectedServer: ColabServerDescriptor = { + label: "My new server", + variant: DEFAULT_SERVER.variant, + accelerator: DEFAULT_SERVER.accelerator, + }; + serverPickerStub.prompt + .withArgs(availableServers, true) + .resolves(selectedServer); + assignmentStub.assignServer + .withArgs(selectedServer) + .resolves(DEFAULT_SERVER); + + await expect( + serverProvider.handleCommand( + { label: NEW_SERVER.label }, + cancellationToken, + ), + ).to.eventually.deep.equal(DEFAULT_SERVER); + + sinon.assert.calledOnceWithExactly( + serverPickerStub.prompt, + availableServers, + true, + ); + }); + + it("enables high-mem option for Pro Plus users", async () => { + colabClientStub.getSubscriptionTier.resolves( + SubscriptionTier.PRO_PLUS, + ); + const availableServers = [DEFAULT_SERVER]; + assignmentStub.getAvailableServerDescriptors.resolves( + availableServers, + ); + const selectedServer: ColabServerDescriptor = { + label: "My new server", + variant: DEFAULT_SERVER.variant, + accelerator: DEFAULT_SERVER.accelerator, + }; + serverPickerStub.prompt + .withArgs(availableServers, true) + .resolves(selectedServer); + assignmentStub.assignServer + .withArgs(selectedServer) + .resolves(DEFAULT_SERVER); + + await expect( + serverProvider.handleCommand( + { label: NEW_SERVER.label }, + cancellationToken, + ), + ).to.eventually.deep.equal(DEFAULT_SERVER); + + sinon.assert.calledOnceWithExactly( + serverPickerStub.prompt, + availableServers, + true, + ); + }); + + it("defaults to no high-mem option when tier check fails", async () => { + colabClientStub.getSubscriptionTier.rejects(new Error("Network error")); + const availableServers = [DEFAULT_SERVER]; + assignmentStub.getAvailableServerDescriptors.resolves( + availableServers, + ); + const selectedServer: ColabServerDescriptor = { + label: "My new server", + variant: DEFAULT_SERVER.variant, + accelerator: DEFAULT_SERVER.accelerator, + }; + serverPickerStub.prompt + .withArgs(availableServers, false) + .resolves(selectedServer); + assignmentStub.assignServer + .withArgs(selectedServer) + .resolves(DEFAULT_SERVER); + + await expect( + serverProvider.handleCommand( + { label: NEW_SERVER.label }, + cancellationToken, + ), + ).to.eventually.deep.equal(DEFAULT_SERVER); + + sinon.assert.calledOnceWithExactly( + serverPickerStub.prompt, + availableServers, + false, + ); + }); }); }); }); diff --git a/src/jupyter/servers.ts b/src/jupyter/servers.ts index f2562ec6..447c76e3 100644 --- a/src/jupyter/servers.ts +++ b/src/jupyter/servers.ts @@ -9,7 +9,7 @@ import { JupyterServer, JupyterServerConnectionInformation, } from "@vscode/jupyter-extension"; -import { Variant } from "../colab/api"; +import { Shape, Variant } from "../colab/api"; /** * Colab's Jupyter server descriptor which includes machine-specific @@ -19,6 +19,7 @@ export interface ColabServerDescriptor { readonly label: string; readonly variant: Variant; readonly accelerator?: string; + readonly shape?: Shape; } /**