Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/colab/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,18 @@ export function variantToMachineType(variant: Variant): string {
return "TPU";
}
}

/**
* Maps a Colab {@link Shape} to a human-friendly RAM type name.
*
* @param shape - The Colab {@link Shape}.
* @returns The human-friendly RAM type name.
*/
export function shapeToRamType(shape: Shape): string {
switch (shape) {
case Shape.STANDARD:
return "Standard RAM";
case Shape.HIGHMEM:
return "High RAM";
}
}
17 changes: 15 additions & 2 deletions src/colab/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
Assignment,
CcuInfo,
Variant,
Shape,
GetAssignmentResponse,
CcuInfoSchema,
AssignmentSchema,
Expand Down Expand Up @@ -110,6 +111,7 @@ export class ColabClient {
* This value should always be a string of length 44.
* @param variant - The machine variant to assign.
* @param accelerator - The accelerator to assign.
* @param shape - The machine shape to assign.
* @param signal - Optional {@link AbortSignal} to cancel the request.
* @returns The assignment which is assigned to the user.
* @throws TooManyAssignmentsError if the user has too many assignments.
Expand All @@ -120,12 +122,14 @@ export class ColabClient {
notebookHash: UUID,
variant: Variant,
accelerator?: string,
shape?: Shape,
signal?: AbortSignal,
): Promise<{ assignment: Assignment; isNew: boolean }> {
const assignment = await this.getAssignment(
notebookHash,
variant,
accelerator,
shape,
signal,
);
switch (assignment.kind) {
Expand All @@ -143,6 +147,7 @@ export class ColabClient {
assignment.xsrfToken,
variant,
accelerator,
shape,
signal,
);
} catch (error) {
Expand Down Expand Up @@ -352,9 +357,10 @@ export class ColabClient {
notebookHash: UUID,
variant: Variant,
accelerator?: string,
shape?: Shape,
signal?: AbortSignal,
): Promise<AssignmentToken | AssignedAssignment> {
const url = this.buildAssignUrl(notebookHash, variant, accelerator);
const url = this.buildAssignUrl(notebookHash, variant, accelerator, shape);
const response = await this.issueRequest(
url,
{ method: "GET", signal },
Expand All @@ -372,9 +378,10 @@ export class ColabClient {
xsrfToken: string,
variant: Variant,
accelerator?: string,
shape?: Shape,
signal?: AbortSignal,
): Promise<PostAssignmentResponse> {
const url = this.buildAssignUrl(notebookHash, variant, accelerator);
const url = this.buildAssignUrl(notebookHash, variant, accelerator, shape);
return await this.issueRequest(
url,
{
Expand All @@ -390,6 +397,7 @@ export class ColabClient {
notebookHash: UUID,
variant: Variant,
accelerator?: string,
shape?: Shape,
): URL {
const url = new URL(`${TUN_ENDPOINT}/assign`, this.colabDomain);
url.searchParams.append("nbh", uuidToWebSafeBase64(notebookHash));
Expand All @@ -399,6 +407,11 @@ export class ColabClient {
if (accelerator) {
url.searchParams.append("accelerator", accelerator);
}
// Only include shape parameter when it's not STANDARD, as STANDARD is the
// default behavior when the parameter is omitted.
if (shape !== undefined && shape !== Shape.STANDARD) {
url.searchParams.append("shape", shape.toString());
}
return url;
}

Expand Down
98 changes: 98 additions & 0 deletions src/colab/client.unit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,104 @@ describe("ColabClient", () => {
});
}

it("creates a new assignment with highmem shape", async () => {
const variant = Variant.GPU;
const accelerator = "T4";
const postQueryParams: Record<string, string | RegExp> = {
...queryParams,
variant,
accelerator,
shape: Shape.HIGHMEM.toString(),
};
const assignmentResponse = {
...DEFAULT_ASSIGNMENT_RESPONSE,
variant,
accelerator,
machineShape: Shape.HIGHMEM,
};
fetchStub
.withArgs(
urlMatcher({
method: "POST",
host: COLAB_HOST,
path: ASSIGN_PATH,
queryParams: postQueryParams,
otherHeaders: {
[COLAB_XSRF_TOKEN_HEADER.key]: "mock-xsrf-token",
},
}),
)
.resolves(
new Response(withXSSI(JSON.stringify(assignmentResponse)), {
status: 200,
}),
);

const expectedAssignment: Assignment = {
...DEFAULT_ASSIGNMENT,
variant,
accelerator,
machineShape: Shape.HIGHMEM,
};
await expect(
client.assign(NOTEBOOK_HASH, variant, accelerator, Shape.HIGHMEM),
).to.eventually.deep.equal({
assignment: expectedAssignment,
isNew: true,
});

sinon.assert.calledTwice(fetchStub);
});

it("does not include shape param when shape is STANDARD", async () => {
const variant = Variant.GPU;
const accelerator = "T4";
// Note: no shape param expected in queryParams
const postQueryParams: Record<string, string | RegExp> = {
...queryParams,
variant,
accelerator,
};
const assignmentResponse = {
...DEFAULT_ASSIGNMENT_RESPONSE,
variant,
accelerator,
machineShape: Shape.STANDARD,
};
fetchStub
.withArgs(
urlMatcher({
method: "POST",
host: COLAB_HOST,
path: ASSIGN_PATH,
queryParams: postQueryParams,
otherHeaders: {
[COLAB_XSRF_TOKEN_HEADER.key]: "mock-xsrf-token",
},
}),
)
.resolves(
new Response(withXSSI(JSON.stringify(assignmentResponse)), {
status: 200,
}),
);

const expectedAssignment: Assignment = {
...DEFAULT_ASSIGNMENT,
variant,
accelerator,
machineShape: Shape.STANDARD,
};
await expect(
client.assign(NOTEBOOK_HASH, variant, accelerator, Shape.STANDARD),
).to.eventually.deep.equal({
assignment: expectedAssignment,
isNew: true,
});

sinon.assert.calledTwice(fetchStub);
});

it("rejects when assignments exceed limit", async () => {
fetchStub
.withArgs(
Expand Down
86 changes: 78 additions & 8 deletions src/colab/server-picker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import vscode, { QuickPickItem } from "vscode";
import { InputStep, MultiStepInput } from "../common/multi-step-quickpick";
import { AssignmentManager } from "../jupyter/assignments";
import { ColabServerDescriptor } from "../jupyter/servers";
import { Variant, variantToMachineType } from "./api";
import { Shape, shapeToRamType, Variant, variantToMachineType } from "./api";

/** Provides an explanation to the user on updating the server alias. */
export const PROMPT_SERVER_ALIAS =
Expand All @@ -32,10 +32,13 @@ export class ServerPicker {
* server type.
*
* @param availableServers - The available servers to pick from.
* @param isHighMemEligible - Whether the user is eligible to select high-mem
* machines (e.g., Colab Pro users).
* @returns The selected server, or undefined if the user cancels.
*/
async prompt(
availableServers: ColabServerDescriptor[],
isHighMemEligible: boolean = false,
): Promise<ColabServerDescriptor | undefined> {
const variantToAccelerators = new Map<Variant, Set<string>>();
for (const server of availableServers) {
Expand All @@ -50,7 +53,12 @@ export class ServerPicker {

const state: Partial<Server> = {};
await MultiStepInput.run(this.vs, (input) =>
this.promptForVariant(input, state, variantToAccelerators),
this.promptForVariant(
input,
state,
variantToAccelerators,
isHighMemEligible,
),
);
if (
state.variant === undefined ||
Expand All @@ -59,17 +67,22 @@ export class ServerPicker {
) {
return undefined;
}
return {
const result: ColabServerDescriptor = {
label: state.alias,
variant: state.variant,
accelerator: state.accelerator,
};
if (state.shape !== undefined) {
return { ...result, shape: state.shape };
}
return result;
}

private async promptForVariant(
input: MultiStepInput,
state: Partial<Server>,
acceleratorsByVariant: Map<Variant, Set<string>>,
isHighMemEligible: boolean,
): Promise<InputStep | undefined> {
const items: VariantPick[] = [];
for (const variant of acceleratorsByVariant.keys()) {
Expand All @@ -94,16 +107,27 @@ export class ServerPicker {
// Skip prompting for an accelerator for the default variant (CPU).
if (state.variant === Variant.DEFAULT) {
state.accelerator = "NONE";
return (input: MultiStepInput) => this.promptForAlias(input, state);
if (isHighMemEligible) {
return (input: MultiStepInput) =>
this.promptForShape(input, state, isHighMemEligible);
}
return (input: MultiStepInput) =>
this.promptForAlias(input, state, isHighMemEligible);
}
return (input: MultiStepInput) =>
this.promptForAccelerator(input, state, acceleratorsByVariant);
this.promptForAccelerator(
input,
state,
acceleratorsByVariant,
isHighMemEligible,
);
}

private async promptForAccelerator(
input: MultiStepInput,
state: PartialServerWith<"variant">,
acceleratorsByVariant: Map<Variant, Set<string>>,
isHighMemEligible: boolean,
): Promise<InputStep | undefined> {
const accelerators = acceleratorsByVariant.get(state.variant) ?? new Set();
const items: AcceleratorPick[] = [];
Expand All @@ -117,7 +141,7 @@ export class ServerPicker {
title: "Select an accelerator",
step: 2,
// Since we have to pick an accelerator, we've added a step.
totalSteps: 3,
totalSteps: isHighMemEligible ? 4 : 3,
items,
activeItem: items.find((item) => item.value === state.accelerator),
buttons: [input.vs.QuickInputButtons.Back],
Expand All @@ -127,18 +151,59 @@ export class ServerPicker {
return;
}

return (input: MultiStepInput) => this.promptForAlias(input, state);
if (isHighMemEligible) {
return (input: MultiStepInput) =>
this.promptForShape(input, state, isHighMemEligible);
}
return (input: MultiStepInput) =>
this.promptForAlias(input, state, isHighMemEligible);
}

private async promptForShape(
input: MultiStepInput,
state: PartialServerWith<"variant">,
isHighMemEligible: boolean,
): Promise<InputStep | undefined> {
const items: ShapePick[] = [
{ value: Shape.STANDARD, label: shapeToRamType(Shape.STANDARD) },
{ value: Shape.HIGHMEM, label: shapeToRamType(Shape.HIGHMEM) },
];
const hasAccelerator = state.accelerator && state.accelerator !== "NONE";
const step = hasAccelerator ? 3 : 2;
const pick = await input.showQuickPick({
title: "Select RAM",
step,
totalSteps: step + 1,
items,
activeItem: items.find((item) => item.value === state.shape),
buttons: [input.vs.QuickInputButtons.Back],
});
state.shape = pick.value;
if (state.shape === undefined) {
return;
}

return (input: MultiStepInput) =>
this.promptForAlias(input, state, isHighMemEligible);
}

private async promptForAlias(
input: MultiStepInput,
state: PartialServerWith<"variant">,
isHighMemEligible: boolean,
): Promise<InputStep | undefined> {
const placeholder = await this.assignments.getDefaultLabel(
state.variant,
state.accelerator,
);
const step = state.accelerator && state.accelerator !== "NONE" ? 3 : 2;
const hasAccelerator = state.accelerator && state.accelerator !== "NONE";
let step = 2;
if (hasAccelerator) {
step = 3;
}
if (isHighMemEligible) {
step = hasAccelerator ? 4 : 3;
}
const alias = await input.showInputBox({
title: "Alias your server",
step,
Expand All @@ -157,6 +222,7 @@ export class ServerPicker {
interface Server {
variant: Variant;
accelerator: string;
shape?: Shape;
alias: string;
}

Expand Down Expand Up @@ -185,3 +251,7 @@ interface VariantPick extends QuickPickItem {
interface AcceleratorPick extends QuickPickItem {
value: string;
}

interface ShapePick extends QuickPickItem {
value: Shape;
}
Loading