From 46e3b2a9d72676c947ff4f9d74e0e108f473e415 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Tue, 2 Jun 2026 18:07:02 +0900 Subject: [PATCH 01/24] feat: implement Grep, Glob, LS, and Read tools with ripgrep support and update ignore configurations --- .gitignore | 1 + .ignore | 1 + src/index.ts | 604 +++++++++++++++++++++++++++++++++++++++++++- tests/index.test.ts | 6 +- 4 files changed, 609 insertions(+), 3 deletions(-) create mode 100644 .ignore diff --git a/.gitignore b/.gitignore index 45a1e50..b76e3ab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ node_modules/ coverage dist/ +docs/ *.tgz .DS_Store diff --git a/.ignore b/.ignore new file mode 100644 index 0000000..00a2480 --- /dev/null +++ b/.ignore @@ -0,0 +1 @@ +!docs/ \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index fb2b05a..ecf6d43 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,9 +12,17 @@ * PI_GROK_CLI_OAUTH_SCOPE - Override OAuth scopes */ -import { existsSync, mkdirSync, readFileSync, writeFileSync } from "node:fs"; +import { execFile } from "node:child_process"; +import { + existsSync, + mkdirSync, + readFileSync, + unlinkSync, + writeFileSync, +} from "node:fs"; import { homedir } from "node:os"; -import { dirname, join } from "node:path"; +import { dirname, join, resolve } from "node:path"; +import { promisify } from "node:util"; import { type Api, type AssistantMessageEventStream, @@ -24,11 +32,15 @@ import { type OAuthLoginCallbacks, type SimpleStreamOptions, streamSimpleOpenAIResponses, + Type, } from "@earendil-works/pi-ai"; import type { ExtensionAPI, ProviderConfig, } from "@earendil-works/pi-coding-agent"; + +const execFileAsync = promisify(execFile); + import { XaiOAuthError } from "./errors.js"; import { type GrokCliModelConfig, resolveModels } from "./models.js"; import * as oauth from "./oauth.js"; @@ -264,6 +276,594 @@ export default function (pi: ExtensionAPI) { streamSimple: streamGrokCli, }); + // ── Register Grok/Cursor-native tools ────────────────────────────────── + + const MAX_OUTPUT_CHARS = 50_000; + const MAX_LINES = 500; + + function truncateLines(lines: string[]): string { + if (lines.length > MAX_LINES) { + return ( + lines.slice(0, MAX_LINES).join("\n") + + `\n\n[Showing first ${MAX_LINES} of ${lines.length} results. Refine your pattern to narrow results.]` + ); + } + return lines.join("\n"); + } + + function truncateChars(output: string): string { + if (output.length > MAX_OUTPUT_CHARS) { + return `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + return output; + } + + let rgAvailable: boolean | undefined; + async function hasRipgrep(): Promise { + if (rgAvailable !== undefined) return rgAvailable; + try { + await execFileAsync("rg", ["--version"]); + rgAvailable = true; + } catch { + rgAvailable = false; + } + return rgAvailable; + } + + type ToolError = { code?: number; message?: string }; + type ToolResult = { + content: [{ type: "text"; text: string }]; + details: T; + }; + + type FileDetails = { path: string; [key: string]: unknown }; + + function fileNotFound( + filePath: string, + extraDetails: Omit, + ): ToolResult { + return { + content: [{ type: "text", text: `File not found: ${filePath}` }], + details: { path: filePath, ...extraDetails } as T, + }; + } + + function fileError( + error: unknown, + toolName: string, + filePath: string, + extraDetails: Omit, + ): ToolResult { + const err = error as ToolError; + return { + content: [ + { + type: "text", + text: `${toolName} error: ${err.message ?? "Unknown error"}`, + }, + ], + details: { path: filePath, ...extraDetails } as T, + }; + } + + function toolError( + error: unknown, + toolName: string, + emptyDetails: T, + ): ToolResult { + const err = error as ToolError; + if (err.code === 1) { + return { + content: [{ type: "text", text: "No matches found" }], + details: emptyDetails, + }; + } + return { + content: [ + { + type: "text", + text: `${toolName} error: ${err.message ?? "Unknown error"}`, + }, + ], + details: emptyDetails, + }; + } + + async function execWithRgFallback( + rgArgs: string[], + grepArgs: string[], + options: { cwd: string; signal?: AbortSignal }, + ): Promise { + if (await hasRipgrep()) { + const result = await execFileAsync("rg", rgArgs, { + cwd: options.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal: options.signal, + }); + return result.stdout; + } + const result = await execFileAsync("grep", grepArgs, { + cwd: options.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal: options.signal, + }); + return result.stdout; + } + + const GrepParams = Type.Object({ + pattern: Type.String({ + description: "Regex pattern to search for in file contents", + }), + path: Type.Optional( + Type.String({ + description: + "Directory or file to search. Defaults to current working directory.", + }), + ), + include: Type.Optional( + Type.String({ + description: + "Glob pattern to filter which files are searched (e.g. *.ts, **/*.md)", + }), + ), + }); + + pi.registerTool({ + name: "Grep", + label: "Grep", + description: + "Search for a regex pattern in file contents. Returns matching lines with file path and line number. Use the include parameter to filter by file type.", + parameters: GrepParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const searchPath = resolve(ctx.cwd, params.path ?? "."); + + try { + const rgArgs = ["-n", "--no-heading", "--color=never"]; + if (params.include) rgArgs.push("--glob", params.include); + rgArgs.push(params.pattern, searchPath); + + const grepArgs = ["-r", "-n", "--color=never"]; + if (params.include) grepArgs.push(`--include=${params.include}`); + grepArgs.push(params.pattern, searchPath); + + const stdout = await execWithRgFallback(rgArgs, grepArgs, { + cwd: ctx.cwd, + signal, + }); + + const lines = stdout.trim().split("\n").filter(Boolean); + if (lines.length === 0) { + return { + content: [{ type: "text", text: "No matches found" }], + details: { matchCount: 0 }, + }; + } + + return { + content: [ + { type: "text", text: truncateChars(truncateLines(lines)) }, + ], + details: { matchCount: lines.length }, + }; + } catch (error: unknown) { + return toolError(error, "Grep", { matchCount: 0 }); + } + }, + }); + + const GlobParams = Type.Object({ + pattern: Type.String({ + description: "Glob pattern to match files (e.g. **/*.ts, src/**/*.json)", + }), + path: Type.Optional( + Type.String({ + description: + "Directory to search within. Defaults to current working directory.", + }), + ), + }); + + pi.registerTool({ + name: "Glob", + label: "Glob", + description: + "Find files matching a glob pattern. Returns a list of matching file paths sorted by modification time (newest first).", + parameters: GlobParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const searchPath = resolve(ctx.cwd, params.path ?? "."); + + try { + let files: string[]; + + if (await hasRipgrep()) { + const result = await execFileAsync( + "rg", + ["--files", "--color=never", "--glob", params.pattern, searchPath], + { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, + ); + files = result.stdout.trim().split("\n").filter(Boolean); + } else { + // find fallback — convert **/*.ext → -name "*.ext" + const basename = params.pattern.replace(/^(\*\*\/)+/, ""); + const result = await execFileAsync( + "find", + [searchPath, "-type", "f", "-name", basename], + { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, + ); + files = result.stdout.trim().split("\n").filter(Boolean); + } + + if (files.length === 0) { + return { + content: [{ type: "text", text: "No files found" }], + details: { fileCount: 0 }, + }; + } + + return { + content: [ + { type: "text", text: truncateChars(truncateLines(files)) }, + ], + details: { fileCount: files.length }, + }; + } catch (error: unknown) { + return toolError(error, "Glob", { fileCount: 0 }); + } + }, + }); + + // ── LS tool ────────────────────────────────────────────────────────── + + const LsParams = Type.Object({ + path: Type.String({ + description: "Directory path to list", + }), + }); + + pi.registerTool({ + name: "LS", + label: "LS", + description: "List the contents of a directory, including hidden files.", + parameters: LsParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const targetPath = resolve(ctx.cwd, params.path); + + try { + const { stdout } = await execFileAsync("ls", ["-la", targetPath], { + cwd: ctx.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal, + }); + + let output = stdout.trim(); + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[LS: output truncated at 50KB]`; + } + + return { + content: [{ type: "text", text: output }], + details: { path: targetPath }, + }; + } catch (error: unknown) { + const err = error as ToolError; + return { + content: [ + { + type: "text", + text: `LS error: ${err.message ?? "Unknown error"}`, + }, + ], + details: { path: targetPath }, + }; + } + }, + }); + + // ── Read tool ──────────────────────────────────────────────────────── + + const ReadParams = Type.Object({ + path: Type.String({ + description: "Path to the file to read", + }), + offset: Type.Optional( + Type.Number({ + description: "Line number to start reading from (0-indexed)", + }), + ), + limit: Type.Optional( + Type.Number({ + description: "Maximum number of lines to read", + }), + ), + }); + + pi.registerTool({ + name: "Read", + label: "Read", + description: + "Read the contents of a file. Returns the file content with line numbers.", + parameters: ReadParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { exists: false, totalLines: 0 }); + } + + const content = readFileSync(filePath, "utf-8"); + const lines = content.split("\n"); + + const startLine = params.offset ?? 0; + const endLine = params.limit + ? Math.min(startLine + params.limit, lines.length) + : Math.min(startLine + 2000, lines.length); + + const selectedLines = lines.slice(startLine, endLine); + const numberedLines = selectedLines.map( + (line, i) => `${startLine + i + 1}\t${line}`, + ); + + let output = numberedLines.join("\n"); + if (endLine < lines.length) { + output += `\n\n[Showing lines ${startLine + 1}-${endLine} of ${lines.length} total lines. Use offset to see more.]`; + } + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [{ type: "text", text: output }], + details: { path: filePath, totalLines: lines.length }, + }; + } catch (error: unknown) { + return fileError(error, "Read", filePath, { + exists: false, + totalLines: 0, + }); + } + }, + }); + + // ── Write tool ─────────────────────────────────────────────────────── + + const WriteParams = Type.Object({ + path: Type.String({ + description: "Path to the file to write", + }), + content: Type.String({ + description: "Content to write to the file", + }), + }); + + pi.registerTool({ + name: "Write", + label: "Write", + description: + "Create or overwrite a file with the given content. Creates parent directories if needed.", + parameters: WriteParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + mkdirSync(dirname(filePath), { recursive: true }); + writeFileSync(filePath, params.content, "utf-8"); + + return { + content: [ + { + type: "text", + text: `Successfully wrote ${params.content.length} bytes to ${params.path}`, + }, + ], + details: { path: filePath, bytesWritten: params.content.length }, + }; + } catch (error: unknown) { + const err = error as ToolError; + return { + content: [ + { + type: "text", + text: `Write error: ${err.message ?? "Unknown error"}`, + }, + ], + details: { path: filePath, bytesWritten: 0 }, + }; + } + }, + }); + + // ── StrReplace tool ────────────────────────────────────────────────── + + const StrReplaceParams = Type.Object({ + path: Type.String({ + description: "Path to the file to modify", + }), + old_str: Type.String({ + description: "String to search for (exact match)", + }), + new_str: Type.String({ + description: "String to replace with", + }), + }); + + pi.registerTool({ + name: "StrReplace", + label: "StrReplace", + description: + "Replace all occurrences of a string in a file. The old_str must be an exact match.", + parameters: StrReplaceParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { replacements: 0 }); + } + + const content = readFileSync(filePath, "utf-8"); + const count = content.split(params.old_str).length - 1; + + if (count === 0) { + return { + content: [ + { + type: "text", + text: `String not found in ${params.path}: "${params.old_str}"`, + }, + ], + details: { path: filePath, replacements: 0 }, + }; + } + + const newContent = content.replaceAll(params.old_str, params.new_str); + writeFileSync(filePath, newContent, "utf-8"); + + return { + content: [ + { + type: "text", + text: `Replaced ${count} occurrence(s) in ${params.path}`, + }, + ], + details: { path: filePath, replacements: count }, + }; + } catch (error: unknown) { + return fileError(error, "StrReplace", filePath, { replacements: 0 }); + } + }, + }); + + // ── Delete tool ────────────────────────────────────────────────────── + + const DeleteParams = Type.Object({ + path: Type.String({ + description: "Path to the file to delete", + }), + }); + + pi.registerTool({ + name: "Delete", + label: "Delete", + description: "Delete a file from the filesystem.", + parameters: DeleteParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { deleted: false }); + } + + unlinkSync(filePath); + + return { + content: [ + { type: "text", text: `Successfully deleted ${params.path}` }, + ], + details: { path: filePath, deleted: true }, + }; + } catch (error: unknown) { + return fileError(error, "Delete", filePath, { deleted: false }); + } + }, + }); + + // ── Shell tool ─────────────────────────────────────────────────────── + + const ShellParams = Type.Object({ + command: Type.String({ + description: "Shell command to execute", + }), + working_directory: Type.Optional( + Type.String({ + description: "Working directory for the command", + }), + ), + timeout: Type.Optional( + Type.Number({ + description: "Timeout in milliseconds (default: 120000)", + }), + ), + }); + + pi.registerTool({ + name: "Shell", + label: "Shell", + description: + "Execute a shell command and return stdout, stderr, and exit code.", + parameters: ShellParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const cwd = params.working_directory + ? resolve(ctx.cwd, params.working_directory) + : ctx.cwd; + const timeout = params.timeout ?? 120_000; + + try { + const { stdout, stderr } = await execFileAsync( + "bash", + ["-c", params.command], + { + cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + timeout, + signal, + }, + ); + + let output = ""; + if (stdout) output += stdout; + if (stderr) output += `\n[stderr]\n${stderr}`; + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [{ type: "text", text: output || "(no output)" }], + details: { exitCode: 0, command: params.command }, + }; + } catch (error: unknown) { + const err = error as { + code?: number; + message?: string; + stdout?: string; + stderr?: string; + }; + + let output = ""; + if (err.stdout) output += err.stdout; + if (err.stderr) output += `\n[stderr]\n${err.stderr}`; + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [ + { + type: "text", + text: `Shell error (exit code ${err.code ?? "unknown"}): ${err.message ?? "Unknown error"}${output ? `\n${output}` : ""}`, + }, + ], + details: { + exitCode: err.code ?? 1, + command: params.command, + }, + }; + } + }, + }); + // ── Payload sanitization via event ──────────────────────────────────── pi.on("before_provider_request", (event, ctx) => { if (ctx.model?.provider !== "grok-cli") return; diff --git a/tests/index.test.ts b/tests/index.test.ts index a2f9b95..58d493d 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -79,6 +79,7 @@ afterEach(() => { async function setupExtension() { const commands = new Map(); const providers = new Map(); + const tools = new Map(); const registerGrokCli = (await import("../src/index.js")).default; registerGrokCli({ registerProvider(name: string, config: ProviderConfig) { @@ -88,8 +89,11 @@ async function setupExtension() { registerCommand(name: string, config: unknown) { commands.set(name, config as CommandConfig); }, + registerTool(tool: { name: string }) { + tools.set(tool.name, tool); + }, } as unknown as ExtensionAPI); - return { commands, providers }; + return { commands, providers, tools }; } function statusContext(notify: TestContext["ui"]["notify"]): TestContext { From 558138b2eb33e93bc6b6827a93920f3ad473113d Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Tue, 2 Jun 2026 18:39:13 +0900 Subject: [PATCH 02/24] feat: implement dynamic tool scoping and compact result rendering for Grok CLI tools --- bun.lock | 2 + package.json | 4 +- src/index.ts | 241 ++++++++++++++++++++++++++++++++++++++++++++ tests/index.test.ts | 232 +++++++++++++++++++++++++++++++++++++++++- 4 files changed, 473 insertions(+), 6 deletions(-) diff --git a/bun.lock b/bun.lock index ec5bcd7..88c55bc 100644 --- a/bun.lock +++ b/bun.lock @@ -8,6 +8,7 @@ "@biomejs/biome": "2.4.16", "@earendil-works/pi-ai": "^0.78.0", "@earendil-works/pi-coding-agent": "^0.78.0", + "@earendil-works/pi-tui": "^0.78.0", "@vitest/coverage-v8": "^4.1.8", "release-tools": "github:kenryu42/release-tools", "husky": "^9.1.7", @@ -20,6 +21,7 @@ "peerDependencies": { "@earendil-works/pi-ai": "*", "@earendil-works/pi-coding-agent": "*", + "@earendil-works/pi-tui": "*", }, }, }, diff --git a/package.json b/package.json index f489fdd..7da54d9 100644 --- a/package.json +++ b/package.json @@ -51,12 +51,14 @@ }, "peerDependencies": { "@earendil-works/pi-ai": "*", - "@earendil-works/pi-coding-agent": "*" + "@earendil-works/pi-coding-agent": "*", + "@earendil-works/pi-tui": "*" }, "devDependencies": { "@biomejs/biome": "2.4.16", "@earendil-works/pi-ai": "^0.78.0", "@earendil-works/pi-coding-agent": "^0.78.0", + "@earendil-works/pi-tui": "^0.78.0", "@vitest/coverage-v8": "^4.1.8", "husky": "^9.1.7", "jscpd": "^4.2.4", diff --git a/src/index.ts b/src/index.ts index ecf6d43..66fcb0a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -38,6 +38,7 @@ import type { ExtensionAPI, ProviderConfig, } from "@earendil-works/pi-coding-agent"; +import { Text } from "@earendil-works/pi-tui"; const execFileAsync = promisify(execFile); @@ -51,6 +52,16 @@ import { sanitizePayload } from "./sanitize.js"; const GROK_CLI_VERSION = "0.2.16"; const QUOTA_CACHE_FILE = "grok-cli-quota.json"; +const GROK_TOOL_NAMES = [ + "Grep", + "Glob", + "LS", + "Read", + "Write", + "StrReplace", + "Delete", + "Shell", +]; // ─── Rate limit cache (piggybacks on onResponse from normal traffic) ────────── @@ -229,6 +240,32 @@ export default function (pi: ExtensionAPI) { const baseUrl = getBaseUrl(); const models = resolveModels(); + function syncGrokTools(provider: string | undefined) { + const currentTools = pi.getActiveTools(); + const baseTools = currentTools.filter( + (toolName) => !GROK_TOOL_NAMES.includes(toolName), + ); + const nextTools = + provider === "grok-cli" ? [...baseTools, ...GROK_TOOL_NAMES] : baseTools; + + if ( + currentTools.length === nextTools.length && + currentTools.every((toolName, i) => toolName === nextTools[i]) + ) { + return; + } + + pi.setActiveTools(nextTools); + } + + pi.on("model_select", (event) => { + syncGrokTools(event.model.provider); + }); + + pi.on("before_agent_start", (_event, ctx) => { + syncGrokTools(ctx.model?.provider); + }); + // ── Register provider ───────────────────────────────────────────────── pi.registerProvider("grok-cli", { name: "Grok CLI", @@ -316,6 +353,52 @@ export default function (pi: ExtensionAPI) { details: T; }; + function text(text: string): Text { + return new Text(text, 0, 0); + } + + function firstText(result: { content: { type: string; text?: string }[] }) { + const first = result.content[0]; + if (first?.type !== "text") return undefined; + return first.text; + } + + function renderResultText( + result: { content: { type: string; text?: string }[] }, + expanded: boolean, + summary: string, + ): Text { + if (expanded) return text(firstText(result) ?? summary); + return text(summary); + } + + function renderRunning(isPartial: boolean): Text | undefined { + if (!isPartial) return undefined; + return text("Running..."); + } + + function detailRecord(result: { details: unknown }): Record { + if (!result.details || typeof result.details !== "object") return {}; + return result.details as Record; + } + + function numberDetail(result: { details: unknown }, key: string): number { + const value = detailRecord(result)[key]; + if (typeof value !== "number") return 0; + return value; + } + + function stringDetail(result: { details: unknown }, key: string): string { + const value = detailRecord(result)[key]; + if (typeof value !== "string") return ""; + return value; + } + + function booleanDetail(result: { details: unknown }, key: string): boolean { + const value = detailRecord(result)[key]; + return value === true; + } + type FileDetails = { path: string; [key: string]: unknown }; function fileNotFound( @@ -450,6 +533,28 @@ export default function (pi: ExtensionAPI) { return toolError(error, "Grep", { matchCount: 0 }); } }, + renderCall(args, theme) { + const path = args.path ? theme.fg("muted", ` in ${args.path}`) : ""; + const include = args.include ? theme.fg("dim", ` [${args.include}]`) : ""; + return text( + theme.fg("toolTitle", theme.bold("Grep ")) + + theme.fg("accent", `"${args.pattern}"`) + + path + + include, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const matchCount = numberDetail(result, "matchCount"); + return renderResultText( + result, + expanded, + matchCount === 0 + ? theme.fg("dim", "No matches") + : theme.fg("muted", `${matchCount} match(es)`), + ); + }, }); const GlobParams = Type.Object({ @@ -512,6 +617,26 @@ export default function (pi: ExtensionAPI) { return toolError(error, "Glob", { fileCount: 0 }); } }, + renderCall(args, theme) { + const path = args.path ? theme.fg("muted", ` in ${args.path}`) : ""; + return text( + theme.fg("toolTitle", theme.bold("Glob ")) + + theme.fg("accent", args.pattern) + + path, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const fileCount = numberDetail(result, "fileCount"); + return renderResultText( + result, + expanded, + fileCount === 0 + ? theme.fg("dim", "No files") + : theme.fg("muted", `${fileCount} file(s)`), + ); + }, }); // ── LS tool ────────────────────────────────────────────────────────── @@ -560,6 +685,21 @@ export default function (pi: ExtensionAPI) { }; } }, + renderCall(args, theme) { + return text( + theme.fg("toolTitle", theme.bold("LS ")) + + theme.fg("accent", args.path), + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText( + result, + expanded, + theme.fg("muted", stringDetail(result, "path")), + ); + }, }); // ── Read tool ──────────────────────────────────────────────────────── @@ -628,6 +768,31 @@ export default function (pi: ExtensionAPI) { }); } }, + renderCall(args, theme) { + const range = + args.offset !== undefined || args.limit !== undefined + ? theme.fg( + "muted", + ` (from ${args.offset ?? 0}${args.limit ? `, ${args.limit} lines` : ""})`, + ) + : ""; + return text( + theme.fg("toolTitle", theme.bold("Read ")) + + theme.fg("accent", args.path) + + range, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText( + result, + expanded, + detailRecord(result).exists === false + ? theme.fg("error", "File not found") + : theme.fg("muted", `${numberDetail(result, "totalLines")} line(s)`), + ); + }, }); // ── Write tool ─────────────────────────────────────────────────────── @@ -677,6 +842,24 @@ export default function (pi: ExtensionAPI) { }; } }, + renderCall(args, theme) { + return text( + theme.fg("toolTitle", theme.bold("Write ")) + + theme.fg("accent", args.path), + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText( + result, + expanded, + theme.fg( + "muted", + `${numberDetail(result, "bytesWritten")} bytes written`, + ), + ); + }, }); // ── StrReplace tool ────────────────────────────────────────────────── @@ -739,6 +922,26 @@ export default function (pi: ExtensionAPI) { return fileError(error, "StrReplace", filePath, { replacements: 0 }); } }, + renderCall(args, theme) { + return text( + theme.fg("toolTitle", theme.bold("StrReplace ")) + + theme.fg("accent", args.path), + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText( + result, + expanded, + numberDetail(result, "replacements") === 0 + ? theme.fg("dim", "No replacements") + : theme.fg( + "muted", + `${numberDetail(result, "replacements")} replacement(s)`, + ), + ); + }, }); // ── Delete tool ────────────────────────────────────────────────────── @@ -775,6 +978,23 @@ export default function (pi: ExtensionAPI) { return fileError(error, "Delete", filePath, { deleted: false }); } }, + renderCall(args, theme) { + return text( + theme.fg("toolTitle", theme.bold("Delete ")) + + theme.fg("accent", args.path), + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText( + result, + expanded, + booleanDetail(result, "deleted") + ? theme.fg("muted", "Deleted") + : theme.fg("error", "Not deleted"), + ); + }, }); // ── Shell tool ─────────────────────────────────────────────────────── @@ -862,6 +1082,27 @@ export default function (pi: ExtensionAPI) { }; } }, + renderCall(args, theme) { + const cwd = args.working_directory + ? theme.fg("muted", ` in ${args.working_directory}`) + : ""; + return text( + theme.fg("toolTitle", theme.bold("Shell ")) + + theme.fg("accent", args.command) + + cwd, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText( + result, + expanded, + numberDetail(result, "exitCode") === 0 + ? theme.fg("muted", "Exit 0") + : theme.fg("warning", `Exit ${numberDetail(result, "exitCode")}`), + ); + }, }); // ── Payload sanitization via event ──────────────────────────────────── diff --git a/tests/index.test.ts b/tests/index.test.ts index 58d493d..2c420d7 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -44,16 +44,40 @@ interface CommandConfig { handler: (args: string[], ctx: TestContext) => Promise; } +interface RegisteredTool { + name: string; + renderCall?: (...args: unknown[]) => Renderable; + renderResult?: (...args: unknown[]) => Renderable; +} + +interface Renderable { + render: (width: number) => string[]; +} + interface TestContext { modelRegistry: { getAll: () => { provider: string; id: string }[]; getApiKeyForProvider?: (provider: string) => Promise; }; + model?: { provider: string; id: string }; ui: { notify: (message: string, level: string) => void; }; } +type ExtensionHandler = (event: unknown, ctx: TestContext) => unknown; + +const grokToolNames = [ + "Grep", + "Glob", + "LS", + "Read", + "Write", + "StrReplace", + "Delete", + "Shell", +]; + const originalFetch = globalThis.fetch; const originalHome = process.env.HOME; const originalToken = process.env.GROK_CLI_OAUTH_TOKEN; @@ -76,24 +100,35 @@ afterEach(() => { for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); }); -async function setupExtension() { +async function setupExtension(initialActiveTools = ["read", "bash"]) { const commands = new Map(); const providers = new Map(); - const tools = new Map(); + const tools = new Map(); + const handlers = new Map(); + let activeTools = initialActiveTools; + const setActiveTools = vi.fn((toolNames: string[]) => { + activeTools = toolNames; + }); const registerGrokCli = (await import("../src/index.js")).default; registerGrokCli({ registerProvider(name: string, config: ProviderConfig) { providers.set(name, config); }, - on() {}, + on(event: string, handler: ExtensionHandler) { + handlers.set(event, handler); + }, registerCommand(name: string, config: unknown) { commands.set(name, config as CommandConfig); }, - registerTool(tool: { name: string }) { + registerTool(tool: RegisteredTool) { tools.set(tool.name, tool); }, + getActiveTools() { + return activeTools; + }, + setActiveTools, } as unknown as ExtensionAPI); - return { commands, providers, tools }; + return { commands, providers, tools, handlers, setActiveTools }; } function statusContext(notify: TestContext["ui"]["notify"]): TestContext { @@ -108,6 +143,26 @@ function statusContext(notify: TestContext["ui"]["notify"]): TestContext { }; } +function contextForModel(provider: string): TestContext { + return { + model: { provider, id: `${provider}-model` }, + modelRegistry: { getAll: () => [] }, + ui: { notify: vi.fn() }, + }; +} + +function renderText(component: Renderable): string { + return component + .render(120) + .map((line) => line.trimEnd()) + .join("\n"); +} + +const theme = { + bold: (text: string) => text, + fg: (_name: string, text: string) => text, +}; + function setupHome() { const dir = mkdtempSync(join(tmpdir(), "pi-grok-cli-home-")); mkdirSync(join(dir, ".pi")); @@ -217,3 +272,170 @@ describe("Grok CLI status command", () => { ); }); }); + +describe("Grok CLI tool scoping", () => { + it("registers the Grok/Cursor-native tool shims", async () => { + const extension = await setupExtension(); + + expect([...extension.tools.keys()].sort()).toEqual( + [...grokToolNames].sort(), + ); + }); + + it("enables Grok tools for Grok models while preserving other active tools", async () => { + const extension = await setupExtension(["read", "custom_tool"]); + + await extension.handlers.get("model_select")?.( + { model: { provider: "grok-cli", id: "grok-build" } }, + contextForModel("grok-cli"), + ); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith([ + "read", + "custom_tool", + ...grokToolNames, + ]); + }); + + it("removes Grok tools for non-Grok models while preserving other active tools", async () => { + const extension = await setupExtension([ + "read", + "Grep", + "custom_tool", + "Shell", + ]); + + await extension.handlers.get("model_select")?.( + { model: { provider: "openai", id: "gpt-4" } }, + contextForModel("openai"), + ); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith([ + "read", + "custom_tool", + ]); + }); + + it("syncs tool scope before each agent turn from the current context model", async () => { + const extension = await setupExtension(["read"]); + + await extension.handlers.get("before_agent_start")?.( + {}, + contextForModel("grok-cli"), + ); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith([ + "read", + ...grokToolNames, + ]); + }); + + it("does not update active tools when the selection is already correct", async () => { + const extension = await setupExtension(["read", ...grokToolNames]); + + await extension.handlers.get("before_agent_start")?.( + {}, + contextForModel("grok-cli"), + ); + + expect(extension.setActiveTools).not.toHaveBeenCalled(); + }); +}); + +describe("Grok CLI tool rendering", () => { + it("adds renderers to every Grok tool shim", async () => { + const extension = await setupExtension(); + + for (const name of grokToolNames) { + expect(extension.tools.get(name)?.renderCall).toBeTypeOf("function"); + expect(extension.tools.get(name)?.renderResult).toBeTypeOf("function"); + } + }); + + it("keeps collapsed search output compact and expands to full output", async () => { + const extension = await setupExtension(); + const grep = extension.tools.get("Grep"); + const result = { + content: [{ type: "text", text: "src/a.ts:1:match\nsrc/b.ts:2:match" }], + details: { matchCount: 2 }, + }; + + const collapsed = renderText( + grep?.renderResult?.( + result, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ); + const expanded = renderText( + grep?.renderResult?.( + result, + { expanded: true, isPartial: false }, + theme, + {}, + ) as Renderable, + ); + + expect(collapsed).toBe("2 match(es)"); + expect(collapsed).not.toContain("src/a.ts"); + expect(expanded).toContain("src/a.ts:1:match"); + }); + + it("renders compact summaries for file mutations, delete, and shell tools", async () => { + const extension = await setupExtension(); + + expect( + renderText( + extension.tools.get("Write")?.renderResult?.( + { + content: [{ type: "text", text: "long write output" }], + details: { bytesWritten: 42 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe("42 bytes written"); + expect( + renderText( + extension.tools.get("StrReplace")?.renderResult?.( + { + content: [{ type: "text", text: "long replace output" }], + details: { replacements: 3 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe("3 replacement(s)"); + expect( + renderText( + extension.tools.get("Delete")?.renderResult?.( + { + content: [{ type: "text", text: "long delete output" }], + details: { deleted: true }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe("Deleted"); + expect( + renderText( + extension.tools.get("Shell")?.renderResult?.( + { + content: [{ type: "text", text: "long shell output" }], + details: { exitCode: 2 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe("Exit 2"); + }); +}); From 980fe06313c6190d6baaa724de0540a7ec64b68f Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Tue, 2 Jun 2026 19:11:12 +0900 Subject: [PATCH 03/24] refactor: modularize codebase by decomposing index into dedicated service modules for auth, tools, models, and provider logic --- src/{ => auth}/oauth.ts | 2 +- src/index.ts | 1186 +---------------- src/{models.ts => models/catalog.ts} | 0 src/{ => payload}/sanitize.ts | 2 +- src/provider/quota.ts | 144 ++ src/provider/register.ts | 104 ++ src/provider/status.ts | 66 + src/provider/stream.ts | 53 + src/provider/toolScope.ts | 33 + src/{ => shared}/errors.ts | 0 src/tools/files.ts | 385 ++++++ src/tools/register.ts | 10 + src/tools/rendering.ts | 185 +++ src/tools/search.ts | 187 +++ src/tools/shell.ts | 124 ++ tests/{ => auth}/oauth.test.ts | 4 +- .../catalog.test.ts} | 5 +- tests/{ => payload}/sanitize.test.ts | 2 +- tests/{ => provider}/package.test.ts | 50 +- .../register.test.ts} | 2 +- tests/{ => shared}/errors.test.ts | 2 +- tests/tools/register.test.ts | 32 + 22 files changed, 1366 insertions(+), 1212 deletions(-) rename src/{ => auth}/oauth.ts (99%) rename src/{models.ts => models/catalog.ts} (100%) rename src/{ => payload}/sanitize.ts (99%) create mode 100644 src/provider/quota.ts create mode 100644 src/provider/register.ts create mode 100644 src/provider/status.ts create mode 100644 src/provider/stream.ts create mode 100644 src/provider/toolScope.ts rename src/{ => shared}/errors.ts (100%) create mode 100644 src/tools/files.ts create mode 100644 src/tools/register.ts create mode 100644 src/tools/rendering.ts create mode 100644 src/tools/search.ts create mode 100644 src/tools/shell.ts rename tests/{ => auth}/oauth.test.ts (98%) rename tests/{models.test.ts => models/catalog.test.ts} (93%) rename tests/{ => payload}/sanitize.test.ts (98%) rename tests/{ => provider}/package.test.ts (52%) rename tests/{index.test.ts => provider/register.test.ts} (99%) rename tests/{ => shared}/errors.test.ts (87%) create mode 100644 tests/tools/register.test.ts diff --git a/src/oauth.ts b/src/auth/oauth.ts similarity index 99% rename from src/oauth.ts rename to src/auth/oauth.ts index 6e515e2..a4dca37 100644 --- a/src/oauth.ts +++ b/src/auth/oauth.ts @@ -10,7 +10,7 @@ */ import { createServer } from "node:http"; -import { XaiErrorCode, XaiOAuthError } from "./errors.js"; +import { XaiErrorCode, XaiOAuthError } from "../shared/errors.js"; // ─── Constants ──────────────────────────────────────────────────────────────── diff --git a/src/index.ts b/src/index.ts index 66fcb0a..1d7f42f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,1189 +1,7 @@ /** * pi-grok-cli — Grok CLI API provider for pi * - * Brings access to the Grok CLI's endpoint - * into pi. This endpoint has access to models not available on the public - * xAI API, including grok-composer-2.5-fast (Cursor's Composer 2.5 model). - * - * Environment variables: - * PI_GROK_CLI_BASE_URL - Override the API base URL - * PI_GROK_CLI_MODELS - Comma-separated model IDs to expose - * PI_GROK_CLI_OAUTH_CLIENT_ID - Override OAuth client ID - * PI_GROK_CLI_OAUTH_SCOPE - Override OAuth scopes - */ - -import { execFile } from "node:child_process"; -import { - existsSync, - mkdirSync, - readFileSync, - unlinkSync, - writeFileSync, -} from "node:fs"; -import { homedir } from "node:os"; -import { dirname, join, resolve } from "node:path"; -import { promisify } from "node:util"; -import { - type Api, - type AssistantMessageEventStream, - type Context, - type Model, - type OAuthCredentials, - type OAuthLoginCallbacks, - type SimpleStreamOptions, - streamSimpleOpenAIResponses, - Type, -} from "@earendil-works/pi-ai"; -import type { - ExtensionAPI, - ProviderConfig, -} from "@earendil-works/pi-coding-agent"; -import { Text } from "@earendil-works/pi-tui"; - -const execFileAsync = promisify(execFile); - -import { XaiOAuthError } from "./errors.js"; -import { type GrokCliModelConfig, resolveModels } from "./models.js"; -import * as oauth from "./oauth.js"; -import { getBaseUrl, type XaiOAuthCredentials } from "./oauth.js"; -import { sanitizePayload } from "./sanitize.js"; - -// ─── Grok CLI version (observed from traffic capture) ───────────────────────── - -const GROK_CLI_VERSION = "0.2.16"; -const QUOTA_CACHE_FILE = "grok-cli-quota.json"; -const GROK_TOOL_NAMES = [ - "Grep", - "Glob", - "LS", - "Read", - "Write", - "StrReplace", - "Delete", - "Shell", -]; - -// ─── Rate limit cache (piggybacks on onResponse from normal traffic) ────────── - -interface RateLimitInfo { - remainingRequests: number; - limitRequests: number; - remainingTokens: number; - limitTokens: number; - contextWindow: number; - zeroDataRetention: boolean; - capturedAt: number; -} - -const cachedRateLimits = new Map(); - -function quotaCachePath() { - return join(homedir(), ".pi", QUOTA_CACHE_FILE); -} - -function isRateLimitInfo(value: unknown): value is RateLimitInfo { - if (!value || typeof value !== "object") return false; - const info = value as Record; - return ( - typeof info.remainingRequests === "number" && - typeof info.limitRequests === "number" && - typeof info.remainingTokens === "number" && - typeof info.limitTokens === "number" && - typeof info.contextWindow === "number" && - typeof info.zeroDataRetention === "boolean" && - typeof info.capturedAt === "number" - ); -} - -function loadQuotaCache() { - cachedRateLimits.clear(); - if (!existsSync(quotaCachePath())) return; - - try { - const payload = JSON.parse( - readFileSync(quotaCachePath(), "utf8"), - ) as Record; - const models = payload.models; - if (!models || typeof models !== "object") return; - - Object.entries(models).forEach(([model, rateLimit]) => { - if (isRateLimitInfo(rateLimit)) cachedRateLimits.set(model, rateLimit); - }); - } catch { - cachedRateLimits.clear(); - } -} - -function persistQuotaCache() { - try { - mkdirSync(dirname(quotaCachePath()), { recursive: true }); - writeFileSync( - quotaCachePath(), - JSON.stringify( - { version: 1, models: Object.fromEntries(cachedRateLimits) }, - null, - "\t", - ), - ); - } catch { - // Status remains cache-only; persistence failures should not break requests. - } -} - -/** - * Extract rate limit info from response headers. - * Returns undefined if no rate limit headers are present. + * Brings access to the Grok CLI's endpoint into pi. */ -function extractRateLimit( - h: Record, -): RateLimitInfo | undefined { - const remainingReqs = Number(h["x-ratelimit-remaining-requests"]); - const limitReqs = Number(h["x-ratelimit-limit-requests"]); - const remainingTokens = Number(h["x-ratelimit-remaining-tokens"]); - const limitTokens = Number(h["x-ratelimit-limit-tokens"]); - const contextWindow = Number(h["x-grok-context-window"]); - - if (Number.isNaN(remainingReqs) && Number.isNaN(remainingTokens)) - return undefined; - - return { - remainingRequests: remainingReqs, - limitRequests: limitReqs, - remainingTokens, - limitTokens, - contextWindow: contextWindow || 512_000, - zeroDataRetention: h["x-zero-data-retention"] === "true", - capturedAt: Date.now(), - }; -} - -function formatQuota(name: string, rateLimit: RateLimitInfo | undefined) { - if (!rateLimit) { - return [ - ` ${name}:`, - " no cached quota data — make a request with this model first", - ]; - } - - const ageSec = Math.round((Date.now() - rateLimit.capturedAt) / 1000); - const ageStr = - ageSec < 60 ? `${ageSec}s ago` : `${Math.round(ageSec / 60)}m ago`; - const lines = [` ${name}:`]; - lines.push(` Cached: ${ageStr}`); - lines.push( - ` Requests: ${rateLimit.remainingRequests}/${rateLimit.limitRequests} remaining`, - ); - lines.push( - ` Tokens: ${rateLimit.remainingTokens.toLocaleString()}/${rateLimit.limitTokens.toLocaleString()} remaining`, - ); - lines.push( - ` Context Limit: ${rateLimit.contextWindow.toLocaleString()} tokens`, - ); - if (rateLimit.zeroDataRetention) { - lines.push(" Data: Zero retention ✓"); - } - return lines; -} - -// ─── Stream function ───────────────────────────────────────────────────────── - -/** - * Stream function that adds Grok CLI-specific headers to requests. - * - * The real Grok CLI sends these headers: - * - x-grok-client-identifier: grok-shell - * - x-grok-client-version: 0.2.16 - * - x-grok-conv-id: - * - x-grok-model-override: - * - x-xai-token-auth: xai-grok-cli - */ -function streamGrokCli( - model: Model, - context: Context, - options?: SimpleStreamOptions, -): AssistantMessageEventStream { - const sessionId = options?.sessionId; - const headers: Record = { - ...options?.headers, - "x-grok-client-identifier": "pi-grok-cli", - "x-grok-client-version": GROK_CLI_VERSION, - "x-xai-token-auth": "xai-grok-cli", - "x-grok-model-override": model.id, - }; - - if (sessionId) { - headers["x-grok-conv-id"] = sessionId; - } - - return streamSimpleOpenAIResponses( - model as Model<"openai-responses">, - context, - { - ...options, - headers, - onResponse(response) { - const rateLimit = extractRateLimit(response.headers); - if (rateLimit) { - cachedRateLimits.set(model.id, rateLimit); - persistQuotaCache(); - } - options?.onResponse?.(response, model); - }, - }, - ); -} - -// ─── Extension entry point ─────────────────────────────────────────────────── - -export default function (pi: ExtensionAPI) { - loadQuotaCache(); - const baseUrl = getBaseUrl(); - const models = resolveModels(); - - function syncGrokTools(provider: string | undefined) { - const currentTools = pi.getActiveTools(); - const baseTools = currentTools.filter( - (toolName) => !GROK_TOOL_NAMES.includes(toolName), - ); - const nextTools = - provider === "grok-cli" ? [...baseTools, ...GROK_TOOL_NAMES] : baseTools; - - if ( - currentTools.length === nextTools.length && - currentTools.every((toolName, i) => toolName === nextTools[i]) - ) { - return; - } - - pi.setActiveTools(nextTools); - } - - pi.on("model_select", (event) => { - syncGrokTools(event.model.provider); - }); - - pi.on("before_agent_start", (_event, ctx) => { - syncGrokTools(ctx.model?.provider); - }); - - // ── Register provider ───────────────────────────────────────────────── - pi.registerProvider("grok-cli", { - name: "Grok CLI", - baseUrl, - apiKey: "$GROK_CLI_OAUTH_TOKEN", - api: "openai-responses", - models: models.map((m: GrokCliModelConfig) => ({ - id: m.id, - name: m.name, - reasoning: m.reasoning, - thinkingLevelMap: m.thinkingLevelMap, - input: m.input, - cost: m.cost, - contextWindow: m.contextWindow, - maxTokens: m.maxTokens, - })), - oauth: { - name: "Grok CLI", - - async login(callbacks: OAuthLoginCallbacks): Promise { - return oauth.login(callbacks); - }, - - async refreshToken( - credentials: OAuthCredentials, - ): Promise { - return oauth.refresh(credentials); - }, - - getApiKey(credentials: OAuthCredentials): string { - return credentials.access; - }, - - modifyModels(models: Model[], credentials: OAuthCredentials) { - const effectiveBaseUrl = String( - (credentials as XaiOAuthCredentials).baseUrl ?? getBaseUrl(), - ).replace(/\/+$/, ""); - - return models.map((m) => - m.provider === "grok-cli" ? { ...m, baseUrl: effectiveBaseUrl } : m, - ); - }, - } satisfies ProviderConfig["oauth"], - - streamSimple: streamGrokCli, - }); - - // ── Register Grok/Cursor-native tools ────────────────────────────────── - - const MAX_OUTPUT_CHARS = 50_000; - const MAX_LINES = 500; - - function truncateLines(lines: string[]): string { - if (lines.length > MAX_LINES) { - return ( - lines.slice(0, MAX_LINES).join("\n") + - `\n\n[Showing first ${MAX_LINES} of ${lines.length} results. Refine your pattern to narrow results.]` - ); - } - return lines.join("\n"); - } - - function truncateChars(output: string): string { - if (output.length > MAX_OUTPUT_CHARS) { - return `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; - } - return output; - } - - let rgAvailable: boolean | undefined; - async function hasRipgrep(): Promise { - if (rgAvailable !== undefined) return rgAvailable; - try { - await execFileAsync("rg", ["--version"]); - rgAvailable = true; - } catch { - rgAvailable = false; - } - return rgAvailable; - } - - type ToolError = { code?: number; message?: string }; - type ToolResult = { - content: [{ type: "text"; text: string }]; - details: T; - }; - - function text(text: string): Text { - return new Text(text, 0, 0); - } - - function firstText(result: { content: { type: string; text?: string }[] }) { - const first = result.content[0]; - if (first?.type !== "text") return undefined; - return first.text; - } - - function renderResultText( - result: { content: { type: string; text?: string }[] }, - expanded: boolean, - summary: string, - ): Text { - if (expanded) return text(firstText(result) ?? summary); - return text(summary); - } - - function renderRunning(isPartial: boolean): Text | undefined { - if (!isPartial) return undefined; - return text("Running..."); - } - - function detailRecord(result: { details: unknown }): Record { - if (!result.details || typeof result.details !== "object") return {}; - return result.details as Record; - } - - function numberDetail(result: { details: unknown }, key: string): number { - const value = detailRecord(result)[key]; - if (typeof value !== "number") return 0; - return value; - } - - function stringDetail(result: { details: unknown }, key: string): string { - const value = detailRecord(result)[key]; - if (typeof value !== "string") return ""; - return value; - } - - function booleanDetail(result: { details: unknown }, key: string): boolean { - const value = detailRecord(result)[key]; - return value === true; - } - - type FileDetails = { path: string; [key: string]: unknown }; - - function fileNotFound( - filePath: string, - extraDetails: Omit, - ): ToolResult { - return { - content: [{ type: "text", text: `File not found: ${filePath}` }], - details: { path: filePath, ...extraDetails } as T, - }; - } - - function fileError( - error: unknown, - toolName: string, - filePath: string, - extraDetails: Omit, - ): ToolResult { - const err = error as ToolError; - return { - content: [ - { - type: "text", - text: `${toolName} error: ${err.message ?? "Unknown error"}`, - }, - ], - details: { path: filePath, ...extraDetails } as T, - }; - } - - function toolError( - error: unknown, - toolName: string, - emptyDetails: T, - ): ToolResult { - const err = error as ToolError; - if (err.code === 1) { - return { - content: [{ type: "text", text: "No matches found" }], - details: emptyDetails, - }; - } - return { - content: [ - { - type: "text", - text: `${toolName} error: ${err.message ?? "Unknown error"}`, - }, - ], - details: emptyDetails, - }; - } - - async function execWithRgFallback( - rgArgs: string[], - grepArgs: string[], - options: { cwd: string; signal?: AbortSignal }, - ): Promise { - if (await hasRipgrep()) { - const result = await execFileAsync("rg", rgArgs, { - cwd: options.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, - signal: options.signal, - }); - return result.stdout; - } - const result = await execFileAsync("grep", grepArgs, { - cwd: options.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, - signal: options.signal, - }); - return result.stdout; - } - - const GrepParams = Type.Object({ - pattern: Type.String({ - description: "Regex pattern to search for in file contents", - }), - path: Type.Optional( - Type.String({ - description: - "Directory or file to search. Defaults to current working directory.", - }), - ), - include: Type.Optional( - Type.String({ - description: - "Glob pattern to filter which files are searched (e.g. *.ts, **/*.md)", - }), - ), - }); - - pi.registerTool({ - name: "Grep", - label: "Grep", - description: - "Search for a regex pattern in file contents. Returns matching lines with file path and line number. Use the include parameter to filter by file type.", - parameters: GrepParams, - - async execute(_toolCallId, params, signal, _onUpdate, ctx) { - const searchPath = resolve(ctx.cwd, params.path ?? "."); - - try { - const rgArgs = ["-n", "--no-heading", "--color=never"]; - if (params.include) rgArgs.push("--glob", params.include); - rgArgs.push(params.pattern, searchPath); - - const grepArgs = ["-r", "-n", "--color=never"]; - if (params.include) grepArgs.push(`--include=${params.include}`); - grepArgs.push(params.pattern, searchPath); - - const stdout = await execWithRgFallback(rgArgs, grepArgs, { - cwd: ctx.cwd, - signal, - }); - - const lines = stdout.trim().split("\n").filter(Boolean); - if (lines.length === 0) { - return { - content: [{ type: "text", text: "No matches found" }], - details: { matchCount: 0 }, - }; - } - - return { - content: [ - { type: "text", text: truncateChars(truncateLines(lines)) }, - ], - details: { matchCount: lines.length }, - }; - } catch (error: unknown) { - return toolError(error, "Grep", { matchCount: 0 }); - } - }, - renderCall(args, theme) { - const path = args.path ? theme.fg("muted", ` in ${args.path}`) : ""; - const include = args.include ? theme.fg("dim", ` [${args.include}]`) : ""; - return text( - theme.fg("toolTitle", theme.bold("Grep ")) + - theme.fg("accent", `"${args.pattern}"`) + - path + - include, - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - const matchCount = numberDetail(result, "matchCount"); - return renderResultText( - result, - expanded, - matchCount === 0 - ? theme.fg("dim", "No matches") - : theme.fg("muted", `${matchCount} match(es)`), - ); - }, - }); - - const GlobParams = Type.Object({ - pattern: Type.String({ - description: "Glob pattern to match files (e.g. **/*.ts, src/**/*.json)", - }), - path: Type.Optional( - Type.String({ - description: - "Directory to search within. Defaults to current working directory.", - }), - ), - }); - - pi.registerTool({ - name: "Glob", - label: "Glob", - description: - "Find files matching a glob pattern. Returns a list of matching file paths sorted by modification time (newest first).", - parameters: GlobParams, - - async execute(_toolCallId, params, signal, _onUpdate, ctx) { - const searchPath = resolve(ctx.cwd, params.path ?? "."); - - try { - let files: string[]; - - if (await hasRipgrep()) { - const result = await execFileAsync( - "rg", - ["--files", "--color=never", "--glob", params.pattern, searchPath], - { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, - ); - files = result.stdout.trim().split("\n").filter(Boolean); - } else { - // find fallback — convert **/*.ext → -name "*.ext" - const basename = params.pattern.replace(/^(\*\*\/)+/, ""); - const result = await execFileAsync( - "find", - [searchPath, "-type", "f", "-name", basename], - { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, - ); - files = result.stdout.trim().split("\n").filter(Boolean); - } - - if (files.length === 0) { - return { - content: [{ type: "text", text: "No files found" }], - details: { fileCount: 0 }, - }; - } - - return { - content: [ - { type: "text", text: truncateChars(truncateLines(files)) }, - ], - details: { fileCount: files.length }, - }; - } catch (error: unknown) { - return toolError(error, "Glob", { fileCount: 0 }); - } - }, - renderCall(args, theme) { - const path = args.path ? theme.fg("muted", ` in ${args.path}`) : ""; - return text( - theme.fg("toolTitle", theme.bold("Glob ")) + - theme.fg("accent", args.pattern) + - path, - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - const fileCount = numberDetail(result, "fileCount"); - return renderResultText( - result, - expanded, - fileCount === 0 - ? theme.fg("dim", "No files") - : theme.fg("muted", `${fileCount} file(s)`), - ); - }, - }); - - // ── LS tool ────────────────────────────────────────────────────────── - - const LsParams = Type.Object({ - path: Type.String({ - description: "Directory path to list", - }), - }); - - pi.registerTool({ - name: "LS", - label: "LS", - description: "List the contents of a directory, including hidden files.", - parameters: LsParams, - - async execute(_toolCallId, params, signal, _onUpdate, ctx) { - const targetPath = resolve(ctx.cwd, params.path); - - try { - const { stdout } = await execFileAsync("ls", ["-la", targetPath], { - cwd: ctx.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, - signal, - }); - - let output = stdout.trim(); - if (output.length > MAX_OUTPUT_CHARS) { - output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[LS: output truncated at 50KB]`; - } - - return { - content: [{ type: "text", text: output }], - details: { path: targetPath }, - }; - } catch (error: unknown) { - const err = error as ToolError; - return { - content: [ - { - type: "text", - text: `LS error: ${err.message ?? "Unknown error"}`, - }, - ], - details: { path: targetPath }, - }; - } - }, - renderCall(args, theme) { - return text( - theme.fg("toolTitle", theme.bold("LS ")) + - theme.fg("accent", args.path), - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - return renderResultText( - result, - expanded, - theme.fg("muted", stringDetail(result, "path")), - ); - }, - }); - - // ── Read tool ──────────────────────────────────────────────────────── - - const ReadParams = Type.Object({ - path: Type.String({ - description: "Path to the file to read", - }), - offset: Type.Optional( - Type.Number({ - description: "Line number to start reading from (0-indexed)", - }), - ), - limit: Type.Optional( - Type.Number({ - description: "Maximum number of lines to read", - }), - ), - }); - - pi.registerTool({ - name: "Read", - label: "Read", - description: - "Read the contents of a file. Returns the file content with line numbers.", - parameters: ReadParams, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { exists: false, totalLines: 0 }); - } - - const content = readFileSync(filePath, "utf-8"); - const lines = content.split("\n"); - - const startLine = params.offset ?? 0; - const endLine = params.limit - ? Math.min(startLine + params.limit, lines.length) - : Math.min(startLine + 2000, lines.length); - - const selectedLines = lines.slice(startLine, endLine); - const numberedLines = selectedLines.map( - (line, i) => `${startLine + i + 1}\t${line}`, - ); - - let output = numberedLines.join("\n"); - if (endLine < lines.length) { - output += `\n\n[Showing lines ${startLine + 1}-${endLine} of ${lines.length} total lines. Use offset to see more.]`; - } - - if (output.length > MAX_OUTPUT_CHARS) { - output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; - } - - return { - content: [{ type: "text", text: output }], - details: { path: filePath, totalLines: lines.length }, - }; - } catch (error: unknown) { - return fileError(error, "Read", filePath, { - exists: false, - totalLines: 0, - }); - } - }, - renderCall(args, theme) { - const range = - args.offset !== undefined || args.limit !== undefined - ? theme.fg( - "muted", - ` (from ${args.offset ?? 0}${args.limit ? `, ${args.limit} lines` : ""})`, - ) - : ""; - return text( - theme.fg("toolTitle", theme.bold("Read ")) + - theme.fg("accent", args.path) + - range, - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - return renderResultText( - result, - expanded, - detailRecord(result).exists === false - ? theme.fg("error", "File not found") - : theme.fg("muted", `${numberDetail(result, "totalLines")} line(s)`), - ); - }, - }); - - // ── Write tool ─────────────────────────────────────────────────────── - - const WriteParams = Type.Object({ - path: Type.String({ - description: "Path to the file to write", - }), - content: Type.String({ - description: "Content to write to the file", - }), - }); - - pi.registerTool({ - name: "Write", - label: "Write", - description: - "Create or overwrite a file with the given content. Creates parent directories if needed.", - parameters: WriteParams, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - try { - mkdirSync(dirname(filePath), { recursive: true }); - writeFileSync(filePath, params.content, "utf-8"); - - return { - content: [ - { - type: "text", - text: `Successfully wrote ${params.content.length} bytes to ${params.path}`, - }, - ], - details: { path: filePath, bytesWritten: params.content.length }, - }; - } catch (error: unknown) { - const err = error as ToolError; - return { - content: [ - { - type: "text", - text: `Write error: ${err.message ?? "Unknown error"}`, - }, - ], - details: { path: filePath, bytesWritten: 0 }, - }; - } - }, - renderCall(args, theme) { - return text( - theme.fg("toolTitle", theme.bold("Write ")) + - theme.fg("accent", args.path), - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - return renderResultText( - result, - expanded, - theme.fg( - "muted", - `${numberDetail(result, "bytesWritten")} bytes written`, - ), - ); - }, - }); - - // ── StrReplace tool ────────────────────────────────────────────────── - - const StrReplaceParams = Type.Object({ - path: Type.String({ - description: "Path to the file to modify", - }), - old_str: Type.String({ - description: "String to search for (exact match)", - }), - new_str: Type.String({ - description: "String to replace with", - }), - }); - - pi.registerTool({ - name: "StrReplace", - label: "StrReplace", - description: - "Replace all occurrences of a string in a file. The old_str must be an exact match.", - parameters: StrReplaceParams, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { replacements: 0 }); - } - - const content = readFileSync(filePath, "utf-8"); - const count = content.split(params.old_str).length - 1; - - if (count === 0) { - return { - content: [ - { - type: "text", - text: `String not found in ${params.path}: "${params.old_str}"`, - }, - ], - details: { path: filePath, replacements: 0 }, - }; - } - - const newContent = content.replaceAll(params.old_str, params.new_str); - writeFileSync(filePath, newContent, "utf-8"); - - return { - content: [ - { - type: "text", - text: `Replaced ${count} occurrence(s) in ${params.path}`, - }, - ], - details: { path: filePath, replacements: count }, - }; - } catch (error: unknown) { - return fileError(error, "StrReplace", filePath, { replacements: 0 }); - } - }, - renderCall(args, theme) { - return text( - theme.fg("toolTitle", theme.bold("StrReplace ")) + - theme.fg("accent", args.path), - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - return renderResultText( - result, - expanded, - numberDetail(result, "replacements") === 0 - ? theme.fg("dim", "No replacements") - : theme.fg( - "muted", - `${numberDetail(result, "replacements")} replacement(s)`, - ), - ); - }, - }); - - // ── Delete tool ────────────────────────────────────────────────────── - - const DeleteParams = Type.Object({ - path: Type.String({ - description: "Path to the file to delete", - }), - }); - - pi.registerTool({ - name: "Delete", - label: "Delete", - description: "Delete a file from the filesystem.", - parameters: DeleteParams, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { deleted: false }); - } - - unlinkSync(filePath); - - return { - content: [ - { type: "text", text: `Successfully deleted ${params.path}` }, - ], - details: { path: filePath, deleted: true }, - }; - } catch (error: unknown) { - return fileError(error, "Delete", filePath, { deleted: false }); - } - }, - renderCall(args, theme) { - return text( - theme.fg("toolTitle", theme.bold("Delete ")) + - theme.fg("accent", args.path), - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - return renderResultText( - result, - expanded, - booleanDetail(result, "deleted") - ? theme.fg("muted", "Deleted") - : theme.fg("error", "Not deleted"), - ); - }, - }); - - // ── Shell tool ─────────────────────────────────────────────────────── - - const ShellParams = Type.Object({ - command: Type.String({ - description: "Shell command to execute", - }), - working_directory: Type.Optional( - Type.String({ - description: "Working directory for the command", - }), - ), - timeout: Type.Optional( - Type.Number({ - description: "Timeout in milliseconds (default: 120000)", - }), - ), - }); - - pi.registerTool({ - name: "Shell", - label: "Shell", - description: - "Execute a shell command and return stdout, stderr, and exit code.", - parameters: ShellParams, - - async execute(_toolCallId, params, signal, _onUpdate, ctx) { - const cwd = params.working_directory - ? resolve(ctx.cwd, params.working_directory) - : ctx.cwd; - const timeout = params.timeout ?? 120_000; - - try { - const { stdout, stderr } = await execFileAsync( - "bash", - ["-c", params.command], - { - cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, - timeout, - signal, - }, - ); - - let output = ""; - if (stdout) output += stdout; - if (stderr) output += `\n[stderr]\n${stderr}`; - - if (output.length > MAX_OUTPUT_CHARS) { - output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; - } - - return { - content: [{ type: "text", text: output || "(no output)" }], - details: { exitCode: 0, command: params.command }, - }; - } catch (error: unknown) { - const err = error as { - code?: number; - message?: string; - stdout?: string; - stderr?: string; - }; - - let output = ""; - if (err.stdout) output += err.stdout; - if (err.stderr) output += `\n[stderr]\n${err.stderr}`; - - if (output.length > MAX_OUTPUT_CHARS) { - output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; - } - - return { - content: [ - { - type: "text", - text: `Shell error (exit code ${err.code ?? "unknown"}): ${err.message ?? "Unknown error"}${output ? `\n${output}` : ""}`, - }, - ], - details: { - exitCode: err.code ?? 1, - command: params.command, - }, - }; - } - }, - renderCall(args, theme) { - const cwd = args.working_directory - ? theme.fg("muted", ` in ${args.working_directory}`) - : ""; - return text( - theme.fg("toolTitle", theme.bold("Shell ")) + - theme.fg("accent", args.command) + - cwd, - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - return renderResultText( - result, - expanded, - numberDetail(result, "exitCode") === 0 - ? theme.fg("muted", "Exit 0") - : theme.fg("warning", `Exit ${numberDetail(result, "exitCode")}`), - ); - }, - }); - - // ── Payload sanitization via event ──────────────────────────────────── - pi.on("before_provider_request", (event, ctx) => { - if (ctx.model?.provider !== "grok-cli") return; - - const modelId = ctx.model?.id ?? ""; - const sessionId = ctx.sessionManager?.getSessionId(); - return sanitizePayload( - event.payload as Record, - modelId, - sessionId, - ); - }); - - // ── /grok-cli-status command ───────────────────────────────────────── - pi.registerCommand("grok-cli-status", { - description: "Show Grok CLI provider status, quota, and token health", - handler: async (_args, ctx) => { - const token = process.env.GROK_CLI_OAUTH_TOKEN; - if (token) { - ctx.ui.notify( - "⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available", - "warning", - ); - } - - try { - const registry = ctx.modelRegistry; - const grokModels = registry - .getAll() - .filter((m: Model) => m.provider === "grok-cli"); - if (grokModels.length === 0) { - ctx.ui.notify( - "Grok CLI: no models registered. Run /login grok-cli first.", - "warning", - ); - return; - } - - const modelNames = grokModels - .slice(0, 5) - .map((m: Model) => m.id) - .join(", "); - const suffix = - grokModels.length > 5 ? ` (+${grokModels.length - 5} more)` : ""; - ctx.ui.notify( - `✓ Grok CLI: ${grokModels.length} models available (${modelNames}${suffix})`, - "info", - ); - - const lines = [ - " Quota:", - "", - ...formatQuota("grok-build", cachedRateLimits.get("grok-build")), - "", - ...formatQuota( - "grok-composer-2.5-fast", - cachedRateLimits.get("grok-composer-2.5-fast"), - ), - ]; - ctx.ui.notify(lines.join("\n"), "info"); - } catch (err) { - const msg = - err instanceof XaiOAuthError - ? `${err.message} (code: ${err.code})` - : err instanceof Error - ? err.message - : String(err); - ctx.ui.notify(`Grok CLI: ${msg}`, "warning"); - } - }, - }); - // ── Warn on env bypass ──────────────────────────────────────────────── - if (process.env.GROK_CLI_OAUTH_TOKEN) { - pi.on("session_start", async (_event, ctx) => { - ctx.ui.notify( - "[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery", - "warning", - ); - }); - } -} +export { default } from "./provider/register.js"; diff --git a/src/models.ts b/src/models/catalog.ts similarity index 100% rename from src/models.ts rename to src/models/catalog.ts diff --git a/src/sanitize.ts b/src/payload/sanitize.ts similarity index 99% rename from src/sanitize.ts rename to src/payload/sanitize.ts index 00804ab..aa952ac 100644 --- a/src/sanitize.ts +++ b/src/payload/sanitize.ts @@ -22,7 +22,7 @@ import { existsSync, readFileSync } from "node:fs"; import { extname, isAbsolute, resolve } from "node:path"; import { fileURLToPath } from "node:url"; -import { supportsReasoningEffort } from "./models.js"; +import { supportsReasoningEffort } from "../models/catalog.js"; // ─── Content text extraction ───────────────────────────────────────────────── diff --git a/src/provider/quota.ts b/src/provider/quota.ts new file mode 100644 index 0000000..7243f20 --- /dev/null +++ b/src/provider/quota.ts @@ -0,0 +1,144 @@ +import { existsSync, mkdirSync, readFileSync, writeFileSync } from "node:fs"; +import { homedir } from "node:os"; +import { dirname, join } from "node:path"; + +const QUOTA_CACHE_FILE = "grok-cli-quota.json"; + +// ─── Rate limit cache (piggybacks on onResponse from normal traffic) ────────── + +interface RateLimitInfo { + remainingRequests: number; + limitRequests: number; + remainingTokens: number; + limitTokens: number; + contextWindow: number; + zeroDataRetention: boolean; + capturedAt: number; +} + +const cachedRateLimits = new Map(); + +function quotaCachePath() { + return join(homedir(), ".pi", QUOTA_CACHE_FILE); +} + +function isRateLimitInfo(value: unknown): value is RateLimitInfo { + if (!value || typeof value !== "object") return false; + const info = value as Record; + return ( + typeof info.remainingRequests === "number" && + typeof info.limitRequests === "number" && + typeof info.remainingTokens === "number" && + typeof info.limitTokens === "number" && + typeof info.contextWindow === "number" && + typeof info.zeroDataRetention === "boolean" && + typeof info.capturedAt === "number" + ); +} + +export function loadQuotaCache() { + cachedRateLimits.clear(); + if (!existsSync(quotaCachePath())) return; + + try { + const payload = JSON.parse( + readFileSync(quotaCachePath(), "utf8"), + ) as Record; + const models = payload.models; + if (!models || typeof models !== "object") return; + + Object.entries(models).forEach(([model, rateLimit]) => { + if (isRateLimitInfo(rateLimit)) cachedRateLimits.set(model, rateLimit); + }); + } catch { + cachedRateLimits.clear(); + } +} + +function persistQuotaCache() { + try { + mkdirSync(dirname(quotaCachePath()), { recursive: true }); + writeFileSync( + quotaCachePath(), + JSON.stringify( + { version: 1, models: Object.fromEntries(cachedRateLimits) }, + null, + "\t", + ), + ); + } catch { + // Status remains cache-only; persistence failures should not break requests. + } +} + +/** + * Extract rate limit info from response headers. + * Returns undefined if no rate limit headers are present. + */ +function extractRateLimit( + h: Record, +): RateLimitInfo | undefined { + const remainingReqs = Number(h["x-ratelimit-remaining-requests"]); + const limitReqs = Number(h["x-ratelimit-limit-requests"]); + const remainingTokens = Number(h["x-ratelimit-remaining-tokens"]); + const limitTokens = Number(h["x-ratelimit-limit-tokens"]); + const contextWindow = Number(h["x-grok-context-window"]); + + if (Number.isNaN(remainingReqs) && Number.isNaN(remainingTokens)) + return undefined; + + return { + remainingRequests: remainingReqs, + limitRequests: limitReqs, + remainingTokens, + limitTokens, + contextWindow: contextWindow || 512_000, + zeroDataRetention: h["x-zero-data-retention"] === "true", + capturedAt: Date.now(), + }; +} + +export function formatQuota( + name: string, + rateLimit: RateLimitInfo | undefined, +) { + if (!rateLimit) { + return [ + ` ${name}:`, + " no cached quota data — make a request with this model first", + ]; + } + + const ageSec = Math.round((Date.now() - rateLimit.capturedAt) / 1000); + const ageStr = + ageSec < 60 ? `${ageSec}s ago` : `${Math.round(ageSec / 60)}m ago`; + const lines = [` ${name}:`]; + lines.push(` Cached: ${ageStr}`); + lines.push( + ` Requests: ${rateLimit.remainingRequests}/${rateLimit.limitRequests} remaining`, + ); + lines.push( + ` Tokens: ${rateLimit.remainingTokens.toLocaleString()}/${rateLimit.limitTokens.toLocaleString()} remaining`, + ); + lines.push( + ` Context Limit: ${rateLimit.contextWindow.toLocaleString()} tokens`, + ); + if (rateLimit.zeroDataRetention) { + lines.push(" Data: Zero retention ✓"); + } + return lines; +} + +export function captureRateLimit( + modelId: string, + headers: Record, +) { + const rateLimit = extractRateLimit(headers); + if (!rateLimit) return; + cachedRateLimits.set(modelId, rateLimit); + persistQuotaCache(); +} + +export function getCachedRateLimit(modelId: string): RateLimitInfo | undefined { + return cachedRateLimits.get(modelId); +} diff --git a/src/provider/register.ts b/src/provider/register.ts new file mode 100644 index 0000000..06de7eb --- /dev/null +++ b/src/provider/register.ts @@ -0,0 +1,104 @@ +import type { + Api, + Model, + OAuthCredentials, + OAuthLoginCallbacks, +} from "@earendil-works/pi-ai"; +import type { + ExtensionAPI, + ProviderConfig, +} from "@earendil-works/pi-coding-agent"; +import * as oauth from "../auth/oauth.js"; +import { getBaseUrl, type XaiOAuthCredentials } from "../auth/oauth.js"; +import { type GrokCliModelConfig, resolveModels } from "../models/catalog.js"; +import { sanitizePayload } from "../payload/sanitize.js"; +import { registerGrokTools } from "../tools/register.js"; +import { loadQuotaCache } from "./quota.js"; +import { registerStatusCommand } from "./status.js"; +import { streamGrokCli } from "./stream.js"; +import { syncGrokTools } from "./toolScope.js"; + +export default function registerGrokCli(pi: ExtensionAPI) { + loadQuotaCache(); + const baseUrl = getBaseUrl(); + const models = resolveModels(); + + pi.on("model_select", (event) => { + syncGrokTools(pi, event.model.provider); + }); + + pi.on("before_agent_start", (_event, ctx) => { + syncGrokTools(pi, ctx.model?.provider); + }); + + pi.registerProvider("grok-cli", { + name: "Grok CLI", + baseUrl, + apiKey: "$GROK_CLI_OAUTH_TOKEN", + api: "openai-responses", + models: models.map((m: GrokCliModelConfig) => ({ + id: m.id, + name: m.name, + reasoning: m.reasoning, + thinkingLevelMap: m.thinkingLevelMap, + input: m.input, + cost: m.cost, + contextWindow: m.contextWindow, + maxTokens: m.maxTokens, + })), + oauth: { + name: "Grok CLI", + + async login(callbacks: OAuthLoginCallbacks): Promise { + return oauth.login(callbacks); + }, + + async refreshToken( + credentials: OAuthCredentials, + ): Promise { + return oauth.refresh(credentials); + }, + + getApiKey(credentials: OAuthCredentials): string { + return credentials.access; + }, + + modifyModels(models: Model[], credentials: OAuthCredentials) { + const effectiveBaseUrl = String( + (credentials as XaiOAuthCredentials).baseUrl ?? getBaseUrl(), + ).replace(/\/+$/, ""); + + return models.map((m) => + m.provider === "grok-cli" ? { ...m, baseUrl: effectiveBaseUrl } : m, + ); + }, + } satisfies ProviderConfig["oauth"], + + streamSimple: streamGrokCli, + }); + + registerGrokTools(pi); + + pi.on("before_provider_request", (event, ctx) => { + if (ctx.model?.provider !== "grok-cli") return; + + const modelId = ctx.model?.id ?? ""; + const sessionId = ctx.sessionManager?.getSessionId(); + return sanitizePayload( + event.payload as Record, + modelId, + sessionId, + ); + }); + + registerStatusCommand(pi); + + if (process.env.GROK_CLI_OAUTH_TOKEN) { + pi.on("session_start", async (_event, ctx) => { + ctx.ui.notify( + "[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery", + "warning", + ); + }); + } +} diff --git a/src/provider/status.ts b/src/provider/status.ts new file mode 100644 index 0000000..b3f6e2e --- /dev/null +++ b/src/provider/status.ts @@ -0,0 +1,66 @@ +import type { Api, Model } from "@earendil-works/pi-ai"; +import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { XaiOAuthError } from "../shared/errors.js"; +import { formatQuota, getCachedRateLimit } from "./quota.js"; + +export function registerStatusCommand( + pi: Pick, +) { + pi.registerCommand("grok-cli-status", { + description: "Show Grok CLI provider status, quota, and token health", + handler: async (_args, ctx) => { + const token = process.env.GROK_CLI_OAUTH_TOKEN; + if (token) { + ctx.ui.notify( + "⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available", + "warning", + ); + } + + try { + const registry = ctx.modelRegistry; + const grokModels = registry + .getAll() + .filter((m: Model) => m.provider === "grok-cli"); + if (grokModels.length === 0) { + ctx.ui.notify( + "Grok CLI: no models registered. Run /login grok-cli first.", + "warning", + ); + return; + } + + const modelNames = grokModels + .slice(0, 5) + .map((m: Model) => m.id) + .join(", "); + const suffix = + grokModels.length > 5 ? ` (+${grokModels.length - 5} more)` : ""; + ctx.ui.notify( + `✓ Grok CLI: ${grokModels.length} models available (${modelNames}${suffix})`, + "info", + ); + + const lines = [ + " Quota:", + "", + ...formatQuota("grok-build", getCachedRateLimit("grok-build")), + "", + ...formatQuota( + "grok-composer-2.5-fast", + getCachedRateLimit("grok-composer-2.5-fast"), + ), + ]; + ctx.ui.notify(lines.join("\n"), "info"); + } catch (err) { + const msg = + err instanceof XaiOAuthError + ? `${err.message} (code: ${err.code})` + : err instanceof Error + ? err.message + : String(err); + ctx.ui.notify(`Grok CLI: ${msg}`, "warning"); + } + }, + }); +} diff --git a/src/provider/stream.ts b/src/provider/stream.ts new file mode 100644 index 0000000..788b936 --- /dev/null +++ b/src/provider/stream.ts @@ -0,0 +1,53 @@ +import { + type Api, + type AssistantMessageEventStream, + type Context, + type Model, + type SimpleStreamOptions, + streamSimpleOpenAIResponses, +} from "@earendil-works/pi-ai"; +import { captureRateLimit } from "./quota.js"; + +const GROK_CLI_VERSION = "0.2.16"; + +/** + * Stream function that adds Grok CLI-specific headers to requests. + * + * The real Grok CLI sends these headers: + * - x-grok-client-identifier: grok-shell + * - x-grok-client-version: 0.2.16 + * - x-grok-conv-id: + * - x-grok-model-override: + * - x-xai-token-auth: xai-grok-cli + */ +export function streamGrokCli( + model: Model, + context: Context, + options?: SimpleStreamOptions, +): AssistantMessageEventStream { + const sessionId = options?.sessionId; + const headers: Record = { + ...options?.headers, + "x-grok-client-identifier": "pi-grok-cli", + "x-grok-client-version": GROK_CLI_VERSION, + "x-xai-token-auth": "xai-grok-cli", + "x-grok-model-override": model.id, + }; + + if (sessionId) { + headers["x-grok-conv-id"] = sessionId; + } + + return streamSimpleOpenAIResponses( + model as Model<"openai-responses">, + context, + { + ...options, + headers, + onResponse(response) { + captureRateLimit(model.id, response.headers); + options?.onResponse?.(response, model); + }, + }, + ); +} diff --git a/src/provider/toolScope.ts b/src/provider/toolScope.ts new file mode 100644 index 0000000..4d24ebd --- /dev/null +++ b/src/provider/toolScope.ts @@ -0,0 +1,33 @@ +import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; + +export const GROK_TOOL_NAMES = [ + "Grep", + "Glob", + "LS", + "Read", + "Write", + "StrReplace", + "Delete", + "Shell", +]; + +export function syncGrokTools( + pi: Pick, + provider: string | undefined, +) { + const currentTools = pi.getActiveTools(); + const baseTools = currentTools.filter( + (toolName) => !GROK_TOOL_NAMES.includes(toolName), + ); + const nextTools = + provider === "grok-cli" ? [...baseTools, ...GROK_TOOL_NAMES] : baseTools; + + if ( + currentTools.length === nextTools.length && + currentTools.every((toolName, i) => toolName === nextTools[i]) + ) { + return; + } + + pi.setActiveTools(nextTools); +} diff --git a/src/errors.ts b/src/shared/errors.ts similarity index 100% rename from src/errors.ts rename to src/shared/errors.ts diff --git a/src/tools/files.ts b/src/tools/files.ts new file mode 100644 index 0000000..4dc5dc1 --- /dev/null +++ b/src/tools/files.ts @@ -0,0 +1,385 @@ +import { execFile } from "node:child_process"; +import { + existsSync, + mkdirSync, + readFileSync, + unlinkSync, + writeFileSync, +} from "node:fs"; +import { dirname, resolve } from "node:path"; +import { promisify } from "node:util"; +import { Type } from "@earendil-works/pi-ai"; +import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { + booleanDetail, + detailRecord, + fileError, + fileNotFound, + MAX_OUTPUT_CHARS, + numberDetail, + renderResultSummary, + stringDetail, + type ToolError, + text, +} from "./rendering.js"; + +const execFileAsync = promisify(execFile); + +type ToolTheme = { + bold: (text: string) => string; + fg: (name: "accent" | "toolTitle", text: string) => string; +}; + +function renderPathToolCall( + toolName: string, + filePath: string, + theme: ToolTheme, +) { + return text( + theme.fg("toolTitle", theme.bold(`${toolName} `)) + + theme.fg("accent", filePath), + ); +} + +export function registerFileTools(pi: ExtensionAPI) { + // ── LS tool ────────────────────────────────────────────────────────── + + const LsParams = Type.Object({ + path: Type.String({ + description: "Directory path to list", + }), + }); + + pi.registerTool({ + name: "LS", + label: "LS", + description: "List the contents of a directory, including hidden files.", + parameters: LsParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const targetPath = resolve(ctx.cwd, params.path); + + try { + const { stdout } = await execFileAsync("ls", ["-la", targetPath], { + cwd: ctx.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal, + }); + + let output = stdout.trim(); + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[LS: output truncated at 50KB]`; + } + + return { + content: [{ type: "text", text: output }], + details: { path: targetPath }, + }; + } catch (error: unknown) { + const err = error as ToolError; + return { + content: [ + { + type: "text", + text: `LS error: ${err.message ?? "Unknown error"}`, + }, + ], + details: { path: targetPath }, + }; + } + }, + renderCall(args, theme) { + return renderPathToolCall("LS", args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + theme.fg("muted", stringDetail(result, "path")), + ); + }, + }); + + // ── Read tool ──────────────────────────────────────────────────────── + + const ReadParams = Type.Object({ + path: Type.String({ + description: "Path to the file to read", + }), + offset: Type.Optional( + Type.Number({ + description: "Line number to start reading from (0-indexed)", + }), + ), + limit: Type.Optional( + Type.Number({ + description: "Maximum number of lines to read", + }), + ), + }); + + pi.registerTool({ + name: "Read", + label: "Read", + description: + "Read the contents of a file. Returns the file content with line numbers.", + parameters: ReadParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { exists: false, totalLines: 0 }); + } + + const content = readFileSync(filePath, "utf-8"); + const lines = content.split("\n"); + + const startLine = params.offset ?? 0; + const endLine = params.limit + ? Math.min(startLine + params.limit, lines.length) + : Math.min(startLine + 2000, lines.length); + + const selectedLines = lines.slice(startLine, endLine); + const numberedLines = selectedLines.map( + (line, i) => `${startLine + i + 1}\t${line}`, + ); + + let output = numberedLines.join("\n"); + if (endLine < lines.length) { + output += `\n\n[Showing lines ${startLine + 1}-${endLine} of ${lines.length} total lines. Use offset to see more.]`; + } + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [{ type: "text", text: output }], + details: { path: filePath, totalLines: lines.length }, + }; + } catch (error: unknown) { + return fileError(error, "Read", filePath, { + exists: false, + totalLines: 0, + }); + } + }, + renderCall(args, theme) { + const range = + args.offset !== undefined || args.limit !== undefined + ? theme.fg( + "muted", + ` (from ${args.offset ?? 0}${args.limit ? `, ${args.limit} lines` : ""})`, + ) + : ""; + return text( + theme.fg("toolTitle", theme.bold("Read ")) + + theme.fg("accent", args.path) + + range, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + detailRecord(result).exists === false + ? theme.fg("error", "File not found") + : theme.fg("muted", `${numberDetail(result, "totalLines")} line(s)`), + ); + }, + }); + + // ── Write tool ─────────────────────────────────────────────────────── + + const WriteParams = Type.Object({ + path: Type.String({ + description: "Path to the file to write", + }), + content: Type.String({ + description: "Content to write to the file", + }), + }); + + pi.registerTool({ + name: "Write", + label: "Write", + description: + "Create or overwrite a file with the given content. Creates parent directories if needed.", + parameters: WriteParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + mkdirSync(dirname(filePath), { recursive: true }); + writeFileSync(filePath, params.content, "utf-8"); + + return { + content: [ + { + type: "text", + text: `Successfully wrote ${params.content.length} bytes to ${params.path}`, + }, + ], + details: { path: filePath, bytesWritten: params.content.length }, + }; + } catch (error: unknown) { + const err = error as ToolError; + return { + content: [ + { + type: "text", + text: `Write error: ${err.message ?? "Unknown error"}`, + }, + ], + details: { path: filePath, bytesWritten: 0 }, + }; + } + }, + renderCall(args, theme) { + return renderPathToolCall("Write", args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + theme.fg( + "muted", + `${numberDetail(result, "bytesWritten")} bytes written`, + ), + ); + }, + }); + + // ── StrReplace tool ────────────────────────────────────────────────── + + const StrReplaceParams = Type.Object({ + path: Type.String({ + description: "Path to the file to modify", + }), + old_str: Type.String({ + description: "String to search for (exact match)", + }), + new_str: Type.String({ + description: "String to replace with", + }), + }); + + pi.registerTool({ + name: "StrReplace", + label: "StrReplace", + description: + "Replace all occurrences of a string in a file. The old_str must be an exact match.", + parameters: StrReplaceParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { replacements: 0 }); + } + + const content = readFileSync(filePath, "utf-8"); + const count = content.split(params.old_str).length - 1; + + if (count === 0) { + return { + content: [ + { + type: "text", + text: `String not found in ${params.path}: "${params.old_str}"`, + }, + ], + details: { path: filePath, replacements: 0 }, + }; + } + + const newContent = content.replaceAll(params.old_str, params.new_str); + writeFileSync(filePath, newContent, "utf-8"); + + return { + content: [ + { + type: "text", + text: `Replaced ${count} occurrence(s) in ${params.path}`, + }, + ], + details: { path: filePath, replacements: count }, + }; + } catch (error: unknown) { + return fileError(error, "StrReplace", filePath, { replacements: 0 }); + } + }, + renderCall(args, theme) { + return renderPathToolCall("StrReplace", args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + numberDetail(result, "replacements") === 0 + ? theme.fg("dim", "No replacements") + : theme.fg( + "muted", + `${numberDetail(result, "replacements")} replacement(s)`, + ), + ); + }, + }); + + // ── Delete tool ────────────────────────────────────────────────────── + + const DeleteParams = Type.Object({ + path: Type.String({ + description: "Path to the file to delete", + }), + }); + + pi.registerTool({ + name: "Delete", + label: "Delete", + description: "Delete a file from the filesystem.", + parameters: DeleteParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { deleted: false }); + } + + unlinkSync(filePath); + + return { + content: [ + { type: "text", text: `Successfully deleted ${params.path}` }, + ], + details: { path: filePath, deleted: true }, + }; + } catch (error: unknown) { + return fileError(error, "Delete", filePath, { deleted: false }); + } + }, + renderCall(args, theme) { + return renderPathToolCall("Delete", args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + booleanDetail(result, "deleted") + ? theme.fg("muted", "Deleted") + : theme.fg("error", "Not deleted"), + ); + }, + }); +} diff --git a/src/tools/register.ts b/src/tools/register.ts new file mode 100644 index 0000000..07311b7 --- /dev/null +++ b/src/tools/register.ts @@ -0,0 +1,10 @@ +import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { registerFileTools } from "./files.js"; +import { registerSearchTools } from "./search.js"; +import { registerShellTool } from "./shell.js"; + +export function registerGrokTools(pi: ExtensionAPI) { + registerSearchTools(pi); + registerFileTools(pi); + registerShellTool(pi); +} diff --git a/src/tools/rendering.ts b/src/tools/rendering.ts new file mode 100644 index 0000000..3c5f4f9 --- /dev/null +++ b/src/tools/rendering.ts @@ -0,0 +1,185 @@ +import { execFile } from "node:child_process"; +import { promisify } from "node:util"; +import { Text } from "@earendil-works/pi-tui"; + +const execFileAsync = promisify(execFile); + +export const MAX_OUTPUT_CHARS = 50_000; +export const MAX_LINES = 500; + +export function truncateLines(lines: string[]): string { + if (lines.length > MAX_LINES) { + return ( + lines.slice(0, MAX_LINES).join("\n") + + `\n\n[Showing first ${MAX_LINES} of ${lines.length} results. Refine your pattern to narrow results.]` + ); + } + return lines.join("\n"); +} + +export function truncateChars(output: string): string { + if (output.length > MAX_OUTPUT_CHARS) { + return `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + return output; +} + +let rgAvailable: boolean | undefined; +export async function hasRipgrep(): Promise { + if (rgAvailable !== undefined) return rgAvailable; + try { + await execFileAsync("rg", ["--version"]); + rgAvailable = true; + } catch { + rgAvailable = false; + } + return rgAvailable; +} + +export type ToolError = { code?: number; message?: string }; +export type ToolResult = { + content: [{ type: "text"; text: string }]; + details: T; +}; + +export function text(text: string): Text { + return new Text(text, 0, 0); +} + +function firstText(result: { content: { type: string; text?: string }[] }) { + const first = result.content[0]; + if (first?.type !== "text") return undefined; + return first.text; +} + +export function renderResultText( + result: { content: { type: string; text?: string }[] }, + expanded: boolean, + summary: string, +): Text { + if (expanded) return text(firstText(result) ?? summary); + return text(summary); +} + +export function renderRunning(isPartial: boolean): Text | undefined { + if (!isPartial) return undefined; + return text("Running..."); +} + +export function renderResultSummary( + result: { content: { type: string; text?: string }[] }, + expanded: boolean, + isPartial: boolean, + summary: string, +): Text { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText(result, expanded, summary); +} + +export function detailRecord(result: { + details: unknown; +}): Record { + if (!result.details || typeof result.details !== "object") return {}; + return result.details as Record; +} + +export function numberDetail( + result: { details: unknown }, + key: string, +): number { + const value = detailRecord(result)[key]; + if (typeof value !== "number") return 0; + return value; +} + +export function stringDetail( + result: { details: unknown }, + key: string, +): string { + const value = detailRecord(result)[key]; + if (typeof value !== "string") return ""; + return value; +} + +export function booleanDetail( + result: { details: unknown }, + key: string, +): boolean { + const value = detailRecord(result)[key]; + return value === true; +} + +type FileDetails = { path: string; [key: string]: unknown }; + +export function fileNotFound( + filePath: string, + extraDetails: Omit, +): ToolResult { + return { + content: [{ type: "text", text: `File not found: ${filePath}` }], + details: { path: filePath, ...extraDetails } as T, + }; +} + +export function fileError( + error: unknown, + toolName: string, + filePath: string, + extraDetails: Omit, +): ToolResult { + const err = error as ToolError; + return { + content: [ + { + type: "text", + text: `${toolName} error: ${err.message ?? "Unknown error"}`, + }, + ], + details: { path: filePath, ...extraDetails } as T, + }; +} + +export function toolError( + error: unknown, + toolName: string, + emptyDetails: T, +): ToolResult { + const err = error as ToolError; + if (err.code === 1) { + return { + content: [{ type: "text", text: "No matches found" }], + details: emptyDetails, + }; + } + return { + content: [ + { + type: "text", + text: `${toolName} error: ${err.message ?? "Unknown error"}`, + }, + ], + details: emptyDetails, + }; +} + +export async function execWithRgFallback( + rgArgs: string[], + grepArgs: string[], + options: { cwd: string; signal?: AbortSignal }, +): Promise { + if (await hasRipgrep()) { + const result = await execFileAsync("rg", rgArgs, { + cwd: options.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal: options.signal, + }); + return result.stdout; + } + const result = await execFileAsync("grep", grepArgs, { + cwd: options.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal: options.signal, + }); + return result.stdout; +} diff --git a/src/tools/search.ts b/src/tools/search.ts new file mode 100644 index 0000000..34b422d --- /dev/null +++ b/src/tools/search.ts @@ -0,0 +1,187 @@ +import { execFile } from "node:child_process"; +import { resolve } from "node:path"; +import { promisify } from "node:util"; +import { Type } from "@earendil-works/pi-ai"; +import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { + execWithRgFallback, + hasRipgrep, + MAX_OUTPUT_CHARS, + numberDetail, + renderResultText, + renderRunning, + text, + toolError, + truncateChars, + truncateLines, +} from "./rendering.js"; + +const execFileAsync = promisify(execFile); + +export function registerSearchTools(pi: ExtensionAPI) { + const GrepParams = Type.Object({ + pattern: Type.String({ + description: "Regex pattern to search for in file contents", + }), + path: Type.Optional( + Type.String({ + description: + "Directory or file to search. Defaults to current working directory.", + }), + ), + include: Type.Optional( + Type.String({ + description: + "Glob pattern to filter which files are searched (e.g. *.ts, **/*.md)", + }), + ), + }); + + pi.registerTool({ + name: "Grep", + label: "Grep", + description: + "Search for a regex pattern in file contents. Returns matching lines with file path and line number. Use the include parameter to filter by file type.", + parameters: GrepParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const searchPath = resolve(ctx.cwd, params.path ?? "."); + + try { + const rgArgs = ["-n", "--no-heading", "--color=never"]; + if (params.include) rgArgs.push("--glob", params.include); + rgArgs.push(params.pattern, searchPath); + + const grepArgs = ["-r", "-n", "--color=never"]; + if (params.include) grepArgs.push(`--include=${params.include}`); + grepArgs.push(params.pattern, searchPath); + + const stdout = await execWithRgFallback(rgArgs, grepArgs, { + cwd: ctx.cwd, + signal, + }); + + const lines = stdout.trim().split("\n").filter(Boolean); + if (lines.length === 0) { + return { + content: [{ type: "text", text: "No matches found" }], + details: { matchCount: 0 }, + }; + } + + return { + content: [ + { type: "text", text: truncateChars(truncateLines(lines)) }, + ], + details: { matchCount: lines.length }, + }; + } catch (error: unknown) { + return toolError(error, "Grep", { matchCount: 0 }); + } + }, + renderCall(args, theme) { + const path = args.path ? theme.fg("muted", ` in ${args.path}`) : ""; + const include = args.include ? theme.fg("dim", ` [${args.include}]`) : ""; + return text( + theme.fg("toolTitle", theme.bold("Grep ")) + + theme.fg("accent", `"${args.pattern}"`) + + path + + include, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const matchCount = numberDetail(result, "matchCount"); + return renderResultText( + result, + expanded, + matchCount === 0 + ? theme.fg("dim", "No matches") + : theme.fg("muted", `${matchCount} match(es)`), + ); + }, + }); + + const GlobParams = Type.Object({ + pattern: Type.String({ + description: "Glob pattern to match files (e.g. **/*.ts, src/**/*.json)", + }), + path: Type.Optional( + Type.String({ + description: + "Directory to search within. Defaults to current working directory.", + }), + ), + }); + + pi.registerTool({ + name: "Glob", + label: "Glob", + description: + "Find files matching a glob pattern. Returns a list of matching file paths sorted by modification time (newest first).", + parameters: GlobParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const searchPath = resolve(ctx.cwd, params.path ?? "."); + + try { + let files: string[]; + + if (await hasRipgrep()) { + const result = await execFileAsync( + "rg", + ["--files", "--color=never", "--glob", params.pattern, searchPath], + { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, + ); + files = result.stdout.trim().split("\n").filter(Boolean); + } else { + // find fallback — convert **/*.ext → -name "*.ext" + const basename = params.pattern.replace(/^(\*\*\/)+/, ""); + const result = await execFileAsync( + "find", + [searchPath, "-type", "f", "-name", basename], + { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, + ); + files = result.stdout.trim().split("\n").filter(Boolean); + } + + if (files.length === 0) { + return { + content: [{ type: "text", text: "No files found" }], + details: { fileCount: 0 }, + }; + } + + return { + content: [ + { type: "text", text: truncateChars(truncateLines(files)) }, + ], + details: { fileCount: files.length }, + }; + } catch (error: unknown) { + return toolError(error, "Glob", { fileCount: 0 }); + } + }, + renderCall(args, theme) { + const path = args.path ? theme.fg("muted", ` in ${args.path}`) : ""; + return text( + theme.fg("toolTitle", theme.bold("Glob ")) + + theme.fg("accent", args.pattern) + + path, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const fileCount = numberDetail(result, "fileCount"); + return renderResultText( + result, + expanded, + fileCount === 0 + ? theme.fg("dim", "No files") + : theme.fg("muted", `${fileCount} file(s)`), + ); + }, + }); +} diff --git a/src/tools/shell.ts b/src/tools/shell.ts new file mode 100644 index 0000000..631e259 --- /dev/null +++ b/src/tools/shell.ts @@ -0,0 +1,124 @@ +import { execFile } from "node:child_process"; +import { resolve } from "node:path"; +import { promisify } from "node:util"; +import { Type } from "@earendil-works/pi-ai"; +import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { + MAX_OUTPUT_CHARS, + numberDetail, + renderResultText, + renderRunning, + text, +} from "./rendering.js"; + +const execFileAsync = promisify(execFile); + +export function registerShellTool(pi: ExtensionAPI) { + // ── Shell tool ─────────────────────────────────────────────────────── + + const ShellParams = Type.Object({ + command: Type.String({ + description: "Shell command to execute", + }), + working_directory: Type.Optional( + Type.String({ + description: "Working directory for the command", + }), + ), + timeout: Type.Optional( + Type.Number({ + description: "Timeout in milliseconds (default: 120000)", + }), + ), + }); + + pi.registerTool({ + name: "Shell", + label: "Shell", + description: + "Execute a shell command and return stdout, stderr, and exit code.", + parameters: ShellParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const cwd = params.working_directory + ? resolve(ctx.cwd, params.working_directory) + : ctx.cwd; + const timeout = params.timeout ?? 120_000; + + try { + const { stdout, stderr } = await execFileAsync( + "bash", + ["-c", params.command], + { + cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + timeout, + signal, + }, + ); + + let output = ""; + if (stdout) output += stdout; + if (stderr) output += `\n[stderr]\n${stderr}`; + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [{ type: "text", text: output || "(no output)" }], + details: { exitCode: 0, command: params.command }, + }; + } catch (error: unknown) { + const err = error as { + code?: number; + message?: string; + stdout?: string; + stderr?: string; + }; + + let output = ""; + if (err.stdout) output += err.stdout; + if (err.stderr) output += `\n[stderr]\n${err.stderr}`; + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [ + { + type: "text", + text: `Shell error (exit code ${err.code ?? "unknown"}): ${err.message ?? "Unknown error"}${output ? `\n${output}` : ""}`, + }, + ], + details: { + exitCode: err.code ?? 1, + command: params.command, + }, + }; + } + }, + renderCall(args, theme) { + const cwd = args.working_directory + ? theme.fg("muted", ` in ${args.working_directory}`) + : ""; + return text( + theme.fg("toolTitle", theme.bold("Shell ")) + + theme.fg("accent", args.command) + + cwd, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText( + result, + expanded, + numberDetail(result, "exitCode") === 0 + ? theme.fg("muted", "Exit 0") + : theme.fg("warning", `Exit ${numberDetail(result, "exitCode")}`), + ); + }, + }); +} diff --git a/tests/oauth.test.ts b/tests/auth/oauth.test.ts similarity index 98% rename from tests/oauth.test.ts rename to tests/auth/oauth.test.ts index e11023e..d1ce809 100644 --- a/tests/oauth.test.ts +++ b/tests/auth/oauth.test.ts @@ -1,6 +1,6 @@ import { afterEach, describe, expect, it, vi } from "vitest"; -import { XaiErrorCode } from "../src/errors.js"; -import { getBaseUrl, login, refresh } from "../src/oauth.js"; +import { getBaseUrl, login, refresh } from "../../src/auth/oauth.js"; +import { XaiErrorCode } from "../../src/shared/errors.js"; const originalEnv = { ...process.env }; const originalFetch = globalThis.fetch; diff --git a/tests/models.test.ts b/tests/models/catalog.test.ts similarity index 93% rename from tests/models.test.ts rename to tests/models/catalog.test.ts index 932a206..407719c 100644 --- a/tests/models.test.ts +++ b/tests/models/catalog.test.ts @@ -1,5 +1,8 @@ import { afterEach, describe, expect, it } from "vitest"; -import { resolveModels, supportsReasoningEffort } from "../src/models.js"; +import { + resolveModels, + supportsReasoningEffort, +} from "../../src/models/catalog.js"; const originalEnv = { ...process.env }; diff --git a/tests/sanitize.test.ts b/tests/payload/sanitize.test.ts similarity index 98% rename from tests/sanitize.test.ts rename to tests/payload/sanitize.test.ts index 4358dc8..ff74434 100644 --- a/tests/sanitize.test.ts +++ b/tests/payload/sanitize.test.ts @@ -2,7 +2,7 @@ import { mkdtempSync, rmSync, writeFileSync } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; import { describe, expect, it } from "vitest"; -import { sanitizePayload } from "../src/sanitize.js"; +import { sanitizePayload } from "../../src/payload/sanitize.js"; describe("payload sanitization", () => { it("removes unsupported items and moves leading instructions", () => { diff --git a/tests/package.test.ts b/tests/provider/package.test.ts similarity index 52% rename from tests/package.test.ts rename to tests/provider/package.test.ts index 215dd7e..8bb6f04 100644 --- a/tests/package.test.ts +++ b/tests/provider/package.test.ts @@ -2,7 +2,7 @@ import { existsSync, globSync, readFileSync } from "node:fs"; import { describe, expect, it } from "vitest"; const packageJson = JSON.parse( - readFileSync(new URL("../package.json", import.meta.url), "utf8"), + readFileSync(new URL("../../package.json", import.meta.url), "utf8"), ); describe("npm package manifest", () => { @@ -25,34 +25,44 @@ describe("npm package manifest", () => { ); expect(packageJson.devDependencies?.vitest).toBeDefined(); expect(packageJson.devDependencies?.["@vitest/coverage-v8"]).toBeDefined(); - expect(existsSync(new URL("../vitest.config.ts", import.meta.url))).toBe( + expect(existsSync(new URL("../../vitest.config.ts", import.meta.url))).toBe( true, ); }); }); describe("repository layout", () => { - it("keeps extension source files under src", () => { - for (const file of [ - "index.ts", - "models.ts", - "oauth.ts", - "sanitize.ts", - "errors.ts", - ]) { - expect(existsSync(new URL(`../src/${file}`, import.meta.url))).toBe(true); - expect(existsSync(new URL(`../${file}`, import.meta.url))).toBe(false); - } + it("keeps the extension entrypoint at src/index.ts", () => { + expect(existsSync(new URL("../../src/index.ts", import.meta.url))).toBe( + true, + ); }); - it("contains the expected source files", () => { - const sourceFiles = globSync("src/*.ts").sort(); - expect(sourceFiles).toEqual([ - "src/errors.ts", + it("contains the expected domain source files", () => { + expect(globSync("src/**/*.ts").sort()).toEqual([ + "src/auth/oauth.ts", "src/index.ts", - "src/models.ts", - "src/oauth.ts", - "src/sanitize.ts", + "src/models/catalog.ts", + "src/payload/sanitize.ts", + "src/provider/quota.ts", + "src/provider/register.ts", + "src/provider/status.ts", + "src/provider/stream.ts", + "src/provider/toolScope.ts", + "src/shared/errors.ts", + "src/tools/files.ts", + "src/tools/register.ts", + "src/tools/rendering.ts", + "src/tools/search.ts", + "src/tools/shell.ts", ]); }); + + it("does not keep top-level helper compatibility wrappers", () => { + for (const file of ["errors.ts", "models.ts", "oauth.ts", "sanitize.ts"]) { + expect(existsSync(new URL(`../../src/${file}`, import.meta.url))).toBe( + false, + ); + } + }); }); diff --git a/tests/index.test.ts b/tests/provider/register.test.ts similarity index 99% rename from tests/index.test.ts rename to tests/provider/register.test.ts index 2c420d7..d87f3d4 100644 --- a/tests/index.test.ts +++ b/tests/provider/register.test.ts @@ -109,7 +109,7 @@ async function setupExtension(initialActiveTools = ["read", "bash"]) { const setActiveTools = vi.fn((toolNames: string[]) => { activeTools = toolNames; }); - const registerGrokCli = (await import("../src/index.js")).default; + const registerGrokCli = (await import("../../src/index.js")).default; registerGrokCli({ registerProvider(name: string, config: ProviderConfig) { providers.set(name, config); diff --git a/tests/errors.test.ts b/tests/shared/errors.test.ts similarity index 87% rename from tests/errors.test.ts rename to tests/shared/errors.test.ts index 9e35aac..8d0d27d 100644 --- a/tests/errors.test.ts +++ b/tests/shared/errors.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from "vitest"; -import { XaiErrorCode, XaiOAuthError } from "../src/errors.js"; +import { XaiErrorCode, XaiOAuthError } from "../../src/shared/errors.js"; describe("OAuth errors", () => { it("keeps machine-readable code and relogin state", () => { diff --git a/tests/tools/register.test.ts b/tests/tools/register.test.ts new file mode 100644 index 0000000..8916c62 --- /dev/null +++ b/tests/tools/register.test.ts @@ -0,0 +1,32 @@ +import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { describe, expect, it } from "vitest"; +import { registerGrokTools } from "../../src/tools/register.js"; + +describe("Grok tool registration", () => { + it("registers all Grok/Cursor-native tool shims with renderers", () => { + const toolNames: string[] = []; + + registerGrokTools({ + registerTool(tool: { + name: string; + renderCall?: unknown; + renderResult?: unknown; + }) { + toolNames.push(tool.name); + expect(tool.renderCall).toBeTypeOf("function"); + expect(tool.renderResult).toBeTypeOf("function"); + }, + } as unknown as ExtensionAPI); + + expect(toolNames.sort()).toEqual([ + "Delete", + "Glob", + "Grep", + "LS", + "Read", + "Shell", + "StrReplace", + "Write", + ]); + }); +}); From 65122488fc4dd3b737366352598fbc8d9d7b8d68 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Tue, 2 Jun 2026 19:28:43 +0900 Subject: [PATCH 04/24] test: add context window validation for default model configurations in catalog tests --- src/models/catalog.ts | 4 ++-- tests/models/catalog.test.ts | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/models/catalog.ts b/src/models/catalog.ts index 300981e..31334ee 100644 --- a/src/models/catalog.ts +++ b/src/models/catalog.ts @@ -40,7 +40,7 @@ const FALLBACK_MODELS: GrokCliModelConfig[] = [ reasoning: false, input: ["text", "image"], cost: COST_COMPOSER, - contextWindow: 512_000, + contextWindow: 200_000, maxTokens: 30_000, thinkingLevelMap: { off: "none", @@ -57,7 +57,7 @@ const FALLBACK_MODELS: GrokCliModelConfig[] = [ reasoning: true, input: ["text", "image"], cost: COST_BUILD, - contextWindow: 1_000_000, + contextWindow: 512_000, maxTokens: 30_000, }, { diff --git a/tests/models/catalog.test.ts b/tests/models/catalog.test.ts index 407719c..79c8bf7 100644 --- a/tests/models/catalog.test.ts +++ b/tests/models/catalog.test.ts @@ -22,7 +22,9 @@ describe("model catalog", () => { it("uses fallback models when no override is configured", () => { delete process.env.PI_GROK_CLI_MODELS; - expect(resolveModels().map((model) => model.id)).toEqual([ + const models = resolveModels(); + + expect(models.map((model) => model.id)).toEqual([ "grok-composer-2.5-fast", "grok-build", "grok-4.3", @@ -30,6 +32,12 @@ describe("model catalog", () => { "grok-4.20-0309-non-reasoning", "grok-4.20-multi-agent-0309", ]); + expect( + models.find((model) => model.id === "grok-composer-2.5-fast"), + ).toMatchObject({ contextWindow: 200_000 }); + expect(models.find((model) => model.id === "grok-build")).toMatchObject({ + contextWindow: 512_000, + }); }); it("filters, reorders, and fills unknown model overrides", () => { From 01f31d83714bb42f0cb5d53fdc8d5c6753ec0e61 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Tue, 2 Jun 2026 19:54:35 +0900 Subject: [PATCH 05/24] test: add comprehensive test suites for file, rendering, search, and shell tools, and expand registration integration tests --- tests/provider/register.test.ts | 185 ++++++++++++++++++++++ tests/tools/files.test.ts | 264 ++++++++++++++++++++++++++++++++ tests/tools/rendering.test.ts | 98 ++++++++++++ tests/tools/search.test.ts | 194 +++++++++++++++++++++++ tests/tools/shell.test.ts | 139 +++++++++++++++++ tests/tools/toolTestHelpers.ts | 110 +++++++++++++ 6 files changed, 990 insertions(+) create mode 100644 tests/tools/files.test.ts create mode 100644 tests/tools/rendering.test.ts create mode 100644 tests/tools/search.test.ts create mode 100644 tests/tools/shell.test.ts create mode 100644 tests/tools/toolTestHelpers.ts diff --git a/tests/provider/register.test.ts b/tests/provider/register.test.ts index d87f3d4..e557651 100644 --- a/tests/provider/register.test.ts +++ b/tests/provider/register.test.ts @@ -60,6 +60,9 @@ interface TestContext { getApiKeyForProvider?: (provider: string) => Promise; }; model?: { provider: string; id: string }; + sessionManager?: { + getSessionId: () => string; + }; ui: { notify: (message: string, level: string) => void; }; @@ -143,6 +146,13 @@ function statusContext(notify: TestContext["ui"]["notify"]): TestContext { }; } +function emptyStatusContext(notify: TestContext["ui"]["notify"]): TestContext { + return { + modelRegistry: { getAll: () => [] }, + ui: { notify }, + }; +} + function contextForModel(provider: string): TestContext { return { model: { provider, id: `${provider}-model` }, @@ -271,6 +281,181 @@ describe("Grok CLI status command", () => { "Requests: 42/180 remaining", ); }); + + it("warns when no Grok models are registered", async () => { + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands + .get("grok-cli-status") + ?.handler([], emptyStatusContext(notify)); + + expect(notify).toHaveBeenCalledOnce(); + expect(notify).toHaveBeenCalledWith( + "Grok CLI: no models registered. Run /login grok-cli first.", + "warning", + ); + }); + + it("shows env-token bypass and truncates long model lists", async () => { + process.env.GROK_CLI_OAUTH_TOKEN = "token"; + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get("grok-cli-status")?.handler([], { + modelRegistry: { + getAll: () => + Array.from({ length: 7 }, (_value, index) => ({ + provider: "grok-cli", + id: `grok-model-${index + 1}`, + })), + }, + ui: { notify }, + }); + + expect(notify.mock.calls[0]).toEqual([ + "⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available", + "warning", + ]); + expect(notify.mock.calls[1]).toEqual([ + "✓ Grok CLI: 7 models available (grok-model-1, grok-model-2, grok-model-3, grok-model-4, grok-model-5 (+2 more))", + "info", + ]); + }); + + it("reports registry errors as status warnings", async () => { + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get("grok-cli-status")?.handler([], { + modelRegistry: { + getAll: () => { + throw new Error("registry unavailable"); + }, + }, + ui: { notify }, + }); + + expect(notify).toHaveBeenCalledWith( + "Grok CLI: registry unavailable", + "warning", + ); + }); + + it("includes OAuth error codes in status warnings", async () => { + const { XaiOAuthError } = await import("../../src/shared/errors.js"); + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get("grok-cli-status")?.handler([], { + modelRegistry: { + getAll: () => { + throw new XaiOAuthError("refresh failed", "refresh_failed", true); + }, + }, + ui: { notify }, + }); + + expect(notify).toHaveBeenCalledWith( + "Grok CLI: refresh failed (code: refresh_failed)", + "warning", + ); + }); +}); + +describe("Grok CLI provider registration", () => { + it("registers provider metadata and OAuth helpers", async () => { + const extension = await setupExtension(); + const provider = extension.providers.get("grok-cli"); + + expect(provider?.name).toBe("Grok CLI"); + expect(provider?.api).toBe("openai-responses"); + expect(provider?.apiKey).toBe("$GROK_CLI_OAUTH_TOKEN"); + expect(provider?.models.map((model) => model.id)).toContain("grok-build"); + expect(provider?.oauth?.getApiKey({ access: "access-token" })).toBe( + "access-token", + ); + expect( + provider?.oauth?.modifyModels( + [ + { provider: "grok-cli", id: "grok-build", baseUrl: "old" }, + { provider: "openai", id: "gpt-4", baseUrl: "keep" }, + ], + { + access: "access-token", + refresh: "refresh-token", + expires: 123, + baseUrl: "https://example.invalid/custom///", + }, + ), + ).toEqual([ + { + provider: "grok-cli", + id: "grok-build", + baseUrl: "https://example.invalid/custom", + }, + { provider: "openai", id: "gpt-4", baseUrl: "keep" }, + ]); + }); + + it("sanitizes Grok provider requests with the current session id", async () => { + const extension = await setupExtension(); + const result = extension.handlers.get("before_provider_request")?.( + { + payload: { + input: [{ role: "system", content: "system instruction" }], + }, + }, + { + model: { provider: "grok-cli", id: "grok-4.3" }, + modelRegistry: { getAll: () => [] }, + sessionManager: { getSessionId: () => "session-123" }, + ui: { notify: vi.fn() }, + }, + ); + + expect(result).toEqual({ + input: [], + instructions: "system instruction", + prompt_cache_key: "session-123", + }); + }); + + it("leaves non-Grok provider requests untouched", async () => { + const extension = await setupExtension(); + const payload = { input: [{ role: "system", content: "keep" }] }; + const result = extension.handlers.get("before_provider_request")?.( + { payload }, + { + model: { provider: "openai", id: "gpt-4" }, + modelRegistry: { getAll: () => [] }, + sessionManager: { getSessionId: () => "session-123" }, + ui: { notify: vi.fn() }, + }, + ); + + expect(result).toBeUndefined(); + expect(payload).toEqual({ input: [{ role: "system", content: "keep" }] }); + }); + + it("warns at session start when env-token bypass is active", async () => { + process.env.GROK_CLI_OAUTH_TOKEN = "token"; + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.handlers.get("session_start")?.( + {}, + { + modelRegistry: { getAll: () => [] }, + ui: { notify }, + }, + ); + + expect(notify).toHaveBeenCalledWith( + "[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery", + "warning", + ); + }); }); describe("Grok CLI tool scoping", () => { diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts new file mode 100644 index 0000000..c8af270 --- /dev/null +++ b/tests/tools/files.test.ts @@ -0,0 +1,264 @@ +import { existsSync, mkdirSync, readFileSync, writeFileSync } from "node:fs"; +import { join } from "node:path"; +import { describe, expect, it } from "vitest"; +import { registerFileTools } from "../../src/tools/files.js"; +import { + collectTools, + executeTool, + firstText, + renderToolCall, + renderToolResult, + type ToolResult, + tempDir, +} from "./toolTestHelpers.js"; + +function expectStoryState( + result: ToolResult, + cwd: string, + replacements: number, + content: string, +) { + expect(result.details).toEqual({ + path: join(cwd, "story.txt"), + replacements, + }); + expect(readFileSync(join(cwd, "story.txt"), "utf-8")).toBe(content); +} + +function strReplace(cwd: string, old_str: string, new_str: string) { + return executeTool( + collectTools(registerFileTools).get("StrReplace"), + { path: "story.txt", old_str, new_str }, + cwd, + ); +} + +describe("file tools", () => { + it("lists directory contents including hidden files", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(cwd, ".hidden"), "secret", "utf-8"); + writeFileSync(join(cwd, "visible.txt"), "visible", "utf-8"); + + const result = await executeTool( + collectTools(registerFileTools).get("LS"), + { path: "." }, + cwd, + ); + + expect(firstText(result)).toContain(".hidden"); + expect(firstText(result)).toContain("visible.txt"); + expect(result.details).toEqual({ path: cwd }); + }); + + it("reports filesystem errors for invalid file operations", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + mkdirSync(join(cwd, "dir")); + writeFileSync(join(cwd, "blocked"), "not a directory", "utf-8"); + const tools = collectTools(registerFileTools); + + const lsResult = await executeTool( + tools.get("LS"), + { path: "missing-dir" }, + cwd, + ); + const readResult = await executeTool( + tools.get("Read"), + { path: "dir" }, + cwd, + ); + const writeResult = await executeTool( + tools.get("Write"), + { path: "blocked/file.txt", content: "content" }, + cwd, + ); + const replaceResult = await executeTool( + tools.get("StrReplace"), + { path: "dir", old_str: "old", new_str: "new" }, + cwd, + ); + const deleteResult = await executeTool( + tools.get("Delete"), + { path: "dir" }, + cwd, + ); + + expect(firstText(lsResult).startsWith("LS error:")).toBe(true); + expect(firstText(readResult).startsWith("Read error:")).toBe(true); + expect(firstText(writeResult).startsWith("Write error:")).toBe(true); + expect(firstText(replaceResult).startsWith("StrReplace error:")).toBe(true); + expect(firstText(deleteResult).startsWith("Delete error:")).toBe(true); + expect(writeResult.details).toEqual({ + path: join(cwd, "blocked", "file.txt"), + bytesWritten: 0, + }); + expect(replaceResult.details).toEqual({ + path: join(cwd, "dir"), + replacements: 0, + }); + expect(deleteResult.details).toEqual({ + path: join(cwd, "dir"), + deleted: false, + }); + }); + + it("writes a nested file and reads a requested line window", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + const tools = collectTools(registerFileTools); + + const writeResult = await executeTool( + tools.get("Write"), + { path: "nested/notes.txt", content: "alpha\nbeta\ngamma\ndelta" }, + cwd, + ); + + expect(firstText(writeResult)).toBe( + "Successfully wrote 22 bytes to nested/notes.txt", + ); + expect(writeResult.details).toEqual({ + path: join(cwd, "nested/notes.txt"), + bytesWritten: 22, + }); + + const readResult = await executeTool( + tools.get("Read"), + { path: "nested/notes.txt", offset: 1, limit: 2 }, + cwd, + ); + + expect(firstText(readResult)).toBe( + "2\tbeta\n3\tgamma\n\n[Showing lines 2-3 of 4 total lines. Use offset to see more.]", + ); + expect(readResult.details).toEqual({ + path: join(cwd, "nested/notes.txt"), + totalLines: 4, + }); + }); + + it("reports missing files without throwing", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + const result = await executeTool( + collectTools(registerFileTools).get("Read"), + { path: "missing.txt" }, + cwd, + ); + + expect(firstText(result)).toBe( + `File not found: ${join(cwd, "missing.txt")}`, + ); + expect(result.details).toEqual({ + path: join(cwd, "missing.txt"), + exists: false, + totalLines: 0, + }); + }); + + it("replaces every exact string occurrence", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(cwd, "story.txt"), "red blue red", "utf-8"); + + const result = await strReplace(cwd, "red", "green"); + + expect(firstText(result)).toBe("Replaced 2 occurrence(s) in story.txt"); + expectStoryState(result, cwd, 2, "green blue green"); + }); + + it("leaves files unchanged when the replacement string is absent", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(cwd, "story.txt"), "red blue red", "utf-8"); + + const result = await strReplace(cwd, "purple", "green"); + + expect(firstText(result)).toBe('String not found in story.txt: "purple"'); + expectStoryState(result, cwd, 0, "red blue red"); + }); + + it("deletes existing files and reports missing files", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(cwd, "remove.txt"), "delete me", "utf-8"); + const tools = collectTools(registerFileTools); + + const deletedResult = await executeTool( + tools.get("Delete"), + { path: "remove.txt" }, + cwd, + ); + + expect(firstText(deletedResult)).toBe("Successfully deleted remove.txt"); + expect(deletedResult.details).toEqual({ + path: join(cwd, "remove.txt"), + deleted: true, + }); + expect(existsSync(join(cwd, "remove.txt"))).toBe(false); + + const missingResult = await executeTool( + tools.get("Delete"), + { path: "remove.txt" }, + cwd, + ); + + expect(firstText(missingResult)).toBe( + `File not found: ${join(cwd, "remove.txt")}`, + ); + expect(missingResult.details).toEqual({ + path: join(cwd, "remove.txt"), + deleted: false, + }); + }); + + it("renders file tool calls and result states", () => { + const tools = collectTools(registerFileTools); + + expect(renderToolCall(tools.get("LS"), { path: "." })).toBe("LS ."); + expect( + renderToolCall(tools.get("Read"), { + path: "notes.txt", + offset: 5, + limit: 10, + }), + ).toBe("Read notes.txt (from 5, 10 lines)"); + expect(renderToolCall(tools.get("StrReplace"), { path: "notes.txt" })).toBe( + "StrReplace notes.txt", + ); + expect(renderToolCall(tools.get("Delete"), { path: "notes.txt" })).toBe( + "Delete notes.txt", + ); + expect( + renderToolResult(tools.get("Read"), { + content: [{ type: "text", text: "missing" }], + details: { exists: false, totalLines: 0 }, + }), + ).toBe("File not found"); + expect( + renderToolResult(tools.get("StrReplace"), { + content: [{ type: "text", text: "no replacement" }], + details: { replacements: 0 }, + }), + ).toBe("No replacements"); + expect( + renderToolResult(tools.get("Delete"), { + content: [{ type: "text", text: "not deleted" }], + details: { deleted: false }, + }), + ).toBe("Not deleted"); + expect( + renderToolResult( + tools.get("LS"), + { + content: [{ type: "text", text: "full listing" }], + details: { path: "/tmp/project" }, + }, + { expanded: true, isPartial: false }, + ), + ).toBe("full listing"); + expect( + renderToolResult( + tools.get("Write"), + { + content: [{ type: "text", text: "writing" }], + details: { bytesWritten: 10 }, + }, + { expanded: false, isPartial: true }, + ), + ).toBe("Running..."); + }); +}); diff --git a/tests/tools/rendering.test.ts b/tests/tools/rendering.test.ts new file mode 100644 index 0000000..fc54d81 --- /dev/null +++ b/tests/tools/rendering.test.ts @@ -0,0 +1,98 @@ +import { describe, expect, it } from "vitest"; +import { + booleanDetail, + detailRecord, + fileError, + fileNotFound, + MAX_LINES, + MAX_OUTPUT_CHARS, + numberDetail, + renderResultSummary, + renderResultText, + renderRunning, + stringDetail, + text, + toolError, + truncateChars, + truncateLines, +} from "../../src/tools/rendering.js"; +import { renderText } from "./toolTestHelpers.js"; + +describe("tool rendering helpers", () => { + it("truncates long result lists and large output", () => { + expect(truncateLines(["one", "two"])).toBe("one\ntwo"); + expect( + truncateLines(Array.from({ length: MAX_LINES + 1 }, String)).endsWith( + `\n\n[Showing first ${MAX_LINES} of ${MAX_LINES + 1} results. Refine your pattern to narrow results.]`, + ), + ).toBe(true); + expect(truncateChars("short")).toBe("short"); + expect(truncateChars("x".repeat(MAX_OUTPUT_CHARS + 1))).toBe( + `${"x".repeat(MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`, + ); + }); + + it("renders summaries, expanded text, missing text fallback, and partial state", () => { + const result = { + content: [{ type: "text", text: "full output" }], + details: {}, + }; + + expect(renderText(text("plain"))).toBe("plain"); + expect(renderText(renderResultText(result, false, "summary"))).toBe( + "summary", + ); + expect(renderText(renderResultText(result, true, "summary"))).toBe( + "full output", + ); + expect( + renderText( + renderResultText( + { content: [{ type: "image" }], details: {} }, + true, + "summary", + ), + ), + ).toBe("summary"); + expect(renderText(renderRunning(true) ?? text(""))).toBe("Running..."); + expect(renderRunning(false)).toBeUndefined(); + expect( + renderText(renderResultSummary(result, false, true, "summary")), + ).toBe("Running..."); + }); + + it("reads typed detail values with defaults for absent or invalid details", () => { + const result = { + content: [{ type: "text", text: "" }], + details: { count: 2, path: "file.txt", deleted: true, invalid: null }, + }; + + expect(detailRecord(result)).toEqual(result.details); + expect(detailRecord({ details: null })).toEqual({}); + expect(numberDetail(result, "count")).toBe(2); + expect(numberDetail(result, "path")).toBe(0); + expect(stringDetail(result, "path")).toBe("file.txt"); + expect(stringDetail(result, "count")).toBe(""); + expect(booleanDetail(result, "deleted")).toBe(true); + expect(booleanDetail(result, "invalid")).toBe(false); + }); + + it("formats file and command errors with stable empty details", () => { + expect(fileNotFound("/tmp/missing.txt", { deleted: false })).toEqual({ + content: [{ type: "text", text: "File not found: /tmp/missing.txt" }], + details: { path: "/tmp/missing.txt", deleted: false }, + }); + expect(fileError({}, "Read", "/tmp/file.txt", { totalLines: 0 })).toEqual({ + content: [{ type: "text", text: "Read error: Unknown error" }], + details: { path: "/tmp/file.txt", totalLines: 0 }, + }); + expect(toolError({ code: 1 }, "Grep", { matchCount: 0 })).toEqual({ + content: [{ type: "text", text: "No matches found" }], + details: { matchCount: 0 }, + }); + expect(toolError({}, "Grep", { matchCount: 0 })).toEqual({ + content: [{ type: "text", text: "Grep error: Unknown error" }], + details: { matchCount: 0 }, + }); + }); +}); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts new file mode 100644 index 0000000..70cfd00 --- /dev/null +++ b/tests/tools/search.test.ts @@ -0,0 +1,194 @@ +import { mkdirSync, writeFileSync } from "node:fs"; +import { join } from "node:path"; +import { describe, expect, it } from "vitest"; +import { registerSearchTools } from "../../src/tools/search.js"; +import { + collectTools, + executeTool, + firstText, + plainTheme, + renderText, + tempDir, +} from "./toolTestHelpers.js"; + +function setupProject() { + const dir = tempDir("pi-grok-cli-search-"); + mkdirSync(join(dir, "src")); + writeFileSync(join(dir, "src", "alpha.ts"), "needle\nhaystack\n", "utf-8"); + writeFileSync(join(dir, "src", "beta.md"), "needle in docs\n", "utf-8"); + writeFileSync(join(dir, "src", "gamma.ts"), "plain text\n", "utf-8"); + return dir; +} + +describe("search tools", () => { + it("greps matching file contents with include filters", async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get("Grep"), + { pattern: "needle", path: "src", include: "*.ts" }, + cwd, + ); + + expect(firstText(result)).toContain( + `${join(cwd, "src", "alpha.ts")}:1:needle`, + ); + expect(firstText(result)).not.toContain("beta.md"); + expect(result.details).toEqual({ matchCount: 1 }); + }); + + it("reports no grep matches as an empty result", async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get("Grep"), + { pattern: "absent", path: "src" }, + cwd, + ); + + expect(firstText(result)).toBe("No matches found"); + expect(result.details).toEqual({ matchCount: 0 }); + }); + + it("reports grep command errors with empty match details", async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get("Grep"), + { pattern: "[", path: "src" }, + cwd, + ); + + expect(firstText(result).startsWith("Grep error:")).toBe(true); + expect(result.details).toEqual({ matchCount: 0 }); + }); + + it("globs files under the requested path", async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get("Glob"), + { pattern: "**/*.ts", path: "src" }, + cwd, + ); + + expect(firstText(result)).toContain(join(cwd, "src", "alpha.ts")); + expect(firstText(result)).toContain(join(cwd, "src", "gamma.ts")); + expect(firstText(result)).not.toContain("beta.md"); + expect(result.details).toEqual({ fileCount: 2 }); + }); + + it("reports empty glob command results", async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get("Glob"), + { pattern: "**/*.json", path: "src" }, + cwd, + ); + + expect(firstText(result)).toBe("No matches found"); + expect(result.details).toEqual({ fileCount: 0 }); + }); + + it("renders grep calls and result states", () => { + const grep = collectTools(registerSearchTools).get("Grep"); + const result = { + content: [{ type: "text", text: "src/alpha.ts:1:needle" }], + details: { matchCount: 1 }, + }; + + expect( + renderText( + grep?.renderCall?.( + { pattern: "needle", path: "src", include: "*.ts" }, + plainTheme, + ) ?? { render: () => [] }, + ), + ).toBe('Grep "needle" in src [*.ts]'); + expect( + renderText( + grep?.renderResult?.( + result, + { expanded: false, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe("1 match(es)"); + expect( + renderText( + grep?.renderResult?.( + result, + { expanded: true, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe("src/alpha.ts:1:needle"); + expect( + renderText( + grep?.renderResult?.( + { + content: [{ type: "text", text: "No matches found" }], + details: {}, + }, + { expanded: false, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe("No matches"); + expect( + renderText( + grep?.renderResult?.( + result, + { expanded: false, isPartial: true }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe("Running..."); + }); + + it("renders glob calls and result states", () => { + const glob = collectTools(registerSearchTools).get("Glob"); + const result = { + content: [{ type: "text", text: "src/alpha.ts\nsrc/gamma.ts" }], + details: { fileCount: 2 }, + }; + + expect( + renderText( + glob?.renderCall?.({ pattern: "**/*.ts", path: "src" }, plainTheme) ?? { + render: () => [], + }, + ), + ).toBe("Glob **/*.ts in src"); + expect( + renderText( + glob?.renderResult?.( + result, + { expanded: false, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe("2 file(s)"); + expect( + renderText( + glob?.renderResult?.( + { content: [{ type: "text", text: "No files found" }], details: {} }, + { expanded: false, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe("No files"); + expect( + renderText( + glob?.renderResult?.( + result, + { expanded: false, isPartial: true }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe("Running..."); + }); +}); diff --git a/tests/tools/shell.test.ts b/tests/tools/shell.test.ts new file mode 100644 index 0000000..3640fa0 --- /dev/null +++ b/tests/tools/shell.test.ts @@ -0,0 +1,139 @@ +import { writeFileSync } from "node:fs"; +import { join } from "node:path"; +import { describe, expect, it } from "vitest"; +import { registerShellTool } from "../../src/tools/shell.js"; +import { + collectTools, + executeTool, + firstText, + renderToolCall, + renderToolResult, + tempDir, +} from "./toolTestHelpers.js"; + +describe("shell tool", () => { + it("returns stdout, stderr, and exit zero details", async () => { + const cwd = tempDir("pi-grok-cli-shell-"); + const result = await executeTool( + collectTools(registerShellTool).get("Shell"), + { command: "printf stdout && printf stderr >&2" }, + cwd, + ); + + expect(firstText(result)).toBe("stdout\n[stderr]\nstderr"); + expect(result.details).toEqual({ + exitCode: 0, + command: "printf stdout && printf stderr >&2", + }); + }); + + it("runs commands in a resolved working directory", async () => { + const cwd = tempDir("pi-grok-cli-shell-"); + writeFileSync(join(cwd, "target.txt"), "from cwd", "utf-8"); + const result = await executeTool( + collectTools(registerShellTool).get("Shell"), + { command: "cat target.txt", working_directory: "." }, + cwd, + ); + + expect(firstText(result)).toBe("from cwd"); + expect(result.details).toEqual({ + exitCode: 0, + command: "cat target.txt", + }); + }); + + it("returns a clear placeholder when commands produce no output", async () => { + const cwd = tempDir("pi-grok-cli-shell-"); + const result = await executeTool( + collectTools(registerShellTool).get("Shell"), + { command: "true" }, + cwd, + ); + + expect(firstText(result)).toBe("(no output)"); + expect(result.details).toEqual({ exitCode: 0, command: "true" }); + }); + + it("includes exit code, error message, and captured output on failure", async () => { + const cwd = tempDir("pi-grok-cli-shell-"); + const result = await executeTool( + collectTools(registerShellTool).get("Shell"), + { command: "printf before && printf problem >&2 && exit 7" }, + cwd, + ); + + expect(firstText(result)).toContain("Shell error (exit code 7):"); + expect(firstText(result)).toContain("before\n[stderr]\nproblem"); + expect(result.details).toEqual({ + exitCode: 7, + command: "printf before && printf problem >&2 && exit 7", + }); + }); + + it("truncates large successful and failed output", async () => { + const cwd = tempDir("pi-grok-cli-shell-"); + const tools = collectTools(registerShellTool); + const largeOutput = "head -c 50001 /dev/zero | tr '\\0' x"; + + const successResult = await executeTool( + tools.get("Shell"), + { command: largeOutput }, + cwd, + ); + const failureResult = await executeTool( + tools.get("Shell"), + { command: `${largeOutput}; exit 9` }, + cwd, + ); + + expect(firstText(successResult)).toHaveLength( + "\n\n[Output truncated at 50KB]".length + 50_000, + ); + expect( + firstText(successResult).endsWith("[Output truncated at 50KB]"), + ).toBe(true); + expect(firstText(failureResult)).toContain("Shell error (exit code 9):"); + expect( + firstText(failureResult).endsWith("[Output truncated at 50KB]"), + ).toBe(true); + }); + + it("renders shell calls and result states", () => { + const shell = collectTools(registerShellTool).get("Shell"); + + expect( + renderToolCall(shell, { + command: "pwd", + working_directory: "src", + }), + ).toBe("Shell pwd in src"); + expect(renderToolCall(shell, { command: "pwd" })).toBe("Shell pwd"); + expect( + renderToolResult(shell, { + content: [{ type: "text", text: "full output" }], + details: { exitCode: 0 }, + }), + ).toBe("Exit 0"); + expect( + renderToolResult( + shell, + { + content: [{ type: "text", text: "full output" }], + details: { exitCode: 0 }, + }, + { expanded: true, isPartial: false }, + ), + ).toBe("full output"); + expect( + renderToolResult( + shell, + { + content: [{ type: "text", text: "still running" }], + details: { exitCode: 0 }, + }, + { expanded: false, isPartial: true }, + ), + ).toBe("Running..."); + }); +}); diff --git a/tests/tools/toolTestHelpers.ts b/tests/tools/toolTestHelpers.ts new file mode 100644 index 0000000..c3788e9 --- /dev/null +++ b/tests/tools/toolTestHelpers.ts @@ -0,0 +1,110 @@ +import { mkdtempSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { afterEach } from "vitest"; + +const tempDirs: string[] = []; + +afterEach(() => { + for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); +}); + +export type ToolResult = { + content: { type: string; text?: string }[]; + details: Record; +}; + +type Renderable = { render: (width: number) => string[] }; + +type ToolTheme = { + bold: (text: string) => string; + fg: (name: string, text: string) => string; +}; + +type RegisteredTool = { + name: string; + execute: ( + toolCallId: string, + params: Record, + signal: AbortSignal, + onUpdate: () => void, + ctx: { cwd: string }, + ) => Promise; + renderCall?: (args: Record, theme: ToolTheme) => Renderable; + renderResult?: ( + result: ToolResult, + state: { expanded: boolean; isPartial: boolean }, + theme: ToolTheme, + args: Record, + ) => Renderable; +}; + +export function collectTools(registerTools: (pi: ExtensionAPI) => void) { + const tools = new Map(); + registerTools({ + registerTool(tool: RegisteredTool) { + tools.set(tool.name, tool); + }, + } as unknown as ExtensionAPI); + return tools; +} + +export async function executeTool( + tool: RegisteredTool | undefined, + params: Record, + cwd: string, +) { + if (!tool) throw new Error("Tool was not registered"); + return tool.execute( + "tool-call-id", + params, + new AbortController().signal, + () => {}, + { + cwd, + }, + ); +} + +export function firstText(result: ToolResult) { + return result.content[0]?.text ?? ""; +} + +export function renderText(component: { render: (width: number) => string[] }) { + return component + .render(120) + .map((line) => line.trimEnd()) + .join("\n"); +} + +export const plainTheme = { + bold: (text: string) => text, + fg: (_name: string, text: string) => text, +}; + +export function renderToolCall( + tool: RegisteredTool | undefined, + args: Record, +) { + if (!tool?.renderCall) + throw new Error("Tool call renderer was not registered"); + return renderText(tool.renderCall(args, plainTheme)); +} + +export function renderToolResult( + tool: RegisteredTool | undefined, + result: ToolResult, + state = { expanded: false, isPartial: false }, +) { + if (!tool?.renderResult) { + throw new Error("Tool result renderer was not registered"); + } + return renderText(tool.renderResult(result, state, plainTheme, {})); +} + +export function tempDir(prefix: string) { + const dir = mkdtempSync(join(tmpdir(), prefix)); + tempDirs.push(dir); + return dir; +} From ceb74e411d7e9d0e8e45c5f043070ab7535ada72 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Tue, 2 Jun 2026 23:37:51 +0900 Subject: [PATCH 06/24] feat: add Edit tool and support for Cursor-style argument variations across file and search tools --- src/provider/toolScope.ts | 1 + src/tools/files.ts | 259 +++++++++++++++++++++++++++++--- src/tools/rendering.ts | 12 ++ src/tools/search.ts | 23 +++ tests/provider/register.test.ts | 1 + tests/tools/files.test.ts | 142 +++++++++++++++++ tests/tools/register.test.ts | 1 + tests/tools/search.test.ts | 48 +++++- tests/tools/toolTestHelpers.ts | 20 +++ 9 files changed, 481 insertions(+), 26 deletions(-) diff --git a/src/provider/toolScope.ts b/src/provider/toolScope.ts index 4d24ebd..fcbbee0 100644 --- a/src/provider/toolScope.ts +++ b/src/provider/toolScope.ts @@ -7,6 +7,7 @@ export const GROK_TOOL_NAMES = [ "Read", "Write", "StrReplace", + "Edit", "Delete", "Shell", ]; diff --git a/src/tools/files.ts b/src/tools/files.ts index 4dc5dc1..8abb63a 100644 --- a/src/tools/files.ts +++ b/src/tools/files.ts @@ -17,19 +17,116 @@ import { fileNotFound, MAX_OUTPUT_CHARS, numberDetail, + recordFrom, renderResultSummary, stringDetail, + stringFrom, type ToolError, text, } from "./rendering.js"; const execFileAsync = promisify(execFile); +type ReplacementEdit = { oldText: string; newText: string }; +type WriteArgs = { path: string; content: string }; +type StrReplaceArgs = { path: string; old_str: string; new_str: string }; +type EditArgs = { + path: string; + edits?: ReplacementEdit[]; + applyPatch?: { patchContent: string }; + strReplace?: ReplacementEdit; + multiStrReplace?: { edits: ReplacementEdit[] }; +}; + type ToolTheme = { bold: (text: string) => string; fg: (name: "accent" | "toolTitle", text: string) => string; }; +function parseEditList(value: unknown): ReplacementEdit[] | undefined { + const editList = typeof value === "string" ? parseJson(value) : value; + if (!Array.isArray(editList)) return undefined; + if ( + !editList.every( + (edit) => + typeof recordFrom(edit)?.oldText === "string" && + typeof recordFrom(edit)?.newText === "string", + ) + ) { + return undefined; + } + return editList.map((edit) => ({ + oldText: stringFrom(recordFrom(edit)?.oldText) ?? "", + newText: stringFrom(recordFrom(edit)?.newText) ?? "", + })); +} + +function parseJson(value: string): unknown { + try { + return JSON.parse(value); + } catch { + return undefined; + } +} + +function editFromText(oldText: unknown, newText: unknown) { + if (typeof oldText !== "string" || typeof newText !== "string") + return undefined; + return [{ oldText, newText }]; +} + +function editsFromArgs(input: Record) { + return ( + parseEditList(input.edits) ?? + parseEditList(recordFrom(input.multiStrReplace)?.edits) ?? + editFromText(input.oldText, input.newText) ?? + editFromText( + recordFrom(input.strReplace)?.oldText, + recordFrom(input.strReplace)?.newText, + ) + ); +} + +function applyEdits(content: string, edits: ReplacementEdit[]) { + return edits.reduce( + (result, edit) => { + const count = result.content.split(edit.oldText).length - 1; + return { + content: + count === 0 + ? result.content + : result.content.replaceAll(edit.oldText, edit.newText), + replacements: result.replacements + count, + }; + }, + { content, replacements: 0 }, + ); +} + +function replacementResult(text: string, filePath: string) { + return { + content: [{ type: "text" as const, text }], + details: { path: filePath, replacements: 0 }, + }; +} + +function renderReplacementResult( + result: { content: { type: string; text?: string }[]; details: unknown }, + expanded: boolean, + isPartial: boolean, + theme: { fg: (name: "dim" | "muted", text: string) => string }, +) { + const replacements = numberDetail(result, "replacements"); + return renderResultSummary( + result, + expanded, + isPartial, + replacements === 0 + ? theme.fg("dim", "No replacements") + : theme.fg("muted", `${replacements} replacement(s)`), + ); +} + function renderPathToolCall( toolName: string, filePath: string, @@ -211,6 +308,15 @@ export function registerFileTools(pi: ExtensionAPI) { "Create or overwrite a file with the given content. Creates parent directories if needed.", parameters: WriteParams, + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as WriteArgs; + return { + ...input, + content: stringFrom(input.content) ?? stringFrom(input.contents), + } as WriteArgs; + }, + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { const filePath = resolve(ctx.cwd, params.path); @@ -277,6 +383,24 @@ export function registerFileTools(pi: ExtensionAPI) { "Replace all occurrences of a string in a file. The old_str must be an exact match.", parameters: StrReplaceParams, + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as StrReplaceArgs; + return { + ...input, + old_str: + stringFrom(input.old_str) ?? + stringFrom(input.old_string) ?? + stringFrom(input.oldText) ?? + stringFrom(recordFrom(input.strReplace)?.oldText), + new_str: + stringFrom(input.new_str) ?? + stringFrom(input.new_string) ?? + stringFrom(input.newText) ?? + stringFrom(recordFrom(input.strReplace)?.newText), + } as StrReplaceArgs; + }, + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { const filePath = resolve(ctx.cwd, params.path); @@ -289,48 +413,147 @@ export function registerFileTools(pi: ExtensionAPI) { const count = content.split(params.old_str).length - 1; if (count === 0) { + return replacementResult( + `String not found in ${params.path}: "${params.old_str}"`, + filePath, + ); + } + + const newContent = content.replaceAll(params.old_str, params.new_str); + writeFileSync(filePath, newContent, "utf-8"); + + return { + content: [ + { + type: "text", + text: `Replaced ${count} occurrence(s) in ${params.path}`, + }, + ], + details: { path: filePath, replacements: count }, + }; + } catch (error: unknown) { + return fileError(error, "StrReplace", filePath, { replacements: 0 }); + } + }, + renderCall(args, theme) { + return renderPathToolCall("StrReplace", args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderReplacementResult(result, expanded, isPartial, theme); + }, + }); + + // ── Edit tool ──────────────────────────────────────────────────────── + + const EditItemParams = Type.Object({ + oldText: Type.String({ + description: "String to search for (exact match)", + }), + newText: Type.String({ + description: "String to replace with", + }), + replaceAll: Type.Optional( + Type.Boolean({ + description: + "Accepted for Cursor compatibility. Replacements are always applied to all matches.", + }), + ), + }); + + const EditParams = Type.Object({ + path: Type.String({ + description: "Path to the file to modify", + }), + edits: Type.Optional( + Type.Array(EditItemParams, { + description: "Exact text replacements to apply sequentially", + }), + ), + applyPatch: Type.Optional( + Type.Object({ + patchContent: Type.String({ + description: "Unsupported unified patch content", + }), + }), + ), + strReplace: Type.Optional(EditItemParams), + multiStrReplace: Type.Optional( + Type.Object({ + edits: Type.Array(EditItemParams), + }), + ), + }); + + pi.registerTool({ + name: "Edit", + label: "Edit", + description: + "Modify a file with exact text replacement. applyPatch is not supported by this Grok tool shim.", + parameters: EditParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as EditArgs; + return { + ...input, + edits: editsFromArgs(input), + } as EditArgs; + }, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + if (!existsSync(filePath)) { + return fileNotFound(filePath, { replacements: 0 }); + } + + try { + if (!params.edits?.length) { return { content: [ { type: "text", - text: `String not found in ${params.path}: "${params.old_str}"`, + text: params.applyPatch + ? "Edit error: applyPatch is not supported by this Grok tool shim" + : "Edit error: provide at least one exact text replacement", }, ], details: { path: filePath, replacements: 0 }, }; } - const newContent = content.replaceAll(params.old_str, params.new_str); - writeFileSync(filePath, newContent, "utf-8"); + const result = applyEdits( + readFileSync(filePath, "utf-8"), + params.edits, + ); + + if (result.replacements === 0) { + return replacementResult( + `No replacement strings found in ${params.path}`, + filePath, + ); + } + + writeFileSync(filePath, result.content, "utf-8"); return { content: [ { type: "text", - text: `Replaced ${count} occurrence(s) in ${params.path}`, + text: `Applied ${result.replacements} replacement(s) in ${params.path}`, }, ], - details: { path: filePath, replacements: count }, + details: { path: filePath, replacements: result.replacements }, }; } catch (error: unknown) { - return fileError(error, "StrReplace", filePath, { replacements: 0 }); + return fileError(error, "Edit", filePath, { replacements: 0 }); } }, renderCall(args, theme) { - return renderPathToolCall("StrReplace", args.path, theme); + return renderPathToolCall("Edit", args.path, theme); }, renderResult(result, { expanded, isPartial }, theme) { - return renderResultSummary( - result, - expanded, - isPartial, - numberDetail(result, "replacements") === 0 - ? theme.fg("dim", "No replacements") - : theme.fg( - "muted", - `${numberDetail(result, "replacements")} replacement(s)`, - ), - ); + return renderReplacementResult(result, expanded, isPartial, theme); }, }); diff --git a/src/tools/rendering.ts b/src/tools/rendering.ts index 3c5f4f9..d02ece9 100644 --- a/src/tools/rendering.ts +++ b/src/tools/rendering.ts @@ -7,6 +7,18 @@ const execFileAsync = promisify(execFile); export const MAX_OUTPUT_CHARS = 50_000; export const MAX_LINES = 500; +export function recordFrom( + value: unknown, +): Record | undefined { + if (!value || typeof value !== "object") return undefined; + return value as Record; +} + +export function stringFrom(value: unknown): string | undefined { + if (typeof value !== "string") return undefined; + return value; +} + export function truncateLines(lines: string[]): string { if (lines.length > MAX_LINES) { return ( diff --git a/src/tools/search.ts b/src/tools/search.ts index 34b422d..bd4bc7d 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -8,8 +8,10 @@ import { hasRipgrep, MAX_OUTPUT_CHARS, numberDetail, + recordFrom, renderResultText, renderRunning, + stringFrom, text, toolError, truncateChars, @@ -18,6 +20,9 @@ import { const execFileAsync = promisify(execFile); +type GrepArgs = { pattern: string; path?: string; include?: string }; +type GlobArgs = { pattern: string; path?: string }; + export function registerSearchTools(pi: ExtensionAPI) { const GrepParams = Type.Object({ pattern: Type.String({ @@ -44,6 +49,15 @@ export function registerSearchTools(pi: ExtensionAPI) { "Search for a regex pattern in file contents. Returns matching lines with file path and line number. Use the include parameter to filter by file type.", parameters: GrepParams, + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as GrepArgs; + return { + ...input, + include: stringFrom(input.include) ?? stringFrom(input.glob_filter), + } as GrepArgs; + }, + async execute(_toolCallId, params, signal, _onUpdate, ctx) { const searchPath = resolve(ctx.cwd, params.path ?? "."); @@ -122,6 +136,15 @@ export function registerSearchTools(pi: ExtensionAPI) { "Find files matching a glob pattern. Returns a list of matching file paths sorted by modification time (newest first).", parameters: GlobParams, + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as GlobArgs; + return { + ...input, + pattern: stringFrom(input.pattern) ?? stringFrom(input.glob_pattern), + } as GlobArgs; + }, + async execute(_toolCallId, params, signal, _onUpdate, ctx) { const searchPath = resolve(ctx.cwd, params.path ?? "."); diff --git a/tests/provider/register.test.ts b/tests/provider/register.test.ts index e557651..e122036 100644 --- a/tests/provider/register.test.ts +++ b/tests/provider/register.test.ts @@ -77,6 +77,7 @@ const grokToolNames = [ "Read", "Write", "StrReplace", + "Edit", "Delete", "Shell", ]; diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts index c8af270..a036709 100644 --- a/tests/tools/files.test.ts +++ b/tests/tools/files.test.ts @@ -4,6 +4,7 @@ import { describe, expect, it } from "vitest"; import { registerFileTools } from "../../src/tools/files.js"; import { collectTools, + executePreparedTool, executeTool, firstText, renderToolCall, @@ -33,6 +34,17 @@ function strReplace(cwd: string, old_str: string, new_str: string) { ); } +function strReplaceWithPreparedArgs( + cwd: string, + params: Record, +) { + return executePreparedTool( + collectTools(registerFileTools).get("StrReplace"), + { path: "story.txt", ...params }, + cwd, + ); +} + describe("file tools", () => { it("lists directory contents including hidden files", async () => { const cwd = tempDir("pi-grok-cli-files-"); @@ -134,6 +146,27 @@ describe("file tools", () => { }); }); + it("writes Cursor-style contents arguments", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + + const result = await executePreparedTool( + collectTools(registerFileTools).get("Write"), + { path: "nested/notes.txt", contents: "alpha\nbeta" }, + cwd, + ); + + expect(firstText(result)).toBe( + "Successfully wrote 10 bytes to nested/notes.txt", + ); + expect(readFileSync(join(cwd, "nested/notes.txt"), "utf-8")).toBe( + "alpha\nbeta", + ); + expect(result.details).toEqual({ + path: join(cwd, "nested/notes.txt"), + bytesWritten: 10, + }); + }); + it("reports missing files without throwing", async () => { const cwd = tempDir("pi-grok-cli-files-"); const result = await executeTool( @@ -162,6 +195,115 @@ describe("file tools", () => { expectStoryState(result, cwd, 2, "green blue green"); }); + it("replaces string occurrences with Grok and Cursor argument variants", async () => { + const oldStringCwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(oldStringCwd, "story.txt"), "red blue red", "utf-8"); + + const oldStringResult = await strReplaceWithPreparedArgs(oldStringCwd, { + old_string: "red", + new_string: "green", + }); + + expect(firstText(oldStringResult)).toBe( + "Replaced 2 occurrence(s) in story.txt", + ); + expectStoryState(oldStringResult, oldStringCwd, 2, "green blue green"); + + const oldTextCwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(oldTextCwd, "story.txt"), "red blue red", "utf-8"); + + const oldTextResult = await strReplaceWithPreparedArgs(oldTextCwd, { + oldText: "red", + newText: "green", + }); + + expect(firstText(oldTextResult)).toBe( + "Replaced 2 occurrence(s) in story.txt", + ); + expectStoryState(oldTextResult, oldTextCwd, 2, "green blue green"); + + const nestedCwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(nestedCwd, "story.txt"), "red blue red", "utf-8"); + + const nestedResult = await strReplaceWithPreparedArgs(nestedCwd, { + strReplace: { oldText: "red", newText: "green" }, + }); + + expect(firstText(nestedResult)).toBe( + "Replaced 2 occurrence(s) in story.txt", + ); + expectStoryState(nestedResult, nestedCwd, 2, "green blue green"); + }); + + it("edits files with single, multiple, and stringified replacement inputs", async () => { + const singleCwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(singleCwd, "story.txt"), "red blue red", "utf-8"); + + const singleResult = await executePreparedTool( + collectTools(registerFileTools).get("Edit"), + { path: "story.txt", oldText: "red", newText: "green" }, + singleCwd, + ); + + expect(firstText(singleResult)).toBe( + "Applied 2 replacement(s) in story.txt", + ); + expectStoryState(singleResult, singleCwd, 2, "green blue green"); + + const multipleCwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(multipleCwd, "story.txt"), "red blue red", "utf-8"); + + const multipleResult = await executePreparedTool( + collectTools(registerFileTools).get("Edit"), + { + path: "story.txt", + edits: [ + { oldText: "red", newText: "green" }, + { oldText: "blue", newText: "yellow" }, + ], + }, + multipleCwd, + ); + + expect(firstText(multipleResult)).toBe( + "Applied 3 replacement(s) in story.txt", + ); + expectStoryState(multipleResult, multipleCwd, 3, "green yellow green"); + + const stringifiedCwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(stringifiedCwd, "story.txt"), "red blue red", "utf-8"); + + const stringifiedResult = await executePreparedTool( + collectTools(registerFileTools).get("Edit"), + { + path: "story.txt", + edits: JSON.stringify([{ oldText: "red", newText: "green" }]), + }, + stringifiedCwd, + ); + + expect(firstText(stringifiedResult)).toBe( + "Applied 2 replacement(s) in story.txt", + ); + expectStoryState(stringifiedResult, stringifiedCwd, 2, "green blue green"); + }); + + it("reports unsupported edit strategies without changing files", async () => { + const cwd = tempDir("pi-grok-cli-files-"); + writeFileSync(join(cwd, "story.txt"), "red blue red", "utf-8"); + + const result = await executePreparedTool( + collectTools(registerFileTools).get("Edit"), + { path: "story.txt", applyPatch: { patchContent: "patch" } }, + cwd, + ); + + expect(firstText(result)).toBe( + "Edit error: applyPatch is not supported by this Grok tool shim", + ); + expectStoryState(result, cwd, 0, "red blue red"); + }); + it("leaves files unchanged when the replacement string is absent", async () => { const cwd = tempDir("pi-grok-cli-files-"); writeFileSync(join(cwd, "story.txt"), "red blue red", "utf-8"); diff --git a/tests/tools/register.test.ts b/tests/tools/register.test.ts index 8916c62..5e1a272 100644 --- a/tests/tools/register.test.ts +++ b/tests/tools/register.test.ts @@ -20,6 +20,7 @@ describe("Grok tool registration", () => { expect(toolNames.sort()).toEqual([ "Delete", + "Edit", "Glob", "Grep", "LS", diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index 70cfd00..2a34811 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -4,10 +4,12 @@ import { describe, expect, it } from "vitest"; import { registerSearchTools } from "../../src/tools/search.js"; import { collectTools, + executePreparedTool, executeTool, firstText, plainTheme, renderText, + type ToolResult, tempDir, } from "./toolTestHelpers.js"; @@ -20,6 +22,21 @@ function setupProject() { return dir; } +function expectGrepResult(cwd: string, result: ToolResult) { + expect(firstText(result)).toContain( + `${join(cwd, "src", "alpha.ts")}:1:needle`, + ); + expect(firstText(result)).not.toContain("beta.md"); + expect(result.details).toEqual({ matchCount: 1 }); +} + +function expectGlobResult(cwd: string, result: ToolResult) { + expect(firstText(result)).toContain(join(cwd, "src", "alpha.ts")); + expect(firstText(result)).toContain(join(cwd, "src", "gamma.ts")); + expect(firstText(result)).not.toContain("beta.md"); + expect(result.details).toEqual({ fileCount: 2 }); +} + describe("search tools", () => { it("greps matching file contents with include filters", async () => { const cwd = setupProject(); @@ -29,11 +46,18 @@ describe("search tools", () => { cwd, ); - expect(firstText(result)).toContain( - `${join(cwd, "src", "alpha.ts")}:1:needle`, + expectGrepResult(cwd, result); + }); + + it("greps matching file contents with Cursor-style glob filters", async () => { + const cwd = setupProject(); + const result = await executePreparedTool( + collectTools(registerSearchTools).get("Grep"), + { pattern: "needle", path: "src", glob_filter: "*.ts" }, + cwd, ); - expect(firstText(result)).not.toContain("beta.md"); - expect(result.details).toEqual({ matchCount: 1 }); + + expectGrepResult(cwd, result); }); it("reports no grep matches as an empty result", async () => { @@ -68,10 +92,18 @@ describe("search tools", () => { cwd, ); - expect(firstText(result)).toContain(join(cwd, "src", "alpha.ts")); - expect(firstText(result)).toContain(join(cwd, "src", "gamma.ts")); - expect(firstText(result)).not.toContain("beta.md"); - expect(result.details).toEqual({ fileCount: 2 }); + expectGlobResult(cwd, result); + }); + + it("globs files with Cursor-style glob pattern arguments", async () => { + const cwd = setupProject(); + const result = await executePreparedTool( + collectTools(registerSearchTools).get("Glob"), + { glob_pattern: "**/*.ts", path: "src" }, + cwd, + ); + + expectGlobResult(cwd, result); }); it("reports empty glob command results", async () => { diff --git a/tests/tools/toolTestHelpers.ts b/tests/tools/toolTestHelpers.ts index c3788e9..e24011f 100644 --- a/tests/tools/toolTestHelpers.ts +++ b/tests/tools/toolTestHelpers.ts @@ -24,6 +24,9 @@ type ToolTheme = { type RegisteredTool = { name: string; + prepareArguments?: ( + params: Record, + ) => Record; execute: ( toolCallId: string, params: Record, @@ -67,6 +70,23 @@ export async function executeTool( ); } +export function prepareToolArguments( + tool: RegisteredTool | undefined, + params: Record, +) { + if (!tool) throw new Error("Tool was not registered"); + return tool.prepareArguments?.(params) ?? params; +} + +export async function executePreparedTool( + tool: RegisteredTool | undefined, + params: Record, + cwd: string, +) { + if (!tool) throw new Error("Tool was not registered"); + return executeTool(tool, prepareToolArguments(tool, params), cwd); +} + export function firstText(result: ToolResult) { return result.content[0]?.text ?? ""; } From a67794a1091545c699b7e0f6a515666048ff314b Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Tue, 2 Jun 2026 23:42:01 +0900 Subject: [PATCH 07/24] chore: reformat codebase using biome with space indentation and single quotes --- .release-tools/cache.json | 30 +- .release-tools/config.ts | 8 +- biome.json | 68 +- knip.json | 6 +- package.json | 148 ++--- src/auth/oauth.ts | 737 ++++++++++----------- src/index.ts | 2 +- src/models/catalog.ts | 218 +++--- src/payload/sanitize.ts | 499 +++++++------- src/provider/quota.ts | 199 +++--- src/provider/register.ts | 160 ++--- src/provider/status.ts | 111 ++-- src/provider/stream.ts | 66 +- src/provider/toolScope.ts | 47 +- src/shared/errors.ts | 64 +- src/tools/files.ts | 1103 +++++++++++++++---------------- src/tools/register.ts | 14 +- src/tools/rendering.ts | 251 ++++--- src/tools/search.ts | 391 ++++++----- src/tools/shell.ts | 205 +++--- tests/auth/oauth.test.ts | 486 +++++++------- tests/models/catalog.test.ts | 101 ++- tests/payload/sanitize.test.ts | 348 +++++----- tests/provider/package.test.ts | 108 ++- tests/provider/register.test.ts | 1076 ++++++++++++++---------------- tests/shared/errors.test.ts | 26 +- tests/tools/files.test.ts | 738 ++++++++++----------- tests/tools/register.test.ts | 54 +- tests/tools/rendering.test.ts | 168 +++-- tests/tools/search.test.ts | 410 ++++++------ tests/tools/shell.test.ts | 238 ++++--- tests/tools/toolTestHelpers.ts | 150 ++--- tsconfig.json | 32 +- vitest.config.ts | 18 +- 34 files changed, 3955 insertions(+), 4325 deletions(-) diff --git a/.release-tools/cache.json b/.release-tools/cache.json index 658e7f1..692beef 100644 --- a/.release-tools/cache.json +++ b/.release-tools/cache.json @@ -1,17 +1,17 @@ { - "scripts": { - "typecheck": "tsc --noEmit", - "knip": "knip --production", - "lint": "biome check --write .", - "lint:ci": "biome ci .", - "check-duplicates": "bunx jscpd src tests --exitCode 1 --reporters ai --noTips", - "check": "bun run lint && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", - "check:ci": "bun run lint:ci && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", - "prepare": "husky" - }, - "lintStaged": { - "*": ["biome check --write --no-errors-on-unmatched"] - }, - "tsconfig": "{\n\t\"compilerOptions\": {\n\t\t\"target\": \"ES2022\",\n\t\t\"module\": \"ES2022\",\n\t\t\"moduleResolution\": \"bundler\",\n\t\t\"strict\": true,\n\t\t\"esModuleInterop\": true,\n\t\t\"skipLibCheck\": true,\n\t\t\"forceConsistentCasingInFileNames\": true,\n\t\t\"resolveJsonModule\": true,\n\t\t\"declaration\": true,\n\t\t\"declarationMap\": true,\n\t\t\"sourceMap\": true,\n\t\t\"outDir\": \"./dist\",\n\t\t\"rootDir\": \".\"\n\t},\n\t\"include\": [\"src/**/*.ts\"],\n\t\"exclude\": [\"node_modules\", \"dist\"]\n}\n", - "installedDeps": [] + "scripts": { + "typecheck": "tsc --noEmit", + "knip": "knip --production", + "lint": "biome check --write .", + "lint:ci": "biome ci .", + "check-duplicates": "bunx jscpd src tests --exitCode 1 --reporters ai --noTips", + "check": "bun run lint && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", + "check:ci": "bun run lint:ci && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", + "prepare": "husky" + }, + "lintStaged": { + "*": ["biome check --write --no-errors-on-unmatched"] + }, + "tsconfig": "{\n\t\"compilerOptions\": {\n\t\t\"target\": \"ES2022\",\n\t\t\"module\": \"ES2022\",\n\t\t\"moduleResolution\": \"bundler\",\n\t\t\"strict\": true,\n\t\t\"esModuleInterop\": true,\n\t\t\"skipLibCheck\": true,\n\t\t\"forceConsistentCasingInFileNames\": true,\n\t\t\"resolveJsonModule\": true,\n\t\t\"declaration\": true,\n\t\t\"declarationMap\": true,\n\t\t\"sourceMap\": true,\n\t\t\"outDir\": \"./dist\",\n\t\t\"rootDir\": \".\"\n\t},\n\t\"include\": [\"src/**/*.ts\"],\n\t\"exclude\": [\"node_modules\", \"dist\"]\n}\n", + "installedDeps": [] } diff --git a/.release-tools/config.ts b/.release-tools/config.ts index 8636522..c595f22 100644 --- a/.release-tools/config.ts +++ b/.release-tools/config.ts @@ -1,7 +1,7 @@ -import { defineConfig } from "release-tools/config"; +import { defineConfig } from 'release-tools/config'; export default defineConfig({ - packageName: "pi-grok-cli", - repo: "kenryu42/pi-grok-cli", - excludedAuthors: ["kenryu42"], + packageName: 'pi-grok-cli', + repo: 'kenryu42/pi-grok-cli', + excludedAuthors: ['kenryu42'], }); diff --git a/biome.json b/biome.json index eb66ea5..d8abd7a 100644 --- a/biome.json +++ b/biome.json @@ -1,34 +1,38 @@ { - "$schema": "https://biomejs.dev/schemas/2.4.16/schema.json", - "vcs": { - "enabled": true, - "clientKind": "git", - "useIgnoreFile": true - }, - "files": { - "ignoreUnknown": false - }, - "formatter": { - "enabled": true, - "indentStyle": "tab" - }, - "linter": { - "enabled": true, - "rules": { - "recommended": true - } - }, - "javascript": { - "formatter": { - "quoteStyle": "double" - } - }, - "assist": { - "enabled": true, - "actions": { - "source": { - "organizeImports": "on" - } - } - } + "$schema": "https://biomejs.dev/schemas/2.4.16/schema.json", + "vcs": { + "enabled": true, + "clientKind": "git", + "useIgnoreFile": true + }, + "files": { + "ignoreUnknown": false + }, + "formatter": { + "enabled": true, + "indentStyle": "space", + "indentWidth": 2, + "lineWidth": 100 + }, + "linter": { + "enabled": true, + "rules": { + "recommended": true + } + }, + "javascript": { + "formatter": { + "quoteStyle": "single", + "trailingCommas": "all", + "semicolons": "always" + } + }, + "assist": { + "enabled": true, + "actions": { + "source": { + "organizeImports": "on" + } + } + } } diff --git a/knip.json b/knip.json index 8a668a3..5c61008 100644 --- a/knip.json +++ b/knip.json @@ -1,5 +1,5 @@ { - "entry": ["src/shared/backup_worker.ts"], - "project": ["**/*.ts", "!**/*.d.ts", "!.release-tools/**"], - "ignoreDependencies": ["lint-staged"] + "entry": ["src/shared/backup_worker.ts"], + "project": ["**/*.ts", "!**/*.d.ts", "!.release-tools/**"], + "ignoreDependencies": ["lint-staged"] } diff --git a/package.json b/package.json index 7da54d9..ef87ea9 100644 --- a/package.json +++ b/package.json @@ -1,76 +1,76 @@ { - "name": "pi-grok-cli", - "version": "0.1.0", - "description": "Use Grok CLI's API endpoint in pi.", - "keywords": [ - "pi-package", - "pi-extension", - "xai", - "grok", - "grok-cli", - "oauth", - "xai-oauth" - ], - "type": "module", - "main": "./src/index.ts", - "files": [ - "README.md", - "src", - "tsconfig.json" - ], - "scripts": { - "test": "vitest run --reporter=agent", - "coverage": "vitest run --reporter=agent --coverage", - "typecheck": "tsc --noEmit", - "prepack": "bun run test && bun run coverage && bun run typecheck", - "knip": "knip --production", - "lint": "biome check --write .", - "lint:ci": "biome ci .", - "check": "bun run lint && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", - "check:ci": "bun run lint:ci && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", - "prepare": "husky", - "check-duplicates": "bunx jscpd src tests --exitCode 1 --reporters ai --noTips" - }, - "author": { - "name": "J Liew", - "email": "jliew@420024lab.com" - }, - "license": "MIT", - "repository": { - "type": "git", - "url": "git+https://github.com/kenryu42/pi-grok-cli.git" - }, - "bugs": { - "url": "https://github.com/kenryu42/pi-grok-cli/issues" - }, - "homepage": "https://github.com/kenryu42/pi-grok-cli#readme", - "pi": { - "extensions": [ - "./src/index.ts" - ] - }, - "peerDependencies": { - "@earendil-works/pi-ai": "*", - "@earendil-works/pi-coding-agent": "*", - "@earendil-works/pi-tui": "*" - }, - "devDependencies": { - "@biomejs/biome": "2.4.16", - "@earendil-works/pi-ai": "^0.78.0", - "@earendil-works/pi-coding-agent": "^0.78.0", - "@earendil-works/pi-tui": "^0.78.0", - "@vitest/coverage-v8": "^4.1.8", - "husky": "^9.1.7", - "jscpd": "^4.2.4", - "knip": "^6.15.0", - "lint-staged": "^17.0.7", - "release-tools": "github:kenryu42/release-tools", - "typescript": "^6.0.3", - "vitest": "^4.1.8" - }, - "lint-staged": { - "*": [ - "biome check --write --no-errors-on-unmatched" - ] - } + "name": "pi-grok-cli", + "version": "0.1.0", + "description": "Use Grok CLI's API endpoint in pi.", + "keywords": [ + "pi-package", + "pi-extension", + "xai", + "grok", + "grok-cli", + "oauth", + "xai-oauth" + ], + "type": "module", + "main": "./src/index.ts", + "files": [ + "README.md", + "src", + "tsconfig.json" + ], + "scripts": { + "test": "vitest run --reporter=agent", + "coverage": "vitest run --reporter=agent --coverage", + "typecheck": "tsc --noEmit", + "prepack": "bun run test && bun run coverage && bun run typecheck", + "knip": "knip --production", + "lint": "biome check --write .", + "lint:ci": "biome ci .", + "check": "bun run lint && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", + "check:ci": "bun run lint:ci && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", + "prepare": "husky", + "check-duplicates": "bunx jscpd src tests --exitCode 1 --reporters ai --noTips" + }, + "author": { + "name": "J Liew", + "email": "jliew@420024lab.com" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "git+https://github.com/kenryu42/pi-grok-cli.git" + }, + "bugs": { + "url": "https://github.com/kenryu42/pi-grok-cli/issues" + }, + "homepage": "https://github.com/kenryu42/pi-grok-cli#readme", + "pi": { + "extensions": [ + "./src/index.ts" + ] + }, + "peerDependencies": { + "@earendil-works/pi-ai": "*", + "@earendil-works/pi-coding-agent": "*", + "@earendil-works/pi-tui": "*" + }, + "devDependencies": { + "@biomejs/biome": "2.4.16", + "@earendil-works/pi-ai": "^0.78.0", + "@earendil-works/pi-coding-agent": "^0.78.0", + "@earendil-works/pi-tui": "^0.78.0", + "@vitest/coverage-v8": "^4.1.8", + "husky": "^9.1.7", + "jscpd": "^4.2.4", + "knip": "^6.15.0", + "lint-staged": "^17.0.7", + "release-tools": "github:kenryu42/release-tools", + "typescript": "^6.0.3", + "vitest": "^4.1.8" + }, + "lint-staged": { + "*": [ + "biome check --write --no-errors-on-unmatched" + ] + } } diff --git a/src/auth/oauth.ts b/src/auth/oauth.ts index a4dca37..73a5b85 100644 --- a/src/auth/oauth.ts +++ b/src/auth/oauth.ts @@ -9,447 +9,424 @@ * cli-chat-proxy.grok.com instead of api.x.ai. */ -import { createServer } from "node:http"; -import { XaiErrorCode, XaiOAuthError } from "../shared/errors.js"; +import { createServer } from 'node:http'; +import { XaiErrorCode, XaiOAuthError } from '../shared/errors.js'; // ─── Constants ──────────────────────────────────────────────────────────────── -const DEFAULT_BASE_URL = "https://cli-chat-proxy.grok.com/v1"; -const ISSUER = "https://auth.x.ai"; +const DEFAULT_BASE_URL = 'https://cli-chat-proxy.grok.com/v1'; +const ISSUER = 'https://auth.x.ai'; const DISCOVERY_URL = `${ISSUER}/.well-known/openid-configuration`; -const CLIENT_ID = - process.env.PI_GROK_CLI_OAUTH_CLIENT_ID || - "b1a00492-073a-47ea-816f-4c329264a828"; +const CLIENT_ID = process.env.PI_GROK_CLI_OAUTH_CLIENT_ID || 'b1a00492-073a-47ea-816f-4c329264a828'; const SCOPE = - process.env.PI_GROK_CLI_OAUTH_SCOPE || - "openid profile email offline_access grok-cli:access api:access"; -const CALLBACK_HOST = process.env.PI_GROK_CLI_CALLBACK_HOST || "127.0.0.1"; -const CALLBACK_PORT = Number.parseInt( - process.env.PI_GROK_CLI_CALLBACK_PORT || "56122", - 10, -); -const CALLBACK_PATH = "/callback"; + process.env.PI_GROK_CLI_OAUTH_SCOPE || + 'openid profile email offline_access grok-cli:access api:access'; +const CALLBACK_HOST = process.env.PI_GROK_CLI_CALLBACK_HOST || '127.0.0.1'; +const CALLBACK_PORT = Number.parseInt(process.env.PI_GROK_CLI_CALLBACK_PORT || '56122', 10); +const CALLBACK_PATH = '/callback'; /** Refresh 120s before actual expiry. */ const REFRESH_SKEW_MS = 120_000; // ─── Types ──────────────────────────────────────────────────────────────────── interface XaiDiscovery { - authorization_endpoint: string; - token_endpoint: string; + authorization_endpoint: string; + token_endpoint: string; } export interface XaiOAuthCredentials { - [key: string]: unknown; - refresh: string; - access: string; - expires: number; - tokenEndpoint?: string; - discovery?: XaiDiscovery; - idToken?: string; - tokenType?: string; - baseUrl?: string; + [key: string]: unknown; + refresh: string; + access: string; + expires: number; + tokenEndpoint?: string; + discovery?: XaiDiscovery; + idToken?: string; + tokenType?: string; + baseUrl?: string; } // ─── Helpers ────────────────────────────────────────────────────────────────── export function getBaseUrl(): string { - return ( - process.env.PI_GROK_CLI_BASE_URL || - process.env.GROK_CLI_BASE_URL || - DEFAULT_BASE_URL - ).replace(/\/+$/, ""); + return ( + process.env.PI_GROK_CLI_BASE_URL || + process.env.GROK_CLI_BASE_URL || + DEFAULT_BASE_URL + ).replace(/\/+$/, ''); } function base64Url(buffer: ArrayBuffer | Uint8Array): string { - const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); - let binary = ""; - for (const b of bytes) binary += String.fromCharCode(b); - return btoa(binary) - .replace(/\+/g, "-") - .replace(/\//g, "_") - .replace(/=+$/, ""); + const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); + let binary = ''; + for (const b of bytes) binary += String.fromCharCode(b); + return btoa(binary).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, ''); } // ─── PKCE ───────────────────────────────────────────────────────────────────── async function generatePKCE(): Promise<{ - verifier: string; - challenge: string; + verifier: string; + challenge: string; }> { - const verifier = base64Url(crypto.getRandomValues(new Uint8Array(32))); - const hash = await crypto.subtle.digest( - "SHA-256", - new TextEncoder().encode(verifier), - ); - return { verifier, challenge: base64Url(hash) }; + const verifier = base64Url(crypto.getRandomValues(new Uint8Array(32))); + const hash = await crypto.subtle.digest('SHA-256', new TextEncoder().encode(verifier)); + return { verifier, challenge: base64Url(hash) }; } // ─── Endpoint validation ────────────────────────────────────────────────────── function validateEndpoint(value: string, field: string): string { - let url: URL; - try { - url = new URL(value); - } catch { - throw new XaiOAuthError( - `xAI OAuth discovery returned invalid ${field}: ${value}`, - XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - ); - } - if (url.protocol !== "https:") { - throw new XaiOAuthError( - `xAI OAuth ${field} must use HTTPS: ${value}`, - XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - ); - } - const host = url.hostname.toLowerCase(); - if ( - host !== "x.ai" && - host !== "auth.x.ai" && - host !== "accounts.x.ai" && - !host.endsWith(".x.ai") - ) { - throw new XaiOAuthError( - `Refusing non-xAI OAuth ${field}: ${value}`, - XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - ); - } - return url.toString(); + let url: URL; + try { + url = new URL(value); + } catch { + throw new XaiOAuthError( + `xAI OAuth discovery returned invalid ${field}: ${value}`, + XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + ); + } + if (url.protocol !== 'https:') { + throw new XaiOAuthError( + `xAI OAuth ${field} must use HTTPS: ${value}`, + XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + ); + } + const host = url.hostname.toLowerCase(); + if ( + host !== 'x.ai' && + host !== 'auth.x.ai' && + host !== 'accounts.x.ai' && + !host.endsWith('.x.ai') + ) { + throw new XaiOAuthError( + `Refusing non-xAI OAuth ${field}: ${value}`, + XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + ); + } + return url.toString(); } // ─── OIDC Discovery ────────────────────────────────────────────────────────── async function discover(): Promise { - let response: Response; - try { - response = await fetch(DISCOVERY_URL, { - headers: { Accept: "application/json" }, - signal: AbortSignal.timeout(15_000), - }); - } catch (cause) { - throw new XaiOAuthError( - `xAI OIDC discovery failed: ${cause instanceof Error ? cause.message : String(cause)}`, - XaiErrorCode.DISCOVERY_FAILED, - ); - } - if (!response.ok) { - throw new XaiOAuthError( - `xAI OIDC discovery returned ${response.status}`, - XaiErrorCode.DISCOVERY_FAILED, - ); - } - const payload = (await response.json()) as Record; - const authorizationEndpoint = validateEndpoint( - String(payload.authorization_endpoint ?? ""), - "authorization_endpoint", - ); - const tokenEndpoint = validateEndpoint( - String(payload.token_endpoint ?? ""), - "token_endpoint", - ); - return { - authorization_endpoint: authorizationEndpoint, - token_endpoint: tokenEndpoint, - }; + let response: Response; + try { + response = await fetch(DISCOVERY_URL, { + headers: { Accept: 'application/json' }, + signal: AbortSignal.timeout(15_000), + }); + } catch (cause) { + throw new XaiOAuthError( + `xAI OIDC discovery failed: ${cause instanceof Error ? cause.message : String(cause)}`, + XaiErrorCode.DISCOVERY_FAILED, + ); + } + if (!response.ok) { + throw new XaiOAuthError( + `xAI OIDC discovery returned ${response.status}`, + XaiErrorCode.DISCOVERY_FAILED, + ); + } + const payload = (await response.json()) as Record; + const authorizationEndpoint = validateEndpoint( + String(payload.authorization_endpoint ?? ''), + 'authorization_endpoint', + ); + const tokenEndpoint = validateEndpoint(String(payload.token_endpoint ?? ''), 'token_endpoint'); + return { + authorization_endpoint: authorizationEndpoint, + token_endpoint: tokenEndpoint, + }; } // ─── Loopback callback server ──────────────────────────────────────────────── interface CallbackResult { - code?: string; - state?: string; - error?: string; - errorDescription?: string; + code?: string; + state?: string; + error?: string; + errorDescription?: string; } function startCallbackServer(): Promise<{ - server: import("node:http").Server; - redirectUri: string; - waitForCallback: (timeoutMs: number) => Promise; + server: import('node:http').Server; + redirectUri: string; + waitForCallback: (timeoutMs: number) => Promise; }> { - let settle: ((value: CallbackResult) => void) | undefined; - const callbackPromise = new Promise((resolve) => { - settle = resolve; - }); - - const server = createServer((req, res) => { - try { - const origin = req.headers.origin; - if ( - origin === "https://accounts.x.ai" || - origin === "https://auth.x.ai" - ) { - res.setHeader("Access-Control-Allow-Origin", origin); - res.setHeader("Access-Control-Allow-Methods", "GET, OPTIONS"); - res.setHeader("Access-Control-Allow-Headers", "Content-Type"); - res.setHeader("Access-Control-Allow-Private-Network", "true"); - res.setHeader("Vary", "Origin"); - } - if (req.method === "OPTIONS") { - res.statusCode = 204; - res.end(); - return; - } - - const url = new URL(req.url ?? "/", `http://${CALLBACK_HOST}`); - if (url.pathname !== CALLBACK_PATH) { - res.statusCode = 404; - res.end("Not found"); - return; - } - - const result: CallbackResult = { - code: url.searchParams.get("code") ?? undefined, - state: url.searchParams.get("state") ?? undefined, - error: url.searchParams.get("error") ?? undefined, - errorDescription: - url.searchParams.get("error_description") ?? undefined, - }; - - res.statusCode = result.error ? 400 : 200; - res.setHeader("Content-Type", "text/html; charset=utf-8"); - const html = result.error - ? "

xAI authorization failed.

You can close this tab." - : "

xAI authorization received.

You can close this tab."; - res.end(html); - settle?.(result); - } catch { - res.statusCode = 500; - res.end("Internal error"); - } - }); - - const listen = (port: number) => - new Promise((resolve, reject) => { - server.once("error", reject); - server.listen(port, CALLBACK_HOST, () => { - server.removeListener("error", reject); - const addr = server.address(); - resolve(typeof addr === "object" && addr ? addr.port : port); - }); - }); - - return (async () => { - let actualPort: number; - try { - actualPort = await listen(CALLBACK_PORT); - } catch { - actualPort = await listen(0); - } - const redirectUri = `http://${CALLBACK_HOST}:${actualPort}${CALLBACK_PATH}`; - return { - server, - redirectUri, - waitForCallback: (timeoutMs: number) => - Promise.race([ - callbackPromise, - new Promise((resolve) => - setTimeout( - () => - resolve({ - error: "timeout", - errorDescription: "Timed out waiting for xAI OAuth callback.", - }), - timeoutMs, - ), - ), - ]), - }; - })(); + let settle: ((value: CallbackResult) => void) | undefined; + const callbackPromise = new Promise((resolve) => { + settle = resolve; + }); + + const server = createServer((req, res) => { + try { + const origin = req.headers.origin; + if (origin === 'https://accounts.x.ai' || origin === 'https://auth.x.ai') { + res.setHeader('Access-Control-Allow-Origin', origin); + res.setHeader('Access-Control-Allow-Methods', 'GET, OPTIONS'); + res.setHeader('Access-Control-Allow-Headers', 'Content-Type'); + res.setHeader('Access-Control-Allow-Private-Network', 'true'); + res.setHeader('Vary', 'Origin'); + } + if (req.method === 'OPTIONS') { + res.statusCode = 204; + res.end(); + return; + } + + const url = new URL(req.url ?? '/', `http://${CALLBACK_HOST}`); + if (url.pathname !== CALLBACK_PATH) { + res.statusCode = 404; + res.end('Not found'); + return; + } + + const result: CallbackResult = { + code: url.searchParams.get('code') ?? undefined, + state: url.searchParams.get('state') ?? undefined, + error: url.searchParams.get('error') ?? undefined, + errorDescription: url.searchParams.get('error_description') ?? undefined, + }; + + res.statusCode = result.error ? 400 : 200; + res.setHeader('Content-Type', 'text/html; charset=utf-8'); + const html = result.error + ? '

xAI authorization failed.

You can close this tab.' + : '

xAI authorization received.

You can close this tab.'; + res.end(html); + settle?.(result); + } catch { + res.statusCode = 500; + res.end('Internal error'); + } + }); + + const listen = (port: number) => + new Promise((resolve, reject) => { + server.once('error', reject); + server.listen(port, CALLBACK_HOST, () => { + server.removeListener('error', reject); + const addr = server.address(); + resolve(typeof addr === 'object' && addr ? addr.port : port); + }); + }); + + return (async () => { + let actualPort: number; + try { + actualPort = await listen(CALLBACK_PORT); + } catch { + actualPort = await listen(0); + } + const redirectUri = `http://${CALLBACK_HOST}:${actualPort}${CALLBACK_PATH}`; + return { + server, + redirectUri, + waitForCallback: (timeoutMs: number) => + Promise.race([ + callbackPromise, + new Promise((resolve) => + setTimeout( + () => + resolve({ + error: 'timeout', + errorDescription: 'Timed out waiting for xAI OAuth callback.', + }), + timeoutMs, + ), + ), + ]), + }; + })(); } // ─── Token exchange ─────────────────────────────────────────────────────────── async function exchangeCode( - tokenEndpoint: string, - code: string, - redirectUri: string, - verifier: string, + tokenEndpoint: string, + code: string, + redirectUri: string, + verifier: string, ): Promise { - const response = await fetch(tokenEndpoint, { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - Accept: "application/json", - }, - body: new URLSearchParams({ - grant_type: "authorization_code", - client_id: CLIENT_ID, - code, - redirect_uri: redirectUri, - code_verifier: verifier, - }), - }); - if (!response.ok) { - throw new XaiOAuthError( - `xAI token exchange failed: ${response.status} ${await response.text()}`, - XaiErrorCode.TOKEN_EXCHANGE_FAILED, - ); - } - const payload = (await response.json()) as Record; - const access = String(payload.access_token ?? ""); - const refresh = String(payload.refresh_token ?? ""); - if (!access) { - throw new XaiOAuthError( - "xAI token exchange did not return access_token.", - XaiErrorCode.TOKEN_EXCHANGE_INVALID, - ); - } - if (!refresh) { - throw new XaiOAuthError( - "xAI token exchange did not return refresh_token.", - XaiErrorCode.TOKEN_EXCHANGE_INVALID, - ); - } - const expiresIn = - typeof payload.expires_in === "number" - ? payload.expires_in - : Number(payload.expires_in ?? 3600); - return { - access, - refresh, - expires: Date.now() + expiresIn * 1000 - REFRESH_SKEW_MS, - tokenEndpoint, - discovery: { authorization_endpoint: "", token_endpoint: tokenEndpoint }, - idToken: String(payload.id_token ?? ""), - tokenType: String(payload.token_type ?? "Bearer"), - baseUrl: getBaseUrl(), - }; + const response = await fetch(tokenEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }, + body: new URLSearchParams({ + grant_type: 'authorization_code', + client_id: CLIENT_ID, + code, + redirect_uri: redirectUri, + code_verifier: verifier, + }), + }); + if (!response.ok) { + throw new XaiOAuthError( + `xAI token exchange failed: ${response.status} ${await response.text()}`, + XaiErrorCode.TOKEN_EXCHANGE_FAILED, + ); + } + const payload = (await response.json()) as Record; + const access = String(payload.access_token ?? ''); + const refresh = String(payload.refresh_token ?? ''); + if (!access) { + throw new XaiOAuthError( + 'xAI token exchange did not return access_token.', + XaiErrorCode.TOKEN_EXCHANGE_INVALID, + ); + } + if (!refresh) { + throw new XaiOAuthError( + 'xAI token exchange did not return refresh_token.', + XaiErrorCode.TOKEN_EXCHANGE_INVALID, + ); + } + const expiresIn = + typeof payload.expires_in === 'number' + ? payload.expires_in + : Number(payload.expires_in ?? 3600); + return { + access, + refresh, + expires: Date.now() + expiresIn * 1000 - REFRESH_SKEW_MS, + tokenEndpoint, + discovery: { authorization_endpoint: '', token_endpoint: tokenEndpoint }, + idToken: String(payload.id_token ?? ''), + tokenType: String(payload.token_type ?? 'Bearer'), + baseUrl: getBaseUrl(), + }; } // ─── Login (called by pi's /login flow) ────────────────────────────────────── export async function login( - callbacks: import("@earendil-works/pi-ai").OAuthLoginCallbacks, -): Promise { - const discovery = await discover(); - const { verifier, challenge } = await generatePKCE(); - const state = base64Url(crypto.getRandomValues(new Uint8Array(16))); - const nonce = base64Url(crypto.getRandomValues(new Uint8Array(16))); - const callback = await startCallbackServer(); - - try { - const authUrl = new URL(discovery.authorization_endpoint); - authUrl.searchParams.set("response_type", "code"); - authUrl.searchParams.set("client_id", CLIENT_ID); - authUrl.searchParams.set("redirect_uri", callback.redirectUri); - authUrl.searchParams.set("scope", SCOPE); - authUrl.searchParams.set("code_challenge", challenge); - authUrl.searchParams.set("code_challenge_method", "S256"); - authUrl.searchParams.set("state", state); - authUrl.searchParams.set("nonce", nonce); - authUrl.searchParams.set("plan", "generic"); - authUrl.searchParams.set("referrer", "pi-grok-cli"); - - callbacks.onAuth({ - url: authUrl.toString(), - instructions: `Authorize xAI, then return to pi. Callback listener: ${callback.redirectUri}`, - }); - - const result = await callback.waitForCallback(180_000); - - if (result.error) { - throw new XaiOAuthError( - result.errorDescription ?? result.error, - XaiErrorCode.AUTHORIZATION_FAILED, - ); - } - if (result.state !== state) { - throw new XaiOAuthError( - "xAI OAuth state mismatch — possible CSRF.", - XaiErrorCode.STATE_MISMATCH, - ); - } - if (!result.code) { - throw new XaiOAuthError( - "xAI OAuth callback did not include an authorization code.", - XaiErrorCode.CODE_MISSING, - ); - } - - const credentials = await exchangeCode( - discovery.token_endpoint, - result.code, - callback.redirectUri, - verifier, - ); - credentials.discovery = discovery; - return credentials; - } finally { - callback.server.close(); - } + callbacks: import('@earendil-works/pi-ai').OAuthLoginCallbacks, +): Promise { + const discovery = await discover(); + const { verifier, challenge } = await generatePKCE(); + const state = base64Url(crypto.getRandomValues(new Uint8Array(16))); + const nonce = base64Url(crypto.getRandomValues(new Uint8Array(16))); + const callback = await startCallbackServer(); + + try { + const authUrl = new URL(discovery.authorization_endpoint); + authUrl.searchParams.set('response_type', 'code'); + authUrl.searchParams.set('client_id', CLIENT_ID); + authUrl.searchParams.set('redirect_uri', callback.redirectUri); + authUrl.searchParams.set('scope', SCOPE); + authUrl.searchParams.set('code_challenge', challenge); + authUrl.searchParams.set('code_challenge_method', 'S256'); + authUrl.searchParams.set('state', state); + authUrl.searchParams.set('nonce', nonce); + authUrl.searchParams.set('plan', 'generic'); + authUrl.searchParams.set('referrer', 'pi-grok-cli'); + + callbacks.onAuth({ + url: authUrl.toString(), + instructions: `Authorize xAI, then return to pi. Callback listener: ${callback.redirectUri}`, + }); + + const result = await callback.waitForCallback(180_000); + + if (result.error) { + throw new XaiOAuthError( + result.errorDescription ?? result.error, + XaiErrorCode.AUTHORIZATION_FAILED, + ); + } + if (result.state !== state) { + throw new XaiOAuthError( + 'xAI OAuth state mismatch — possible CSRF.', + XaiErrorCode.STATE_MISMATCH, + ); + } + if (!result.code) { + throw new XaiOAuthError( + 'xAI OAuth callback did not include an authorization code.', + XaiErrorCode.CODE_MISSING, + ); + } + + const credentials = await exchangeCode( + discovery.token_endpoint, + result.code, + callback.redirectUri, + verifier, + ); + credentials.discovery = discovery; + return credentials; + } finally { + callback.server.close(); + } } // ─── Token refresh ──────────────────────────────────────────────────────────── export async function refresh( - credentials: import("@earendil-works/pi-ai").OAuthCredentials, -): Promise { - const xai = credentials as XaiOAuthCredentials; - const tokenEndpoint = - xai.tokenEndpoint || - xai.discovery?.token_endpoint || - (await discover()).token_endpoint; - validateEndpoint(tokenEndpoint, "token_endpoint"); - - if (!credentials.refresh) { - throw new XaiOAuthError( - "Missing refresh_token. Re-login required.", - XaiErrorCode.REFRESH_MISSING, - true, - ); - } - - const response = await fetch(tokenEndpoint, { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - Accept: "application/json", - }, - body: new URLSearchParams({ - grant_type: "refresh_token", - client_id: CLIENT_ID, - refresh_token: credentials.refresh, - }), - }); - - if (!response.ok) { - const isFatal = - response.status === 400 || - response.status === 401 || - response.status === 403; - throw new XaiOAuthError( - `xAI token refresh failed: ${response.status} ${await response.text()}`, - XaiErrorCode.REFRESH_FAILED, - isFatal, - ); - } - - const payload = (await response.json()) as Record; - const access = String(payload.access_token ?? ""); - if (!access) { - throw new XaiOAuthError( - "xAI token refresh did not return access_token.", - XaiErrorCode.REFRESH_FAILED, - true, - ); - } - - const refresh_new = String(payload.refresh_token ?? credentials.refresh); - const expiresIn = - typeof payload.expires_in === "number" - ? payload.expires_in - : Number(payload.expires_in ?? 3600); - - return { - ...xai, - access, - refresh: refresh_new, - expires: Date.now() + expiresIn * 1000 - REFRESH_SKEW_MS, - tokenEndpoint, - idToken: String(payload.id_token ?? xai.idToken ?? ""), - tokenType: String(payload.token_type ?? xai.tokenType ?? "Bearer"), - baseUrl: getBaseUrl(), - }; + credentials: import('@earendil-works/pi-ai').OAuthCredentials, +): Promise { + const xai = credentials as XaiOAuthCredentials; + const tokenEndpoint = + xai.tokenEndpoint || xai.discovery?.token_endpoint || (await discover()).token_endpoint; + validateEndpoint(tokenEndpoint, 'token_endpoint'); + + if (!credentials.refresh) { + throw new XaiOAuthError( + 'Missing refresh_token. Re-login required.', + XaiErrorCode.REFRESH_MISSING, + true, + ); + } + + const response = await fetch(tokenEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }, + body: new URLSearchParams({ + grant_type: 'refresh_token', + client_id: CLIENT_ID, + refresh_token: credentials.refresh, + }), + }); + + if (!response.ok) { + const isFatal = response.status === 400 || response.status === 401 || response.status === 403; + throw new XaiOAuthError( + `xAI token refresh failed: ${response.status} ${await response.text()}`, + XaiErrorCode.REFRESH_FAILED, + isFatal, + ); + } + + const payload = (await response.json()) as Record; + const access = String(payload.access_token ?? ''); + if (!access) { + throw new XaiOAuthError( + 'xAI token refresh did not return access_token.', + XaiErrorCode.REFRESH_FAILED, + true, + ); + } + + const refresh_new = String(payload.refresh_token ?? credentials.refresh); + const expiresIn = + typeof payload.expires_in === 'number' + ? payload.expires_in + : Number(payload.expires_in ?? 3600); + + return { + ...xai, + access, + refresh: refresh_new, + expires: Date.now() + expiresIn * 1000 - REFRESH_SKEW_MS, + tokenEndpoint, + idToken: String(payload.id_token ?? xai.idToken ?? ''), + tokenType: String(payload.token_type ?? xai.tokenType ?? 'Bearer'), + baseUrl: getBaseUrl(), + }; } diff --git a/src/index.ts b/src/index.ts index 1d7f42f..ea9eb6b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -4,4 +4,4 @@ * Brings access to the Grok CLI's endpoint into pi. */ -export { default } from "./provider/register.js"; +export { default } from './provider/register.js'; diff --git a/src/models/catalog.ts b/src/models/catalog.ts index 31334ee..5fe3eb7 100644 --- a/src/models/catalog.ts +++ b/src/models/catalog.ts @@ -12,20 +12,20 @@ const COST_420 = { input: 2, output: 6, cacheRead: 0.2, cacheWrite: 0 }; // ─── Model type ─────────────────────────────────────────────────────────────── export interface GrokCliModelConfig { - id: string; - name: string; - reasoning: boolean; - input: ("text" | "image")[]; - cost: { - input: number; - output: number; - cacheRead: number; - cacheWrite: number; - }; - contextWindow: number; - maxTokens: number; - /** Models that don't support reasoning.effort get a thinkingLevelMap. */ - thinkingLevelMap?: Record; + id: string; + name: string; + reasoning: boolean; + input: ('text' | 'image')[]; + cost: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + }; + contextWindow: number; + maxTokens: number; + /** Models that don't support reasoning.effort get a thinkingLevelMap. */ + thinkingLevelMap?: Record; } // ─── Hardcoded fallback catalog ─────────────────────────────────────────────── @@ -34,76 +34,76 @@ export interface GrokCliModelConfig { // the actual traffic captured through cli-chat-proxy.grok.com. const FALLBACK_MODELS: GrokCliModelConfig[] = [ - { - id: "grok-composer-2.5-fast", - name: "Composer 2.5 Fast (Grok CLI)", - reasoning: false, - input: ["text", "image"], - cost: COST_COMPOSER, - contextWindow: 200_000, - maxTokens: 30_000, - thinkingLevelMap: { - off: "none", - minimal: null, - low: null, - medium: null, - high: null, - xhigh: null, - }, - }, - { - id: "grok-build", - name: "Grok Build", - reasoning: true, - input: ["text", "image"], - cost: COST_BUILD, - contextWindow: 512_000, - maxTokens: 30_000, - }, - { - id: "grok-4.3", - name: "Grok 4.3", - reasoning: true, - input: ["text", "image"], - cost: COST_43, - contextWindow: 1_000_000, - maxTokens: 30_000, - }, - { - id: "grok-4.20-0309-reasoning", - name: "Grok 4.20 Reasoning", - reasoning: true, - input: ["text", "image"], - cost: COST_420, - contextWindow: 2_000_000, - maxTokens: 30_000, - }, - { - id: "grok-4.20-0309-non-reasoning", - name: "Grok 4.20 Non-Reasoning", - reasoning: false, - input: ["text", "image"], - cost: COST_420, - contextWindow: 2_000_000, - maxTokens: 30_000, - thinkingLevelMap: { - off: "none", - minimal: null, - low: null, - medium: null, - high: null, - xhigh: null, - }, - }, - { - id: "grok-4.20-multi-agent-0309", - name: "Grok 4.20 Multi-Agent", - reasoning: true, - input: ["text", "image"], - cost: COST_420, - contextWindow: 2_000_000, - maxTokens: 30_000, - }, + { + id: 'grok-composer-2.5-fast', + name: 'Composer 2.5 Fast (Grok CLI)', + reasoning: false, + input: ['text', 'image'], + cost: COST_COMPOSER, + contextWindow: 200_000, + maxTokens: 30_000, + thinkingLevelMap: { + off: 'none', + minimal: null, + low: null, + medium: null, + high: null, + xhigh: null, + }, + }, + { + id: 'grok-build', + name: 'Grok Build', + reasoning: true, + input: ['text', 'image'], + cost: COST_BUILD, + contextWindow: 512_000, + maxTokens: 30_000, + }, + { + id: 'grok-4.3', + name: 'Grok 4.3', + reasoning: true, + input: ['text', 'image'], + cost: COST_43, + contextWindow: 1_000_000, + maxTokens: 30_000, + }, + { + id: 'grok-4.20-0309-reasoning', + name: 'Grok 4.20 Reasoning', + reasoning: true, + input: ['text', 'image'], + cost: COST_420, + contextWindow: 2_000_000, + maxTokens: 30_000, + }, + { + id: 'grok-4.20-0309-non-reasoning', + name: 'Grok 4.20 Non-Reasoning', + reasoning: false, + input: ['text', 'image'], + cost: COST_420, + contextWindow: 2_000_000, + maxTokens: 30_000, + thinkingLevelMap: { + off: 'none', + minimal: null, + low: null, + medium: null, + high: null, + xhigh: null, + }, + }, + { + id: 'grok-4.20-multi-agent-0309', + name: 'Grok 4.20 Multi-Agent', + reasoning: true, + input: ['text', 'image'], + cost: COST_420, + contextWindow: 2_000_000, + maxTokens: 30_000, + }, ]; // ─── Reasoning-effort allowlist ─────────────────────────────────────────────── @@ -113,16 +113,16 @@ const FALLBACK_MODELS: GrokCliModelConfig[] = [ * Everything else gets the param stripped in the sanitizer. */ const EFFORT_CAPABLE_PREFIXES = [ - "grok-3-mini", - "grok-4.20-multi-agent", - "grok-4.3", - "grok-composer", + 'grok-3-mini', + 'grok-4.20-multi-agent', + 'grok-4.3', + 'grok-composer', ]; export function supportsReasoningEffort(modelId: string): boolean { - const parts = modelId.split("/"); - const name = parts.at(-1) ?? modelId; - return EFFORT_CAPABLE_PREFIXES.some((p) => name.toLowerCase().startsWith(p)); + const parts = modelId.split('/'); + const name = parts.at(-1) ?? modelId; + return EFFORT_CAPABLE_PREFIXES.some((p) => name.toLowerCase().startsWith(p)); } // ─── PI_GROK_CLI_MODELS env override ────────────────────────────────────────── @@ -132,23 +132,23 @@ export function supportsReasoningEffort(modelId: string): boolean { * it filters/reorders the fallback list; unknown IDs get sensible defaults. */ export function resolveModels(): GrokCliModelConfig[] { - const env = (process.env.PI_GROK_CLI_MODELS || "") - .split(",") - .map((s) => s.trim()) - .filter(Boolean); - if (env.length === 0) return FALLBACK_MODELS; + const env = (process.env.PI_GROK_CLI_MODELS || '') + .split(',') + .map((s) => s.trim()) + .filter(Boolean); + if (env.length === 0) return FALLBACK_MODELS; - const byId = new Map(FALLBACK_MODELS.map((m) => [m.id, m])); - return env.map( - (id) => - byId.get(id) ?? { - id, - name: id, - reasoning: true, - input: ["text"] as ("text" | "image")[], - cost: COST_BUILD, - contextWindow: 1_000_000, - maxTokens: 30_000, - }, - ); + const byId = new Map(FALLBACK_MODELS.map((m) => [m.id, m])); + return env.map( + (id) => + byId.get(id) ?? { + id, + name: id, + reasoning: true, + input: ['text'] as ('text' | 'image')[], + cost: COST_BUILD, + contextWindow: 1_000_000, + maxTokens: 30_000, + }, + ); } diff --git a/src/payload/sanitize.ts b/src/payload/sanitize.ts index aa952ac..2b3ccf5 100644 --- a/src/payload/sanitize.ts +++ b/src/payload/sanitize.ts @@ -19,212 +19,198 @@ * - Uses prompt_cache_key for session affinity */ -import { existsSync, readFileSync } from "node:fs"; -import { extname, isAbsolute, resolve } from "node:path"; -import { fileURLToPath } from "node:url"; -import { supportsReasoningEffort } from "../models/catalog.js"; +import { existsSync, readFileSync } from 'node:fs'; +import { extname, isAbsolute, resolve } from 'node:path'; +import { fileURLToPath } from 'node:url'; +import { supportsReasoningEffort } from '../models/catalog.js'; // ─── Content text extraction ───────────────────────────────────────────────── function textFromContent(content: unknown): string { - if (typeof content === "string") return content; - if (!Array.isArray(content)) return ""; - return content - .map((part) => { - if (typeof part === "string") return part; - if (!part || typeof part !== "object") return ""; - const item = part as Record; - const type = typeof item.type === "string" ? item.type : ""; - return ["text", "input_text", "output_text"].includes(type) && - typeof item.text === "string" - ? item.text - : ""; - }) - .filter(Boolean) - .join("\n"); + if (typeof content === 'string') return content; + if (!Array.isArray(content)) return ''; + return content + .map((part) => { + if (typeof part === 'string') return part; + if (!part || typeof part !== 'object') return ''; + const item = part as Record; + const type = typeof item.type === 'string' ? item.type : ''; + return ['text', 'input_text', 'output_text'].includes(type) && typeof item.text === 'string' + ? item.text + : ''; + }) + .filter(Boolean) + .join('\n'); } // ─── Image helpers ──────────────────────────────────────────────────────────── function stripShellQuotes(value: string): string { - const trimmed = value.trim(); - if ( - trimmed.length >= 2 && - ((trimmed.startsWith('"') && trimmed.endsWith('"')) || - (trimmed.startsWith("'") && trimmed.endsWith("'"))) - ) { - return trimmed.slice(1, -1); - } - return trimmed; + const trimmed = value.trim(); + if ( + trimmed.length >= 2 && + ((trimmed.startsWith('"') && trimmed.endsWith('"')) || + (trimmed.startsWith("'") && trimmed.endsWith("'"))) + ) { + return trimmed.slice(1, -1); + } + return trimmed; } function unescapeShellPath(value: string): string { - return stripShellQuotes(value).replace(/\\([\\\s'"()&;@])/g, "$1"); + return stripShellQuotes(value).replace(/\\([\\\s'"()&;@])/g, '$1'); } function imageMimeTypeForPath(path: string): string { - switch (extname(path).toLowerCase()) { - case ".jpg": - case ".jpeg": - return "image/jpeg"; - case ".png": - return "image/png"; - default: - throw new Error( - "xAI image understanding supports local .jpg, .jpeg, and .png files only", - ); - } + switch (extname(path).toLowerCase()) { + case '.jpg': + case '.jpeg': + return 'image/jpeg'; + case '.png': + return 'image/png'; + default: + throw new Error('xAI image understanding supports local .jpg, .jpeg, and .png files only'); + } } function resolveLocalImagePath(value: string): string | undefined { - const cleaned = unescapeShellPath(value); - if (!cleaned) return undefined; + const cleaned = unescapeShellPath(value); + if (!cleaned) return undefined; - if (cleaned.startsWith("file://")) { - try { - return fileURLToPath(cleaned); - } catch { - return undefined; - } - } + if (cleaned.startsWith('file://')) { + try { + return fileURLToPath(cleaned); + } catch { + return undefined; + } + } - const candidates = [cleaned]; - if (!isAbsolute(cleaned)) candidates.push(resolve(process.cwd(), cleaned)); + const candidates = [cleaned]; + if (!isAbsolute(cleaned)) candidates.push(resolve(process.cwd(), cleaned)); - return candidates.find((candidate) => existsSync(candidate)); + return candidates.find((candidate) => existsSync(candidate)); } function normalizeImageInput(value: unknown): string | undefined { - if (typeof value !== "string" || !value.trim()) return undefined; - const cleaned = stripShellQuotes(value); - - if (/^https?:\/\//i.test(cleaned) || /^data:image\//i.test(cleaned)) { - return cleaned; - } - - const localPath = resolveLocalImagePath(cleaned); - if (!localPath) { - throw new Error( - `Image file does not exist or is not a valid URL: ${cleaned}`, - ); - } - - const mimeType = imageMimeTypeForPath(localPath); - const data = readFileSync(localPath).toString("base64"); - return `data:${mimeType};base64,${data}`; + if (typeof value !== 'string' || !value.trim()) return undefined; + const cleaned = stripShellQuotes(value); + + if (/^https?:\/\//i.test(cleaned) || /^data:image\//i.test(cleaned)) { + return cleaned; + } + + const localPath = resolveLocalImagePath(cleaned); + if (!localPath) { + throw new Error(`Image file does not exist or is not a valid URL: ${cleaned}`); + } + + const mimeType = imageMimeTypeForPath(localPath); + const data = readFileSync(localPath).toString('base64'); + return `data:${mimeType};base64,${data}`; } // ─── Content part normalization ─────────────────────────────────────────────── function isInputImagePart(value: unknown): value is Record { - return ( - !!value && - typeof value === "object" && - (value as Record).type === "input_image" - ); + return ( + !!value && + typeof value === 'object' && + (value as Record).type === 'input_image' + ); } function getImageUrlAndDetail(obj: Record): { - imageUrl: unknown; - detail: unknown; + imageUrl: unknown; + detail: unknown; } { - if (typeof obj.image_url === "object" && obj.image_url) { - const imageUrl = obj.image_url as Record; - return { imageUrl: imageUrl.url, detail: imageUrl.detail }; - } + if (typeof obj.image_url === 'object' && obj.image_url) { + const imageUrl = obj.image_url as Record; + return { imageUrl: imageUrl.url, detail: imageUrl.detail }; + } - return { imageUrl: obj.image_url, detail: obj.detail }; + return { imageUrl: obj.image_url, detail: obj.detail }; } function normalizeImageParts(value: unknown): unknown { - if (Array.isArray(value)) return value.map(normalizeImageParts); - if (!value || typeof value !== "object") return value; - - const obj = { ...(value as Record) }; - - if ( - obj.type === "image" && - typeof obj.data === "string" && - typeof obj.mimeType === "string" - ) { - return { - type: "input_image", - image_url: `data:${obj.mimeType};base64,${obj.data}`, - detail: - typeof obj.detail === "string" && obj.detail ? obj.detail : "auto", - }; - } - - if (obj.type === "image_url") { - const { imageUrl, detail } = getImageUrlAndDetail(obj); - obj.type = "input_image"; - obj.image_url = imageUrl; - if (typeof detail === "string" && detail) obj.detail = detail; - } - - if (obj.type === "input_image") { - const { imageUrl, detail } = getImageUrlAndDetail(obj); - const normalized = normalizeImageInput(imageUrl); - if (normalized) obj.image_url = normalized; - if (typeof detail === "string" && detail) obj.detail = detail; - if (typeof obj.detail !== "string" || !obj.detail) obj.detail = "auto"; - } - - if (Array.isArray(obj.content)) - obj.content = normalizeImageParts(obj.content); - if (Array.isArray(obj.output)) obj.output = normalizeImageParts(obj.output); - return obj; + if (Array.isArray(value)) return value.map(normalizeImageParts); + if (!value || typeof value !== 'object') return value; + + const obj = { ...(value as Record) }; + + if (obj.type === 'image' && typeof obj.data === 'string' && typeof obj.mimeType === 'string') { + return { + type: 'input_image', + image_url: `data:${obj.mimeType};base64,${obj.data}`, + detail: typeof obj.detail === 'string' && obj.detail ? obj.detail : 'auto', + }; + } + + if (obj.type === 'image_url') { + const { imageUrl, detail } = getImageUrlAndDetail(obj); + obj.type = 'input_image'; + obj.image_url = imageUrl; + if (typeof detail === 'string' && detail) obj.detail = detail; + } + + if (obj.type === 'input_image') { + const { imageUrl, detail } = getImageUrlAndDetail(obj); + const normalized = normalizeImageInput(imageUrl); + if (normalized) obj.image_url = normalized; + if (typeof detail === 'string' && detail) obj.detail = detail; + if (typeof obj.detail !== 'string' || !obj.detail) obj.detail = 'auto'; + } + + if (Array.isArray(obj.content)) obj.content = normalizeImageParts(obj.content); + if (Array.isArray(obj.output)) obj.output = normalizeImageParts(obj.output); + return obj; } // ─── function_call_output rewrite ───────────────────────────────────────────── -function rewriteFunctionCallOutput( - input: Record[], -): Record[] { - const rewritten: Record[] = []; - - for (const item of input) { - if ( - !item || - typeof item !== "object" || - item.type !== "function_call_output" || - !Array.isArray(item.output) - ) { - rewritten.push(item); - continue; - } - - const outputParts = item.output as unknown[]; - const imageParts = outputParts.filter(isInputImagePart); - const textParts = outputParts.filter((p) => !isInputImagePart(p)); - - const textChunks: string[] = []; - for (const part of textParts) { - if (typeof part === "string") { - textChunks.push(part); - } else if (part && typeof part === "object") { - const p = part as Record; - if (typeof p.text === "string") textChunks.push(p.text); - } - } - let imageCount = 0; - for (const _ of imageParts) imageCount++; - - const outputText = - textChunks.join("\n") || "(tool returned no text output)"; - rewritten.push({ ...item, output: outputText }); - - if (imageCount > 0) { - const callId = item.call_id ? ` (${String(item.call_id)})` : ""; - const label = `The previous tool result${callId} included ${imageCount} image${imageCount === 1 ? "" : "s"}. Use the attached image${imageCount === 1 ? "" : "s"} as the visual output from that tool.`; - rewritten.push({ - role: "user", - content: [{ type: "input_text", text: label }, ...imageParts], - }); - } - } - - return rewritten; +function rewriteFunctionCallOutput(input: Record[]): Record[] { + const rewritten: Record[] = []; + + for (const item of input) { + if ( + !item || + typeof item !== 'object' || + item.type !== 'function_call_output' || + !Array.isArray(item.output) + ) { + rewritten.push(item); + continue; + } + + const outputParts = item.output as unknown[]; + const imageParts = outputParts.filter(isInputImagePart); + const textParts = outputParts.filter((p) => !isInputImagePart(p)); + + const textChunks: string[] = []; + for (const part of textParts) { + if (typeof part === 'string') { + textChunks.push(part); + } else if (part && typeof part === 'object') { + const p = part as Record; + if (typeof p.text === 'string') textChunks.push(p.text); + } + } + let imageCount = 0; + for (const _ of imageParts) imageCount++; + + const outputText = textChunks.join('\n') || '(tool returned no text output)'; + rewritten.push({ ...item, output: outputText }); + + if (imageCount > 0) { + const callId = item.call_id ? ` (${String(item.call_id)})` : ''; + const label = `The previous tool result${callId} included ${imageCount} image${imageCount === 1 ? '' : 's'}. Use the attached image${imageCount === 1 ? '' : 's'} as the visual output from that tool.`; + rewritten.push({ + role: 'user', + content: [{ type: 'input_text', text: label }, ...imageParts], + }); + } + } + + return rewritten; } // ─── Main sanitization ──────────────────────────────────────────────────────── @@ -236,100 +222,93 @@ function rewriteFunctionCallOutput( * Returns the modified payload. Mutates the input in place for efficiency. */ export function sanitizePayload( - params: Record, - modelId: string, - sessionId?: string, + params: Record, + modelId: string, + sessionId?: string, ): Record { - const next = params; - - // ── Sanitize input array ────────────────────────────────────────────── - if (Array.isArray(next.input)) { - let input = (next.input as unknown[]) - .map((item: unknown) => { - if (!item || typeof item !== "object") return item; - const obj = item as Record; - - // Strip replayed reasoning items - if (obj.type === "reasoning") return null; - - // Drop empty string content - if (typeof obj.content === "string" && obj.content.length === 0) - return null; - - return obj; - }) - .filter(Boolean) as Record[]; - - // Move system/developer messages to top-level instructions. - // xAI rejects role: "developer" and role: "system" in the input array. - const instructionParts: string[] = []; - while (input.length > 0) { - const first = input[0]; - if (!first || typeof first !== "object") break; - const role = (first as Record).role; - if (role !== "developer" && role !== "system") break; - const text = textFromContent( - (first as Record).content, - ).trim(); - if (text) instructionParts.push(text); - input.shift(); - } - if (instructionParts.length > 0) { - const existing = - typeof next.instructions === "string" && next.instructions - ? next.instructions - : ""; - const merged = [existing, ...instructionParts] - .filter((part) => part.length > 0) - .join("\n\n"); - next.instructions = merged; - } - - // Normalize image parts (resolve local paths, fix types) - input = normalizeImageParts(input) as Record[]; - - // Rewrite function_call_output with images - input = rewriteFunctionCallOutput(input); - - next.input = input; - } else if (typeof next.input === "string") { - // String input is valid and should stay string-shaped. - } - - // ── response_format → text.format ──────────────────────────────────── - if (next.response_format && !next.text) { - next.text = { format: next.response_format }; - delete next.response_format; - } - - // ── Reasoning effort ────────────────────────────────────────────────── - if (supportsReasoningEffort(modelId)) { - const reasoning = next.reasoning as Record | undefined; - if (reasoning && reasoning.effort === "minimal") { - next.reasoning = { ...reasoning, effort: "low" }; - } - if (reasoning && reasoning.summary !== undefined) { - next.reasoning = { effort: reasoning.effort }; - } - } else { - delete next.reasoning; - delete next.reasoningEffort; - } - - // ── Strip/filter unsupported fields ────────────────────────────────── - if (Array.isArray(next.include)) { - next.include = (next.include as unknown[]).filter( - (item) => item !== "reasoning.encrypted_content", - ); - if ((next.include as unknown[]).length === 0) delete next.include; - } - - delete next.prompt_cache_retention; - - // Add prompt_cache_key for conversation caching (routes to same server). - if (sessionId && !next.prompt_cache_key) { - next.prompt_cache_key = sessionId; - } - - return next; + const next = params; + + // ── Sanitize input array ────────────────────────────────────────────── + if (Array.isArray(next.input)) { + let input = (next.input as unknown[]) + .map((item: unknown) => { + if (!item || typeof item !== 'object') return item; + const obj = item as Record; + + // Strip replayed reasoning items + if (obj.type === 'reasoning') return null; + + // Drop empty string content + if (typeof obj.content === 'string' && obj.content.length === 0) return null; + + return obj; + }) + .filter(Boolean) as Record[]; + + // Move system/developer messages to top-level instructions. + // xAI rejects role: "developer" and role: "system" in the input array. + const instructionParts: string[] = []; + while (input.length > 0) { + const first = input[0]; + if (!first || typeof first !== 'object') break; + const role = (first as Record).role; + if (role !== 'developer' && role !== 'system') break; + const text = textFromContent((first as Record).content).trim(); + if (text) instructionParts.push(text); + input.shift(); + } + if (instructionParts.length > 0) { + const existing = + typeof next.instructions === 'string' && next.instructions ? next.instructions : ''; + const merged = [existing, ...instructionParts].filter((part) => part.length > 0).join('\n\n'); + next.instructions = merged; + } + + // Normalize image parts (resolve local paths, fix types) + input = normalizeImageParts(input) as Record[]; + + // Rewrite function_call_output with images + input = rewriteFunctionCallOutput(input); + + next.input = input; + } else if (typeof next.input === 'string') { + // String input is valid and should stay string-shaped. + } + + // ── response_format → text.format ──────────────────────────────────── + if (next.response_format && !next.text) { + next.text = { format: next.response_format }; + delete next.response_format; + } + + // ── Reasoning effort ────────────────────────────────────────────────── + if (supportsReasoningEffort(modelId)) { + const reasoning = next.reasoning as Record | undefined; + if (reasoning && reasoning.effort === 'minimal') { + next.reasoning = { ...reasoning, effort: 'low' }; + } + if (reasoning && reasoning.summary !== undefined) { + next.reasoning = { effort: reasoning.effort }; + } + } else { + delete next.reasoning; + delete next.reasoningEffort; + } + + // ── Strip/filter unsupported fields ────────────────────────────────── + if (Array.isArray(next.include)) { + next.include = (next.include as unknown[]).filter( + (item) => item !== 'reasoning.encrypted_content', + ); + if ((next.include as unknown[]).length === 0) delete next.include; + } + + delete next.prompt_cache_retention; + + // Add prompt_cache_key for conversation caching (routes to same server). + if (sessionId && !next.prompt_cache_key) { + next.prompt_cache_key = sessionId; + } + + return next; } diff --git a/src/provider/quota.ts b/src/provider/quota.ts index 7243f20..3295fcb 100644 --- a/src/provider/quota.ts +++ b/src/provider/quota.ts @@ -1,144 +1,121 @@ -import { existsSync, mkdirSync, readFileSync, writeFileSync } from "node:fs"; -import { homedir } from "node:os"; -import { dirname, join } from "node:path"; +import { existsSync, mkdirSync, readFileSync, writeFileSync } from 'node:fs'; +import { homedir } from 'node:os'; +import { dirname, join } from 'node:path'; -const QUOTA_CACHE_FILE = "grok-cli-quota.json"; +const QUOTA_CACHE_FILE = 'grok-cli-quota.json'; // ─── Rate limit cache (piggybacks on onResponse from normal traffic) ────────── interface RateLimitInfo { - remainingRequests: number; - limitRequests: number; - remainingTokens: number; - limitTokens: number; - contextWindow: number; - zeroDataRetention: boolean; - capturedAt: number; + remainingRequests: number; + limitRequests: number; + remainingTokens: number; + limitTokens: number; + contextWindow: number; + zeroDataRetention: boolean; + capturedAt: number; } const cachedRateLimits = new Map(); function quotaCachePath() { - return join(homedir(), ".pi", QUOTA_CACHE_FILE); + return join(homedir(), '.pi', QUOTA_CACHE_FILE); } function isRateLimitInfo(value: unknown): value is RateLimitInfo { - if (!value || typeof value !== "object") return false; - const info = value as Record; - return ( - typeof info.remainingRequests === "number" && - typeof info.limitRequests === "number" && - typeof info.remainingTokens === "number" && - typeof info.limitTokens === "number" && - typeof info.contextWindow === "number" && - typeof info.zeroDataRetention === "boolean" && - typeof info.capturedAt === "number" - ); + if (!value || typeof value !== 'object') return false; + const info = value as Record; + return ( + typeof info.remainingRequests === 'number' && + typeof info.limitRequests === 'number' && + typeof info.remainingTokens === 'number' && + typeof info.limitTokens === 'number' && + typeof info.contextWindow === 'number' && + typeof info.zeroDataRetention === 'boolean' && + typeof info.capturedAt === 'number' + ); } export function loadQuotaCache() { - cachedRateLimits.clear(); - if (!existsSync(quotaCachePath())) return; - - try { - const payload = JSON.parse( - readFileSync(quotaCachePath(), "utf8"), - ) as Record; - const models = payload.models; - if (!models || typeof models !== "object") return; - - Object.entries(models).forEach(([model, rateLimit]) => { - if (isRateLimitInfo(rateLimit)) cachedRateLimits.set(model, rateLimit); - }); - } catch { - cachedRateLimits.clear(); - } + cachedRateLimits.clear(); + if (!existsSync(quotaCachePath())) return; + + try { + const payload = JSON.parse(readFileSync(quotaCachePath(), 'utf8')) as Record; + const models = payload.models; + if (!models || typeof models !== 'object') return; + + Object.entries(models).forEach(([model, rateLimit]) => { + if (isRateLimitInfo(rateLimit)) cachedRateLimits.set(model, rateLimit); + }); + } catch { + cachedRateLimits.clear(); + } } function persistQuotaCache() { - try { - mkdirSync(dirname(quotaCachePath()), { recursive: true }); - writeFileSync( - quotaCachePath(), - JSON.stringify( - { version: 1, models: Object.fromEntries(cachedRateLimits) }, - null, - "\t", - ), - ); - } catch { - // Status remains cache-only; persistence failures should not break requests. - } + try { + mkdirSync(dirname(quotaCachePath()), { recursive: true }); + writeFileSync( + quotaCachePath(), + JSON.stringify({ version: 1, models: Object.fromEntries(cachedRateLimits) }, null, '\t'), + ); + } catch { + // Status remains cache-only; persistence failures should not break requests. + } } /** * Extract rate limit info from response headers. * Returns undefined if no rate limit headers are present. */ -function extractRateLimit( - h: Record, -): RateLimitInfo | undefined { - const remainingReqs = Number(h["x-ratelimit-remaining-requests"]); - const limitReqs = Number(h["x-ratelimit-limit-requests"]); - const remainingTokens = Number(h["x-ratelimit-remaining-tokens"]); - const limitTokens = Number(h["x-ratelimit-limit-tokens"]); - const contextWindow = Number(h["x-grok-context-window"]); - - if (Number.isNaN(remainingReqs) && Number.isNaN(remainingTokens)) - return undefined; - - return { - remainingRequests: remainingReqs, - limitRequests: limitReqs, - remainingTokens, - limitTokens, - contextWindow: contextWindow || 512_000, - zeroDataRetention: h["x-zero-data-retention"] === "true", - capturedAt: Date.now(), - }; +function extractRateLimit(h: Record): RateLimitInfo | undefined { + const remainingReqs = Number(h['x-ratelimit-remaining-requests']); + const limitReqs = Number(h['x-ratelimit-limit-requests']); + const remainingTokens = Number(h['x-ratelimit-remaining-tokens']); + const limitTokens = Number(h['x-ratelimit-limit-tokens']); + const contextWindow = Number(h['x-grok-context-window']); + + if (Number.isNaN(remainingReqs) && Number.isNaN(remainingTokens)) return undefined; + + return { + remainingRequests: remainingReqs, + limitRequests: limitReqs, + remainingTokens, + limitTokens, + contextWindow: contextWindow || 512_000, + zeroDataRetention: h['x-zero-data-retention'] === 'true', + capturedAt: Date.now(), + }; } -export function formatQuota( - name: string, - rateLimit: RateLimitInfo | undefined, -) { - if (!rateLimit) { - return [ - ` ${name}:`, - " no cached quota data — make a request with this model first", - ]; - } - - const ageSec = Math.round((Date.now() - rateLimit.capturedAt) / 1000); - const ageStr = - ageSec < 60 ? `${ageSec}s ago` : `${Math.round(ageSec / 60)}m ago`; - const lines = [` ${name}:`]; - lines.push(` Cached: ${ageStr}`); - lines.push( - ` Requests: ${rateLimit.remainingRequests}/${rateLimit.limitRequests} remaining`, - ); - lines.push( - ` Tokens: ${rateLimit.remainingTokens.toLocaleString()}/${rateLimit.limitTokens.toLocaleString()} remaining`, - ); - lines.push( - ` Context Limit: ${rateLimit.contextWindow.toLocaleString()} tokens`, - ); - if (rateLimit.zeroDataRetention) { - lines.push(" Data: Zero retention ✓"); - } - return lines; +export function formatQuota(name: string, rateLimit: RateLimitInfo | undefined) { + if (!rateLimit) { + return [` ${name}:`, ' no cached quota data — make a request with this model first']; + } + + const ageSec = Math.round((Date.now() - rateLimit.capturedAt) / 1000); + const ageStr = ageSec < 60 ? `${ageSec}s ago` : `${Math.round(ageSec / 60)}m ago`; + const lines = [` ${name}:`]; + lines.push(` Cached: ${ageStr}`); + lines.push(` Requests: ${rateLimit.remainingRequests}/${rateLimit.limitRequests} remaining`); + lines.push( + ` Tokens: ${rateLimit.remainingTokens.toLocaleString()}/${rateLimit.limitTokens.toLocaleString()} remaining`, + ); + lines.push(` Context Limit: ${rateLimit.contextWindow.toLocaleString()} tokens`); + if (rateLimit.zeroDataRetention) { + lines.push(' Data: Zero retention ✓'); + } + return lines; } -export function captureRateLimit( - modelId: string, - headers: Record, -) { - const rateLimit = extractRateLimit(headers); - if (!rateLimit) return; - cachedRateLimits.set(modelId, rateLimit); - persistQuotaCache(); +export function captureRateLimit(modelId: string, headers: Record) { + const rateLimit = extractRateLimit(headers); + if (!rateLimit) return; + cachedRateLimits.set(modelId, rateLimit); + persistQuotaCache(); } export function getCachedRateLimit(modelId: string): RateLimitInfo | undefined { - return cachedRateLimits.get(modelId); + return cachedRateLimits.get(modelId); } diff --git a/src/provider/register.ts b/src/provider/register.ts index 06de7eb..80fff0c 100644 --- a/src/provider/register.ts +++ b/src/provider/register.ts @@ -1,104 +1,90 @@ -import type { - Api, - Model, - OAuthCredentials, - OAuthLoginCallbacks, -} from "@earendil-works/pi-ai"; -import type { - ExtensionAPI, - ProviderConfig, -} from "@earendil-works/pi-coding-agent"; -import * as oauth from "../auth/oauth.js"; -import { getBaseUrl, type XaiOAuthCredentials } from "../auth/oauth.js"; -import { type GrokCliModelConfig, resolveModels } from "../models/catalog.js"; -import { sanitizePayload } from "../payload/sanitize.js"; -import { registerGrokTools } from "../tools/register.js"; -import { loadQuotaCache } from "./quota.js"; -import { registerStatusCommand } from "./status.js"; -import { streamGrokCli } from "./stream.js"; -import { syncGrokTools } from "./toolScope.js"; +import type { Api, Model, OAuthCredentials, OAuthLoginCallbacks } from '@earendil-works/pi-ai'; +import type { ExtensionAPI, ProviderConfig } from '@earendil-works/pi-coding-agent'; +import * as oauth from '../auth/oauth.js'; +import { getBaseUrl, type XaiOAuthCredentials } from '../auth/oauth.js'; +import { type GrokCliModelConfig, resolveModels } from '../models/catalog.js'; +import { sanitizePayload } from '../payload/sanitize.js'; +import { registerGrokTools } from '../tools/register.js'; +import { loadQuotaCache } from './quota.js'; +import { registerStatusCommand } from './status.js'; +import { streamGrokCli } from './stream.js'; +import { syncGrokTools } from './toolScope.js'; export default function registerGrokCli(pi: ExtensionAPI) { - loadQuotaCache(); - const baseUrl = getBaseUrl(); - const models = resolveModels(); + loadQuotaCache(); + const baseUrl = getBaseUrl(); + const models = resolveModels(); - pi.on("model_select", (event) => { - syncGrokTools(pi, event.model.provider); - }); + pi.on('model_select', (event) => { + syncGrokTools(pi, event.model.provider); + }); - pi.on("before_agent_start", (_event, ctx) => { - syncGrokTools(pi, ctx.model?.provider); - }); + pi.on('before_agent_start', (_event, ctx) => { + syncGrokTools(pi, ctx.model?.provider); + }); - pi.registerProvider("grok-cli", { - name: "Grok CLI", - baseUrl, - apiKey: "$GROK_CLI_OAUTH_TOKEN", - api: "openai-responses", - models: models.map((m: GrokCliModelConfig) => ({ - id: m.id, - name: m.name, - reasoning: m.reasoning, - thinkingLevelMap: m.thinkingLevelMap, - input: m.input, - cost: m.cost, - contextWindow: m.contextWindow, - maxTokens: m.maxTokens, - })), - oauth: { - name: "Grok CLI", + pi.registerProvider('grok-cli', { + name: 'Grok CLI', + baseUrl, + apiKey: '$GROK_CLI_OAUTH_TOKEN', + api: 'openai-responses', + models: models.map((m: GrokCliModelConfig) => ({ + id: m.id, + name: m.name, + reasoning: m.reasoning, + thinkingLevelMap: m.thinkingLevelMap, + input: m.input, + cost: m.cost, + contextWindow: m.contextWindow, + maxTokens: m.maxTokens, + })), + oauth: { + name: 'Grok CLI', - async login(callbacks: OAuthLoginCallbacks): Promise { - return oauth.login(callbacks); - }, + async login(callbacks: OAuthLoginCallbacks): Promise { + return oauth.login(callbacks); + }, - async refreshToken( - credentials: OAuthCredentials, - ): Promise { - return oauth.refresh(credentials); - }, + async refreshToken(credentials: OAuthCredentials): Promise { + return oauth.refresh(credentials); + }, - getApiKey(credentials: OAuthCredentials): string { - return credentials.access; - }, + getApiKey(credentials: OAuthCredentials): string { + return credentials.access; + }, - modifyModels(models: Model[], credentials: OAuthCredentials) { - const effectiveBaseUrl = String( - (credentials as XaiOAuthCredentials).baseUrl ?? getBaseUrl(), - ).replace(/\/+$/, ""); + modifyModels(models: Model[], credentials: OAuthCredentials) { + const effectiveBaseUrl = String( + (credentials as XaiOAuthCredentials).baseUrl ?? getBaseUrl(), + ).replace(/\/+$/, ''); - return models.map((m) => - m.provider === "grok-cli" ? { ...m, baseUrl: effectiveBaseUrl } : m, - ); - }, - } satisfies ProviderConfig["oauth"], + return models.map((m) => + m.provider === 'grok-cli' ? { ...m, baseUrl: effectiveBaseUrl } : m, + ); + }, + } satisfies ProviderConfig['oauth'], - streamSimple: streamGrokCli, - }); + streamSimple: streamGrokCli, + }); - registerGrokTools(pi); + registerGrokTools(pi); - pi.on("before_provider_request", (event, ctx) => { - if (ctx.model?.provider !== "grok-cli") return; + pi.on('before_provider_request', (event, ctx) => { + if (ctx.model?.provider !== 'grok-cli') return; - const modelId = ctx.model?.id ?? ""; - const sessionId = ctx.sessionManager?.getSessionId(); - return sanitizePayload( - event.payload as Record, - modelId, - sessionId, - ); - }); + const modelId = ctx.model?.id ?? ''; + const sessionId = ctx.sessionManager?.getSessionId(); + return sanitizePayload(event.payload as Record, modelId, sessionId); + }); - registerStatusCommand(pi); + registerStatusCommand(pi); - if (process.env.GROK_CLI_OAUTH_TOKEN) { - pi.on("session_start", async (_event, ctx) => { - ctx.ui.notify( - "[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery", - "warning", - ); - }); - } + if (process.env.GROK_CLI_OAUTH_TOKEN) { + pi.on('session_start', async (_event, ctx) => { + ctx.ui.notify( + '[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery', + 'warning', + ); + }); + } } diff --git a/src/provider/status.ts b/src/provider/status.ts index b3f6e2e..1727e9a 100644 --- a/src/provider/status.ts +++ b/src/provider/status.ts @@ -1,66 +1,55 @@ -import type { Api, Model } from "@earendil-works/pi-ai"; -import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; -import { XaiOAuthError } from "../shared/errors.js"; -import { formatQuota, getCachedRateLimit } from "./quota.js"; +import type { Api, Model } from '@earendil-works/pi-ai'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { XaiOAuthError } from '../shared/errors.js'; +import { formatQuota, getCachedRateLimit } from './quota.js'; -export function registerStatusCommand( - pi: Pick, -) { - pi.registerCommand("grok-cli-status", { - description: "Show Grok CLI provider status, quota, and token health", - handler: async (_args, ctx) => { - const token = process.env.GROK_CLI_OAUTH_TOKEN; - if (token) { - ctx.ui.notify( - "⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available", - "warning", - ); - } +export function registerStatusCommand(pi: Pick) { + pi.registerCommand('grok-cli-status', { + description: 'Show Grok CLI provider status, quota, and token health', + handler: async (_args, ctx) => { + const token = process.env.GROK_CLI_OAUTH_TOKEN; + if (token) { + ctx.ui.notify( + '⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available', + 'warning', + ); + } - try { - const registry = ctx.modelRegistry; - const grokModels = registry - .getAll() - .filter((m: Model) => m.provider === "grok-cli"); - if (grokModels.length === 0) { - ctx.ui.notify( - "Grok CLI: no models registered. Run /login grok-cli first.", - "warning", - ); - return; - } + try { + const registry = ctx.modelRegistry; + const grokModels = registry.getAll().filter((m: Model) => m.provider === 'grok-cli'); + if (grokModels.length === 0) { + ctx.ui.notify('Grok CLI: no models registered. Run /login grok-cli first.', 'warning'); + return; + } - const modelNames = grokModels - .slice(0, 5) - .map((m: Model) => m.id) - .join(", "); - const suffix = - grokModels.length > 5 ? ` (+${grokModels.length - 5} more)` : ""; - ctx.ui.notify( - `✓ Grok CLI: ${grokModels.length} models available (${modelNames}${suffix})`, - "info", - ); + const modelNames = grokModels + .slice(0, 5) + .map((m: Model) => m.id) + .join(', '); + const suffix = grokModels.length > 5 ? ` (+${grokModels.length - 5} more)` : ''; + ctx.ui.notify( + `✓ Grok CLI: ${grokModels.length} models available (${modelNames}${suffix})`, + 'info', + ); - const lines = [ - " Quota:", - "", - ...formatQuota("grok-build", getCachedRateLimit("grok-build")), - "", - ...formatQuota( - "grok-composer-2.5-fast", - getCachedRateLimit("grok-composer-2.5-fast"), - ), - ]; - ctx.ui.notify(lines.join("\n"), "info"); - } catch (err) { - const msg = - err instanceof XaiOAuthError - ? `${err.message} (code: ${err.code})` - : err instanceof Error - ? err.message - : String(err); - ctx.ui.notify(`Grok CLI: ${msg}`, "warning"); - } - }, - }); + const lines = [ + ' Quota:', + '', + ...formatQuota('grok-build', getCachedRateLimit('grok-build')), + '', + ...formatQuota('grok-composer-2.5-fast', getCachedRateLimit('grok-composer-2.5-fast')), + ]; + ctx.ui.notify(lines.join('\n'), 'info'); + } catch (err) { + const msg = + err instanceof XaiOAuthError + ? `${err.message} (code: ${err.code})` + : err instanceof Error + ? err.message + : String(err); + ctx.ui.notify(`Grok CLI: ${msg}`, 'warning'); + } + }, + }); } diff --git a/src/provider/stream.ts b/src/provider/stream.ts index 788b936..afcafae 100644 --- a/src/provider/stream.ts +++ b/src/provider/stream.ts @@ -1,14 +1,14 @@ import { - type Api, - type AssistantMessageEventStream, - type Context, - type Model, - type SimpleStreamOptions, - streamSimpleOpenAIResponses, -} from "@earendil-works/pi-ai"; -import { captureRateLimit } from "./quota.js"; + type Api, + type AssistantMessageEventStream, + type Context, + type Model, + type SimpleStreamOptions, + streamSimpleOpenAIResponses, +} from '@earendil-works/pi-ai'; +import { captureRateLimit } from './quota.js'; -const GROK_CLI_VERSION = "0.2.16"; +const GROK_CLI_VERSION = '0.2.16'; /** * Stream function that adds Grok CLI-specific headers to requests. @@ -21,33 +21,29 @@ const GROK_CLI_VERSION = "0.2.16"; * - x-xai-token-auth: xai-grok-cli */ export function streamGrokCli( - model: Model, - context: Context, - options?: SimpleStreamOptions, + model: Model, + context: Context, + options?: SimpleStreamOptions, ): AssistantMessageEventStream { - const sessionId = options?.sessionId; - const headers: Record = { - ...options?.headers, - "x-grok-client-identifier": "pi-grok-cli", - "x-grok-client-version": GROK_CLI_VERSION, - "x-xai-token-auth": "xai-grok-cli", - "x-grok-model-override": model.id, - }; + const sessionId = options?.sessionId; + const headers: Record = { + ...options?.headers, + 'x-grok-client-identifier': 'pi-grok-cli', + 'x-grok-client-version': GROK_CLI_VERSION, + 'x-xai-token-auth': 'xai-grok-cli', + 'x-grok-model-override': model.id, + }; - if (sessionId) { - headers["x-grok-conv-id"] = sessionId; - } + if (sessionId) { + headers['x-grok-conv-id'] = sessionId; + } - return streamSimpleOpenAIResponses( - model as Model<"openai-responses">, - context, - { - ...options, - headers, - onResponse(response) { - captureRateLimit(model.id, response.headers); - options?.onResponse?.(response, model); - }, - }, - ); + return streamSimpleOpenAIResponses(model as Model<'openai-responses'>, context, { + ...options, + headers, + onResponse(response) { + captureRateLimit(model.id, response.headers); + options?.onResponse?.(response, model); + }, + }); } diff --git a/src/provider/toolScope.ts b/src/provider/toolScope.ts index fcbbee0..351058c 100644 --- a/src/provider/toolScope.ts +++ b/src/provider/toolScope.ts @@ -1,34 +1,31 @@ -import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; export const GROK_TOOL_NAMES = [ - "Grep", - "Glob", - "LS", - "Read", - "Write", - "StrReplace", - "Edit", - "Delete", - "Shell", + 'Grep', + 'Glob', + 'LS', + 'Read', + 'Write', + 'StrReplace', + 'Edit', + 'Delete', + 'Shell', ]; export function syncGrokTools( - pi: Pick, - provider: string | undefined, + pi: Pick, + provider: string | undefined, ) { - const currentTools = pi.getActiveTools(); - const baseTools = currentTools.filter( - (toolName) => !GROK_TOOL_NAMES.includes(toolName), - ); - const nextTools = - provider === "grok-cli" ? [...baseTools, ...GROK_TOOL_NAMES] : baseTools; + const currentTools = pi.getActiveTools(); + const baseTools = currentTools.filter((toolName) => !GROK_TOOL_NAMES.includes(toolName)); + const nextTools = provider === 'grok-cli' ? [...baseTools, ...GROK_TOOL_NAMES] : baseTools; - if ( - currentTools.length === nextTools.length && - currentTools.every((toolName, i) => toolName === nextTools[i]) - ) { - return; - } + if ( + currentTools.length === nextTools.length && + currentTools.every((toolName, i) => toolName === nextTools[i]) + ) { + return; + } - pi.setActiveTools(nextTools); + pi.setActiveTools(nextTools); } diff --git a/src/shared/errors.ts b/src/shared/errors.ts index cde32e6..3d2bcc5 100644 --- a/src/shared/errors.ts +++ b/src/shared/errors.ts @@ -5,40 +5,40 @@ * retryable failures (network) from fatal ones (revoked refresh token). */ export class XaiOAuthError extends Error { - constructor( - message: string, - public readonly code: string, - public readonly reloginRequired = false, - ) { - super(message); - this.name = "XaiOAuthError"; - } + constructor( + message: string, + public readonly code: string, + public readonly reloginRequired = false, + ) { + super(message); + this.name = 'XaiOAuthError'; + } } /** Well-known error codes. */ export const XaiErrorCode = { - /** OIDC discovery failed (network, invalid response). */ - DISCOVERY_FAILED: "discovery_failed", - /** Discovery endpoint returned a non-xAI origin. */ - DISCOVERY_INVALID_ORIGIN: "discovery_invalid_origin", - /** Authorization was denied or errored in the browser. */ - AUTHORIZATION_FAILED: "authorization_failed", - /** CSRF state mismatch between request and callback. */ - STATE_MISMATCH: "state_mismatch", - /** Callback did not include an authorization code. */ - CODE_MISSING: "code_missing", - /** Token exchange failed (network, invalid response). */ - TOKEN_EXCHANGE_FAILED: "token_exchange_failed", - /** Token exchange returned an invalid payload. */ - TOKEN_EXCHANGE_INVALID: "token_exchange_invalid", - /** Refresh token is missing or empty. */ - REFRESH_MISSING: "refresh_missing", - /** Token refresh failed (expired, revoked). */ - REFRESH_FAILED: "refresh_failed", - /** No credentials stored. */ - AUTH_MISSING: "auth_missing", - /** Loopback callback server could not bind. */ - CALLBACK_BIND_FAILED: "callback_bind_failed", - /** Loopback callback timed out. */ - CALLBACK_TIMEOUT: "callback_timeout", + /** OIDC discovery failed (network, invalid response). */ + DISCOVERY_FAILED: 'discovery_failed', + /** Discovery endpoint returned a non-xAI origin. */ + DISCOVERY_INVALID_ORIGIN: 'discovery_invalid_origin', + /** Authorization was denied or errored in the browser. */ + AUTHORIZATION_FAILED: 'authorization_failed', + /** CSRF state mismatch between request and callback. */ + STATE_MISMATCH: 'state_mismatch', + /** Callback did not include an authorization code. */ + CODE_MISSING: 'code_missing', + /** Token exchange failed (network, invalid response). */ + TOKEN_EXCHANGE_FAILED: 'token_exchange_failed', + /** Token exchange returned an invalid payload. */ + TOKEN_EXCHANGE_INVALID: 'token_exchange_invalid', + /** Refresh token is missing or empty. */ + REFRESH_MISSING: 'refresh_missing', + /** Token refresh failed (expired, revoked). */ + REFRESH_FAILED: 'refresh_failed', + /** No credentials stored. */ + AUTH_MISSING: 'auth_missing', + /** Loopback callback server could not bind. */ + CALLBACK_BIND_FAILED: 'callback_bind_failed', + /** Loopback callback timed out. */ + CALLBACK_TIMEOUT: 'callback_timeout', } as const; diff --git a/src/tools/files.ts b/src/tools/files.ts index 8abb63a..e5ee3ac 100644 --- a/src/tools/files.ts +++ b/src/tools/files.ts @@ -1,29 +1,23 @@ -import { execFile } from "node:child_process"; +import { execFile } from 'node:child_process'; +import { existsSync, mkdirSync, readFileSync, unlinkSync, writeFileSync } from 'node:fs'; +import { dirname, resolve } from 'node:path'; +import { promisify } from 'node:util'; +import { Type } from '@earendil-works/pi-ai'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { - existsSync, - mkdirSync, - readFileSync, - unlinkSync, - writeFileSync, -} from "node:fs"; -import { dirname, resolve } from "node:path"; -import { promisify } from "node:util"; -import { Type } from "@earendil-works/pi-ai"; -import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; -import { - booleanDetail, - detailRecord, - fileError, - fileNotFound, - MAX_OUTPUT_CHARS, - numberDetail, - recordFrom, - renderResultSummary, - stringDetail, - stringFrom, - type ToolError, - text, -} from "./rendering.js"; + booleanDetail, + detailRecord, + fileError, + fileNotFound, + MAX_OUTPUT_CHARS, + numberDetail, + recordFrom, + renderResultSummary, + stringDetail, + stringFrom, + type ToolError, + text, +} from './rendering.js'; const execFileAsync = promisify(execFile); @@ -31,578 +25,549 @@ type ReplacementEdit = { oldText: string; newText: string }; type WriteArgs = { path: string; content: string }; type StrReplaceArgs = { path: string; old_str: string; new_str: string }; type EditArgs = { - path: string; - edits?: ReplacementEdit[]; - applyPatch?: { patchContent: string }; - strReplace?: ReplacementEdit; - multiStrReplace?: { edits: ReplacementEdit[] }; + path: string; + edits?: ReplacementEdit[]; + applyPatch?: { patchContent: string }; + strReplace?: ReplacementEdit; + multiStrReplace?: { edits: ReplacementEdit[] }; }; type ToolTheme = { - bold: (text: string) => string; - fg: (name: "accent" | "toolTitle", text: string) => string; + bold: (text: string) => string; + fg: (name: 'accent' | 'toolTitle', text: string) => string; }; function parseEditList(value: unknown): ReplacementEdit[] | undefined { - const editList = typeof value === "string" ? parseJson(value) : value; - if (!Array.isArray(editList)) return undefined; - if ( - !editList.every( - (edit) => - typeof recordFrom(edit)?.oldText === "string" && - typeof recordFrom(edit)?.newText === "string", - ) - ) { - return undefined; - } - return editList.map((edit) => ({ - oldText: stringFrom(recordFrom(edit)?.oldText) ?? "", - newText: stringFrom(recordFrom(edit)?.newText) ?? "", - })); + const editList = typeof value === 'string' ? parseJson(value) : value; + if (!Array.isArray(editList)) return undefined; + if ( + !editList.every( + (edit) => + typeof recordFrom(edit)?.oldText === 'string' && + typeof recordFrom(edit)?.newText === 'string', + ) + ) { + return undefined; + } + return editList.map((edit) => ({ + oldText: stringFrom(recordFrom(edit)?.oldText) ?? '', + newText: stringFrom(recordFrom(edit)?.newText) ?? '', + })); } function parseJson(value: string): unknown { - try { - return JSON.parse(value); - } catch { - return undefined; - } + try { + return JSON.parse(value); + } catch { + return undefined; + } } function editFromText(oldText: unknown, newText: unknown) { - if (typeof oldText !== "string" || typeof newText !== "string") - return undefined; - return [{ oldText, newText }]; + if (typeof oldText !== 'string' || typeof newText !== 'string') return undefined; + return [{ oldText, newText }]; } function editsFromArgs(input: Record) { - return ( - parseEditList(input.edits) ?? - parseEditList(recordFrom(input.multiStrReplace)?.edits) ?? - editFromText(input.oldText, input.newText) ?? - editFromText( - recordFrom(input.strReplace)?.oldText, - recordFrom(input.strReplace)?.newText, - ) - ); + return ( + parseEditList(input.edits) ?? + parseEditList(recordFrom(input.multiStrReplace)?.edits) ?? + editFromText(input.oldText, input.newText) ?? + editFromText(recordFrom(input.strReplace)?.oldText, recordFrom(input.strReplace)?.newText) + ); } function applyEdits(content: string, edits: ReplacementEdit[]) { - return edits.reduce( - (result, edit) => { - const count = result.content.split(edit.oldText).length - 1; - return { - content: - count === 0 - ? result.content - : result.content.replaceAll(edit.oldText, edit.newText), - replacements: result.replacements + count, - }; - }, - { content, replacements: 0 }, - ); + return edits.reduce( + (result, edit) => { + const count = result.content.split(edit.oldText).length - 1; + return { + content: + count === 0 ? result.content : result.content.replaceAll(edit.oldText, edit.newText), + replacements: result.replacements + count, + }; + }, + { content, replacements: 0 }, + ); } function replacementResult(text: string, filePath: string) { - return { - content: [{ type: "text" as const, text }], - details: { path: filePath, replacements: 0 }, - }; + return { + content: [{ type: 'text' as const, text }], + details: { path: filePath, replacements: 0 }, + }; } function renderReplacementResult( - result: { content: { type: string; text?: string }[]; details: unknown }, - expanded: boolean, - isPartial: boolean, - theme: { fg: (name: "dim" | "muted", text: string) => string }, + result: { content: { type: string; text?: string }[]; details: unknown }, + expanded: boolean, + isPartial: boolean, + theme: { fg: (name: 'dim' | 'muted', text: string) => string }, ) { - const replacements = numberDetail(result, "replacements"); - return renderResultSummary( - result, - expanded, - isPartial, - replacements === 0 - ? theme.fg("dim", "No replacements") - : theme.fg("muted", `${replacements} replacement(s)`), - ); + const replacements = numberDetail(result, 'replacements'); + return renderResultSummary( + result, + expanded, + isPartial, + replacements === 0 + ? theme.fg('dim', 'No replacements') + : theme.fg('muted', `${replacements} replacement(s)`), + ); } -function renderPathToolCall( - toolName: string, - filePath: string, - theme: ToolTheme, -) { - return text( - theme.fg("toolTitle", theme.bold(`${toolName} `)) + - theme.fg("accent", filePath), - ); +function renderPathToolCall(toolName: string, filePath: string, theme: ToolTheme) { + return text(theme.fg('toolTitle', theme.bold(`${toolName} `)) + theme.fg('accent', filePath)); } export function registerFileTools(pi: ExtensionAPI) { - // ── LS tool ────────────────────────────────────────────────────────── - - const LsParams = Type.Object({ - path: Type.String({ - description: "Directory path to list", - }), - }); - - pi.registerTool({ - name: "LS", - label: "LS", - description: "List the contents of a directory, including hidden files.", - parameters: LsParams, - - async execute(_toolCallId, params, signal, _onUpdate, ctx) { - const targetPath = resolve(ctx.cwd, params.path); - - try { - const { stdout } = await execFileAsync("ls", ["-la", targetPath], { - cwd: ctx.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, - signal, - }); - - let output = stdout.trim(); - if (output.length > MAX_OUTPUT_CHARS) { - output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[LS: output truncated at 50KB]`; - } - - return { - content: [{ type: "text", text: output }], - details: { path: targetPath }, - }; - } catch (error: unknown) { - const err = error as ToolError; - return { - content: [ - { - type: "text", - text: `LS error: ${err.message ?? "Unknown error"}`, - }, - ], - details: { path: targetPath }, - }; - } - }, - renderCall(args, theme) { - return renderPathToolCall("LS", args.path, theme); - }, - renderResult(result, { expanded, isPartial }, theme) { - return renderResultSummary( - result, - expanded, - isPartial, - theme.fg("muted", stringDetail(result, "path")), - ); - }, - }); - - // ── Read tool ──────────────────────────────────────────────────────── - - const ReadParams = Type.Object({ - path: Type.String({ - description: "Path to the file to read", - }), - offset: Type.Optional( - Type.Number({ - description: "Line number to start reading from (0-indexed)", - }), - ), - limit: Type.Optional( - Type.Number({ - description: "Maximum number of lines to read", - }), - ), - }); - - pi.registerTool({ - name: "Read", - label: "Read", - description: - "Read the contents of a file. Returns the file content with line numbers.", - parameters: ReadParams, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { exists: false, totalLines: 0 }); - } - - const content = readFileSync(filePath, "utf-8"); - const lines = content.split("\n"); - - const startLine = params.offset ?? 0; - const endLine = params.limit - ? Math.min(startLine + params.limit, lines.length) - : Math.min(startLine + 2000, lines.length); - - const selectedLines = lines.slice(startLine, endLine); - const numberedLines = selectedLines.map( - (line, i) => `${startLine + i + 1}\t${line}`, - ); - - let output = numberedLines.join("\n"); - if (endLine < lines.length) { - output += `\n\n[Showing lines ${startLine + 1}-${endLine} of ${lines.length} total lines. Use offset to see more.]`; - } - - if (output.length > MAX_OUTPUT_CHARS) { - output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; - } - - return { - content: [{ type: "text", text: output }], - details: { path: filePath, totalLines: lines.length }, - }; - } catch (error: unknown) { - return fileError(error, "Read", filePath, { - exists: false, - totalLines: 0, - }); - } - }, - renderCall(args, theme) { - const range = - args.offset !== undefined || args.limit !== undefined - ? theme.fg( - "muted", - ` (from ${args.offset ?? 0}${args.limit ? `, ${args.limit} lines` : ""})`, - ) - : ""; - return text( - theme.fg("toolTitle", theme.bold("Read ")) + - theme.fg("accent", args.path) + - range, - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - return renderResultSummary( - result, - expanded, - isPartial, - detailRecord(result).exists === false - ? theme.fg("error", "File not found") - : theme.fg("muted", `${numberDetail(result, "totalLines")} line(s)`), - ); - }, - }); - - // ── Write tool ─────────────────────────────────────────────────────── - - const WriteParams = Type.Object({ - path: Type.String({ - description: "Path to the file to write", - }), - content: Type.String({ - description: "Content to write to the file", - }), - }); - - pi.registerTool({ - name: "Write", - label: "Write", - description: - "Create or overwrite a file with the given content. Creates parent directories if needed.", - parameters: WriteParams, - - prepareArguments(args) { - const input = recordFrom(args); - if (!input) return args as WriteArgs; - return { - ...input, - content: stringFrom(input.content) ?? stringFrom(input.contents), - } as WriteArgs; - }, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - try { - mkdirSync(dirname(filePath), { recursive: true }); - writeFileSync(filePath, params.content, "utf-8"); - - return { - content: [ - { - type: "text", - text: `Successfully wrote ${params.content.length} bytes to ${params.path}`, - }, - ], - details: { path: filePath, bytesWritten: params.content.length }, - }; - } catch (error: unknown) { - const err = error as ToolError; - return { - content: [ - { - type: "text", - text: `Write error: ${err.message ?? "Unknown error"}`, - }, - ], - details: { path: filePath, bytesWritten: 0 }, - }; - } - }, - renderCall(args, theme) { - return renderPathToolCall("Write", args.path, theme); - }, - renderResult(result, { expanded, isPartial }, theme) { - return renderResultSummary( - result, - expanded, - isPartial, - theme.fg( - "muted", - `${numberDetail(result, "bytesWritten")} bytes written`, - ), - ); - }, - }); - - // ── StrReplace tool ────────────────────────────────────────────────── - - const StrReplaceParams = Type.Object({ - path: Type.String({ - description: "Path to the file to modify", - }), - old_str: Type.String({ - description: "String to search for (exact match)", - }), - new_str: Type.String({ - description: "String to replace with", - }), - }); - - pi.registerTool({ - name: "StrReplace", - label: "StrReplace", - description: - "Replace all occurrences of a string in a file. The old_str must be an exact match.", - parameters: StrReplaceParams, - - prepareArguments(args) { - const input = recordFrom(args); - if (!input) return args as StrReplaceArgs; - return { - ...input, - old_str: - stringFrom(input.old_str) ?? - stringFrom(input.old_string) ?? - stringFrom(input.oldText) ?? - stringFrom(recordFrom(input.strReplace)?.oldText), - new_str: - stringFrom(input.new_str) ?? - stringFrom(input.new_string) ?? - stringFrom(input.newText) ?? - stringFrom(recordFrom(input.strReplace)?.newText), - } as StrReplaceArgs; - }, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { replacements: 0 }); - } - - const content = readFileSync(filePath, "utf-8"); - const count = content.split(params.old_str).length - 1; - - if (count === 0) { - return replacementResult( - `String not found in ${params.path}: "${params.old_str}"`, - filePath, - ); - } - - const newContent = content.replaceAll(params.old_str, params.new_str); - writeFileSync(filePath, newContent, "utf-8"); - - return { - content: [ - { - type: "text", - text: `Replaced ${count} occurrence(s) in ${params.path}`, - }, - ], - details: { path: filePath, replacements: count }, - }; - } catch (error: unknown) { - return fileError(error, "StrReplace", filePath, { replacements: 0 }); - } - }, - renderCall(args, theme) { - return renderPathToolCall("StrReplace", args.path, theme); - }, - renderResult(result, { expanded, isPartial }, theme) { - return renderReplacementResult(result, expanded, isPartial, theme); - }, - }); - - // ── Edit tool ──────────────────────────────────────────────────────── - - const EditItemParams = Type.Object({ - oldText: Type.String({ - description: "String to search for (exact match)", - }), - newText: Type.String({ - description: "String to replace with", - }), - replaceAll: Type.Optional( - Type.Boolean({ - description: - "Accepted for Cursor compatibility. Replacements are always applied to all matches.", - }), - ), - }); - - const EditParams = Type.Object({ - path: Type.String({ - description: "Path to the file to modify", - }), - edits: Type.Optional( - Type.Array(EditItemParams, { - description: "Exact text replacements to apply sequentially", - }), - ), - applyPatch: Type.Optional( - Type.Object({ - patchContent: Type.String({ - description: "Unsupported unified patch content", - }), - }), - ), - strReplace: Type.Optional(EditItemParams), - multiStrReplace: Type.Optional( - Type.Object({ - edits: Type.Array(EditItemParams), - }), - ), - }); - - pi.registerTool({ - name: "Edit", - label: "Edit", - description: - "Modify a file with exact text replacement. applyPatch is not supported by this Grok tool shim.", - parameters: EditParams, - - prepareArguments(args) { - const input = recordFrom(args); - if (!input) return args as EditArgs; - return { - ...input, - edits: editsFromArgs(input), - } as EditArgs; - }, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - if (!existsSync(filePath)) { - return fileNotFound(filePath, { replacements: 0 }); - } - - try { - if (!params.edits?.length) { - return { - content: [ - { - type: "text", - text: params.applyPatch - ? "Edit error: applyPatch is not supported by this Grok tool shim" - : "Edit error: provide at least one exact text replacement", - }, - ], - details: { path: filePath, replacements: 0 }, - }; - } - - const result = applyEdits( - readFileSync(filePath, "utf-8"), - params.edits, - ); - - if (result.replacements === 0) { - return replacementResult( - `No replacement strings found in ${params.path}`, - filePath, - ); - } - - writeFileSync(filePath, result.content, "utf-8"); - - return { - content: [ - { - type: "text", - text: `Applied ${result.replacements} replacement(s) in ${params.path}`, - }, - ], - details: { path: filePath, replacements: result.replacements }, - }; - } catch (error: unknown) { - return fileError(error, "Edit", filePath, { replacements: 0 }); - } - }, - renderCall(args, theme) { - return renderPathToolCall("Edit", args.path, theme); - }, - renderResult(result, { expanded, isPartial }, theme) { - return renderReplacementResult(result, expanded, isPartial, theme); - }, - }); - - // ── Delete tool ────────────────────────────────────────────────────── - - const DeleteParams = Type.Object({ - path: Type.String({ - description: "Path to the file to delete", - }), - }); - - pi.registerTool({ - name: "Delete", - label: "Delete", - description: "Delete a file from the filesystem.", - parameters: DeleteParams, - - async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); - - try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { deleted: false }); - } - - unlinkSync(filePath); - - return { - content: [ - { type: "text", text: `Successfully deleted ${params.path}` }, - ], - details: { path: filePath, deleted: true }, - }; - } catch (error: unknown) { - return fileError(error, "Delete", filePath, { deleted: false }); - } - }, - renderCall(args, theme) { - return renderPathToolCall("Delete", args.path, theme); - }, - renderResult(result, { expanded, isPartial }, theme) { - return renderResultSummary( - result, - expanded, - isPartial, - booleanDetail(result, "deleted") - ? theme.fg("muted", "Deleted") - : theme.fg("error", "Not deleted"), - ); - }, - }); + // ── LS tool ────────────────────────────────────────────────────────── + + const LsParams = Type.Object({ + path: Type.String({ + description: 'Directory path to list', + }), + }); + + pi.registerTool({ + name: 'LS', + label: 'LS', + description: 'List the contents of a directory, including hidden files.', + parameters: LsParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const targetPath = resolve(ctx.cwd, params.path); + + try { + const { stdout } = await execFileAsync('ls', ['-la', targetPath], { + cwd: ctx.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal, + }); + + let output = stdout.trim(); + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[LS: output truncated at 50KB]`; + } + + return { + content: [{ type: 'text', text: output }], + details: { path: targetPath }, + }; + } catch (error: unknown) { + const err = error as ToolError; + return { + content: [ + { + type: 'text', + text: `LS error: ${err.message ?? 'Unknown error'}`, + }, + ], + details: { path: targetPath }, + }; + } + }, + renderCall(args, theme) { + return renderPathToolCall('LS', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + theme.fg('muted', stringDetail(result, 'path')), + ); + }, + }); + + // ── Read tool ──────────────────────────────────────────────────────── + + const ReadParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to read', + }), + offset: Type.Optional( + Type.Number({ + description: 'Line number to start reading from (0-indexed)', + }), + ), + limit: Type.Optional( + Type.Number({ + description: 'Maximum number of lines to read', + }), + ), + }); + + pi.registerTool({ + name: 'Read', + label: 'Read', + description: 'Read the contents of a file. Returns the file content with line numbers.', + parameters: ReadParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { exists: false, totalLines: 0 }); + } + + const content = readFileSync(filePath, 'utf-8'); + const lines = content.split('\n'); + + const startLine = params.offset ?? 0; + const endLine = params.limit + ? Math.min(startLine + params.limit, lines.length) + : Math.min(startLine + 2000, lines.length); + + const selectedLines = lines.slice(startLine, endLine); + const numberedLines = selectedLines.map((line, i) => `${startLine + i + 1}\t${line}`); + + let output = numberedLines.join('\n'); + if (endLine < lines.length) { + output += `\n\n[Showing lines ${startLine + 1}-${endLine} of ${lines.length} total lines. Use offset to see more.]`; + } + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [{ type: 'text', text: output }], + details: { path: filePath, totalLines: lines.length }, + }; + } catch (error: unknown) { + return fileError(error, 'Read', filePath, { + exists: false, + totalLines: 0, + }); + } + }, + renderCall(args, theme) { + const range = + args.offset !== undefined || args.limit !== undefined + ? theme.fg( + 'muted', + ` (from ${args.offset ?? 0}${args.limit ? `, ${args.limit} lines` : ''})`, + ) + : ''; + return text( + theme.fg('toolTitle', theme.bold('Read ')) + theme.fg('accent', args.path) + range, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + detailRecord(result).exists === false + ? theme.fg('error', 'File not found') + : theme.fg('muted', `${numberDetail(result, 'totalLines')} line(s)`), + ); + }, + }); + + // ── Write tool ─────────────────────────────────────────────────────── + + const WriteParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to write', + }), + content: Type.String({ + description: 'Content to write to the file', + }), + }); + + pi.registerTool({ + name: 'Write', + label: 'Write', + description: + 'Create or overwrite a file with the given content. Creates parent directories if needed.', + parameters: WriteParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as WriteArgs; + return { + ...input, + content: stringFrom(input.content) ?? stringFrom(input.contents), + } as WriteArgs; + }, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + mkdirSync(dirname(filePath), { recursive: true }); + writeFileSync(filePath, params.content, 'utf-8'); + + return { + content: [ + { + type: 'text', + text: `Successfully wrote ${params.content.length} bytes to ${params.path}`, + }, + ], + details: { path: filePath, bytesWritten: params.content.length }, + }; + } catch (error: unknown) { + const err = error as ToolError; + return { + content: [ + { + type: 'text', + text: `Write error: ${err.message ?? 'Unknown error'}`, + }, + ], + details: { path: filePath, bytesWritten: 0 }, + }; + } + }, + renderCall(args, theme) { + return renderPathToolCall('Write', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + theme.fg('muted', `${numberDetail(result, 'bytesWritten')} bytes written`), + ); + }, + }); + + // ── StrReplace tool ────────────────────────────────────────────────── + + const StrReplaceParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to modify', + }), + old_str: Type.String({ + description: 'String to search for (exact match)', + }), + new_str: Type.String({ + description: 'String to replace with', + }), + }); + + pi.registerTool({ + name: 'StrReplace', + label: 'StrReplace', + description: + 'Replace all occurrences of a string in a file. The old_str must be an exact match.', + parameters: StrReplaceParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as StrReplaceArgs; + return { + ...input, + old_str: + stringFrom(input.old_str) ?? + stringFrom(input.old_string) ?? + stringFrom(input.oldText) ?? + stringFrom(recordFrom(input.strReplace)?.oldText), + new_str: + stringFrom(input.new_str) ?? + stringFrom(input.new_string) ?? + stringFrom(input.newText) ?? + stringFrom(recordFrom(input.strReplace)?.newText), + } as StrReplaceArgs; + }, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { replacements: 0 }); + } + + const content = readFileSync(filePath, 'utf-8'); + const count = content.split(params.old_str).length - 1; + + if (count === 0) { + return replacementResult( + `String not found in ${params.path}: "${params.old_str}"`, + filePath, + ); + } + + const newContent = content.replaceAll(params.old_str, params.new_str); + writeFileSync(filePath, newContent, 'utf-8'); + + return { + content: [ + { + type: 'text', + text: `Replaced ${count} occurrence(s) in ${params.path}`, + }, + ], + details: { path: filePath, replacements: count }, + }; + } catch (error: unknown) { + return fileError(error, 'StrReplace', filePath, { replacements: 0 }); + } + }, + renderCall(args, theme) { + return renderPathToolCall('StrReplace', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderReplacementResult(result, expanded, isPartial, theme); + }, + }); + + // ── Edit tool ──────────────────────────────────────────────────────── + + const EditItemParams = Type.Object({ + oldText: Type.String({ + description: 'String to search for (exact match)', + }), + newText: Type.String({ + description: 'String to replace with', + }), + replaceAll: Type.Optional( + Type.Boolean({ + description: + 'Accepted for Cursor compatibility. Replacements are always applied to all matches.', + }), + ), + }); + + const EditParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to modify', + }), + edits: Type.Optional( + Type.Array(EditItemParams, { + description: 'Exact text replacements to apply sequentially', + }), + ), + applyPatch: Type.Optional( + Type.Object({ + patchContent: Type.String({ + description: 'Unsupported unified patch content', + }), + }), + ), + strReplace: Type.Optional(EditItemParams), + multiStrReplace: Type.Optional( + Type.Object({ + edits: Type.Array(EditItemParams), + }), + ), + }); + + pi.registerTool({ + name: 'Edit', + label: 'Edit', + description: + 'Modify a file with exact text replacement. applyPatch is not supported by this Grok tool shim.', + parameters: EditParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as EditArgs; + return { + ...input, + edits: editsFromArgs(input), + } as EditArgs; + }, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + if (!existsSync(filePath)) { + return fileNotFound(filePath, { replacements: 0 }); + } + + try { + if (!params.edits?.length) { + return { + content: [ + { + type: 'text', + text: params.applyPatch + ? 'Edit error: applyPatch is not supported by this Grok tool shim' + : 'Edit error: provide at least one exact text replacement', + }, + ], + details: { path: filePath, replacements: 0 }, + }; + } + + const result = applyEdits(readFileSync(filePath, 'utf-8'), params.edits); + + if (result.replacements === 0) { + return replacementResult(`No replacement strings found in ${params.path}`, filePath); + } + + writeFileSync(filePath, result.content, 'utf-8'); + + return { + content: [ + { + type: 'text', + text: `Applied ${result.replacements} replacement(s) in ${params.path}`, + }, + ], + details: { path: filePath, replacements: result.replacements }, + }; + } catch (error: unknown) { + return fileError(error, 'Edit', filePath, { replacements: 0 }); + } + }, + renderCall(args, theme) { + return renderPathToolCall('Edit', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderReplacementResult(result, expanded, isPartial, theme); + }, + }); + + // ── Delete tool ────────────────────────────────────────────────────── + + const DeleteParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to delete', + }), + }); + + pi.registerTool({ + name: 'Delete', + label: 'Delete', + description: 'Delete a file from the filesystem.', + parameters: DeleteParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + if (!existsSync(filePath)) { + return fileNotFound(filePath, { deleted: false }); + } + + unlinkSync(filePath); + + return { + content: [{ type: 'text', text: `Successfully deleted ${params.path}` }], + details: { path: filePath, deleted: true }, + }; + } catch (error: unknown) { + return fileError(error, 'Delete', filePath, { deleted: false }); + } + }, + renderCall(args, theme) { + return renderPathToolCall('Delete', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + booleanDetail(result, 'deleted') + ? theme.fg('muted', 'Deleted') + : theme.fg('error', 'Not deleted'), + ); + }, + }); } diff --git a/src/tools/register.ts b/src/tools/register.ts index 07311b7..e798f1f 100644 --- a/src/tools/register.ts +++ b/src/tools/register.ts @@ -1,10 +1,10 @@ -import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; -import { registerFileTools } from "./files.js"; -import { registerSearchTools } from "./search.js"; -import { registerShellTool } from "./shell.js"; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { registerFileTools } from './files.js'; +import { registerSearchTools } from './search.js'; +import { registerShellTool } from './shell.js'; export function registerGrokTools(pi: ExtensionAPI) { - registerSearchTools(pi); - registerFileTools(pi); - registerShellTool(pi); + registerSearchTools(pi); + registerFileTools(pi); + registerShellTool(pi); } diff --git a/src/tools/rendering.ts b/src/tools/rendering.ts index d02ece9..c442677 100644 --- a/src/tools/rendering.ts +++ b/src/tools/rendering.ts @@ -1,197 +1,180 @@ -import { execFile } from "node:child_process"; -import { promisify } from "node:util"; -import { Text } from "@earendil-works/pi-tui"; +import { execFile } from 'node:child_process'; +import { promisify } from 'node:util'; +import { Text } from '@earendil-works/pi-tui'; const execFileAsync = promisify(execFile); export const MAX_OUTPUT_CHARS = 50_000; export const MAX_LINES = 500; -export function recordFrom( - value: unknown, -): Record | undefined { - if (!value || typeof value !== "object") return undefined; - return value as Record; +export function recordFrom(value: unknown): Record | undefined { + if (!value || typeof value !== 'object') return undefined; + return value as Record; } export function stringFrom(value: unknown): string | undefined { - if (typeof value !== "string") return undefined; - return value; + if (typeof value !== 'string') return undefined; + return value; } export function truncateLines(lines: string[]): string { - if (lines.length > MAX_LINES) { - return ( - lines.slice(0, MAX_LINES).join("\n") + - `\n\n[Showing first ${MAX_LINES} of ${lines.length} results. Refine your pattern to narrow results.]` - ); - } - return lines.join("\n"); + if (lines.length > MAX_LINES) { + return ( + lines.slice(0, MAX_LINES).join('\n') + + `\n\n[Showing first ${MAX_LINES} of ${lines.length} results. Refine your pattern to narrow results.]` + ); + } + return lines.join('\n'); } export function truncateChars(output: string): string { - if (output.length > MAX_OUTPUT_CHARS) { - return `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; - } - return output; + if (output.length > MAX_OUTPUT_CHARS) { + return `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + return output; } let rgAvailable: boolean | undefined; export async function hasRipgrep(): Promise { - if (rgAvailable !== undefined) return rgAvailable; - try { - await execFileAsync("rg", ["--version"]); - rgAvailable = true; - } catch { - rgAvailable = false; - } - return rgAvailable; + if (rgAvailable !== undefined) return rgAvailable; + try { + await execFileAsync('rg', ['--version']); + rgAvailable = true; + } catch { + rgAvailable = false; + } + return rgAvailable; } export type ToolError = { code?: number; message?: string }; export type ToolResult = { - content: [{ type: "text"; text: string }]; - details: T; + content: [{ type: 'text'; text: string }]; + details: T; }; export function text(text: string): Text { - return new Text(text, 0, 0); + return new Text(text, 0, 0); } function firstText(result: { content: { type: string; text?: string }[] }) { - const first = result.content[0]; - if (first?.type !== "text") return undefined; - return first.text; + const first = result.content[0]; + if (first?.type !== 'text') return undefined; + return first.text; } export function renderResultText( - result: { content: { type: string; text?: string }[] }, - expanded: boolean, - summary: string, + result: { content: { type: string; text?: string }[] }, + expanded: boolean, + summary: string, ): Text { - if (expanded) return text(firstText(result) ?? summary); - return text(summary); + if (expanded) return text(firstText(result) ?? summary); + return text(summary); } export function renderRunning(isPartial: boolean): Text | undefined { - if (!isPartial) return undefined; - return text("Running..."); + if (!isPartial) return undefined; + return text('Running...'); } export function renderResultSummary( - result: { content: { type: string; text?: string }[] }, - expanded: boolean, - isPartial: boolean, - summary: string, + result: { content: { type: string; text?: string }[] }, + expanded: boolean, + isPartial: boolean, + summary: string, ): Text { - const running = renderRunning(isPartial); - if (running) return running; - return renderResultText(result, expanded, summary); + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText(result, expanded, summary); } -export function detailRecord(result: { - details: unknown; -}): Record { - if (!result.details || typeof result.details !== "object") return {}; - return result.details as Record; +export function detailRecord(result: { details: unknown }): Record { + if (!result.details || typeof result.details !== 'object') return {}; + return result.details as Record; } -export function numberDetail( - result: { details: unknown }, - key: string, -): number { - const value = detailRecord(result)[key]; - if (typeof value !== "number") return 0; - return value; +export function numberDetail(result: { details: unknown }, key: string): number { + const value = detailRecord(result)[key]; + if (typeof value !== 'number') return 0; + return value; } -export function stringDetail( - result: { details: unknown }, - key: string, -): string { - const value = detailRecord(result)[key]; - if (typeof value !== "string") return ""; - return value; +export function stringDetail(result: { details: unknown }, key: string): string { + const value = detailRecord(result)[key]; + if (typeof value !== 'string') return ''; + return value; } -export function booleanDetail( - result: { details: unknown }, - key: string, -): boolean { - const value = detailRecord(result)[key]; - return value === true; +export function booleanDetail(result: { details: unknown }, key: string): boolean { + const value = detailRecord(result)[key]; + return value === true; } type FileDetails = { path: string; [key: string]: unknown }; export function fileNotFound( - filePath: string, - extraDetails: Omit, + filePath: string, + extraDetails: Omit, ): ToolResult { - return { - content: [{ type: "text", text: `File not found: ${filePath}` }], - details: { path: filePath, ...extraDetails } as T, - }; + return { + content: [{ type: 'text', text: `File not found: ${filePath}` }], + details: { path: filePath, ...extraDetails } as T, + }; } export function fileError( - error: unknown, - toolName: string, - filePath: string, - extraDetails: Omit, + error: unknown, + toolName: string, + filePath: string, + extraDetails: Omit, ): ToolResult { - const err = error as ToolError; - return { - content: [ - { - type: "text", - text: `${toolName} error: ${err.message ?? "Unknown error"}`, - }, - ], - details: { path: filePath, ...extraDetails } as T, - }; -} - -export function toolError( - error: unknown, - toolName: string, - emptyDetails: T, -): ToolResult { - const err = error as ToolError; - if (err.code === 1) { - return { - content: [{ type: "text", text: "No matches found" }], - details: emptyDetails, - }; - } - return { - content: [ - { - type: "text", - text: `${toolName} error: ${err.message ?? "Unknown error"}`, - }, - ], - details: emptyDetails, - }; + const err = error as ToolError; + return { + content: [ + { + type: 'text', + text: `${toolName} error: ${err.message ?? 'Unknown error'}`, + }, + ], + details: { path: filePath, ...extraDetails } as T, + }; +} + +export function toolError(error: unknown, toolName: string, emptyDetails: T): ToolResult { + const err = error as ToolError; + if (err.code === 1) { + return { + content: [{ type: 'text', text: 'No matches found' }], + details: emptyDetails, + }; + } + return { + content: [ + { + type: 'text', + text: `${toolName} error: ${err.message ?? 'Unknown error'}`, + }, + ], + details: emptyDetails, + }; } export async function execWithRgFallback( - rgArgs: string[], - grepArgs: string[], - options: { cwd: string; signal?: AbortSignal }, + rgArgs: string[], + grepArgs: string[], + options: { cwd: string; signal?: AbortSignal }, ): Promise { - if (await hasRipgrep()) { - const result = await execFileAsync("rg", rgArgs, { - cwd: options.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, - signal: options.signal, - }); - return result.stdout; - } - const result = await execFileAsync("grep", grepArgs, { - cwd: options.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, - signal: options.signal, - }); - return result.stdout; + if (await hasRipgrep()) { + const result = await execFileAsync('rg', rgArgs, { + cwd: options.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal: options.signal, + }); + return result.stdout; + } + const result = await execFileAsync('grep', grepArgs, { + cwd: options.cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + signal: options.signal, + }); + return result.stdout; } diff --git a/src/tools/search.ts b/src/tools/search.ts index bd4bc7d..34b9b63 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -1,22 +1,22 @@ -import { execFile } from "node:child_process"; -import { resolve } from "node:path"; -import { promisify } from "node:util"; -import { Type } from "@earendil-works/pi-ai"; -import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { execFile } from 'node:child_process'; +import { resolve } from 'node:path'; +import { promisify } from 'node:util'; +import { Type } from '@earendil-works/pi-ai'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { - execWithRgFallback, - hasRipgrep, - MAX_OUTPUT_CHARS, - numberDetail, - recordFrom, - renderResultText, - renderRunning, - stringFrom, - text, - toolError, - truncateChars, - truncateLines, -} from "./rendering.js"; + execWithRgFallback, + hasRipgrep, + MAX_OUTPUT_CHARS, + numberDetail, + recordFrom, + renderResultText, + renderRunning, + stringFrom, + text, + toolError, + truncateChars, + truncateLines, +} from './rendering.js'; const execFileAsync = promisify(execFile); @@ -24,187 +24,176 @@ type GrepArgs = { pattern: string; path?: string; include?: string }; type GlobArgs = { pattern: string; path?: string }; export function registerSearchTools(pi: ExtensionAPI) { - const GrepParams = Type.Object({ - pattern: Type.String({ - description: "Regex pattern to search for in file contents", - }), - path: Type.Optional( - Type.String({ - description: - "Directory or file to search. Defaults to current working directory.", - }), - ), - include: Type.Optional( - Type.String({ - description: - "Glob pattern to filter which files are searched (e.g. *.ts, **/*.md)", - }), - ), - }); - - pi.registerTool({ - name: "Grep", - label: "Grep", - description: - "Search for a regex pattern in file contents. Returns matching lines with file path and line number. Use the include parameter to filter by file type.", - parameters: GrepParams, - - prepareArguments(args) { - const input = recordFrom(args); - if (!input) return args as GrepArgs; - return { - ...input, - include: stringFrom(input.include) ?? stringFrom(input.glob_filter), - } as GrepArgs; - }, - - async execute(_toolCallId, params, signal, _onUpdate, ctx) { - const searchPath = resolve(ctx.cwd, params.path ?? "."); - - try { - const rgArgs = ["-n", "--no-heading", "--color=never"]; - if (params.include) rgArgs.push("--glob", params.include); - rgArgs.push(params.pattern, searchPath); - - const grepArgs = ["-r", "-n", "--color=never"]; - if (params.include) grepArgs.push(`--include=${params.include}`); - grepArgs.push(params.pattern, searchPath); - - const stdout = await execWithRgFallback(rgArgs, grepArgs, { - cwd: ctx.cwd, - signal, - }); - - const lines = stdout.trim().split("\n").filter(Boolean); - if (lines.length === 0) { - return { - content: [{ type: "text", text: "No matches found" }], - details: { matchCount: 0 }, - }; - } - - return { - content: [ - { type: "text", text: truncateChars(truncateLines(lines)) }, - ], - details: { matchCount: lines.length }, - }; - } catch (error: unknown) { - return toolError(error, "Grep", { matchCount: 0 }); - } - }, - renderCall(args, theme) { - const path = args.path ? theme.fg("muted", ` in ${args.path}`) : ""; - const include = args.include ? theme.fg("dim", ` [${args.include}]`) : ""; - return text( - theme.fg("toolTitle", theme.bold("Grep ")) + - theme.fg("accent", `"${args.pattern}"`) + - path + - include, - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - const matchCount = numberDetail(result, "matchCount"); - return renderResultText( - result, - expanded, - matchCount === 0 - ? theme.fg("dim", "No matches") - : theme.fg("muted", `${matchCount} match(es)`), - ); - }, - }); - - const GlobParams = Type.Object({ - pattern: Type.String({ - description: "Glob pattern to match files (e.g. **/*.ts, src/**/*.json)", - }), - path: Type.Optional( - Type.String({ - description: - "Directory to search within. Defaults to current working directory.", - }), - ), - }); - - pi.registerTool({ - name: "Glob", - label: "Glob", - description: - "Find files matching a glob pattern. Returns a list of matching file paths sorted by modification time (newest first).", - parameters: GlobParams, - - prepareArguments(args) { - const input = recordFrom(args); - if (!input) return args as GlobArgs; - return { - ...input, - pattern: stringFrom(input.pattern) ?? stringFrom(input.glob_pattern), - } as GlobArgs; - }, - - async execute(_toolCallId, params, signal, _onUpdate, ctx) { - const searchPath = resolve(ctx.cwd, params.path ?? "."); - - try { - let files: string[]; - - if (await hasRipgrep()) { - const result = await execFileAsync( - "rg", - ["--files", "--color=never", "--glob", params.pattern, searchPath], - { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, - ); - files = result.stdout.trim().split("\n").filter(Boolean); - } else { - // find fallback — convert **/*.ext → -name "*.ext" - const basename = params.pattern.replace(/^(\*\*\/)+/, ""); - const result = await execFileAsync( - "find", - [searchPath, "-type", "f", "-name", basename], - { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, - ); - files = result.stdout.trim().split("\n").filter(Boolean); - } - - if (files.length === 0) { - return { - content: [{ type: "text", text: "No files found" }], - details: { fileCount: 0 }, - }; - } - - return { - content: [ - { type: "text", text: truncateChars(truncateLines(files)) }, - ], - details: { fileCount: files.length }, - }; - } catch (error: unknown) { - return toolError(error, "Glob", { fileCount: 0 }); - } - }, - renderCall(args, theme) { - const path = args.path ? theme.fg("muted", ` in ${args.path}`) : ""; - return text( - theme.fg("toolTitle", theme.bold("Glob ")) + - theme.fg("accent", args.pattern) + - path, - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - const fileCount = numberDetail(result, "fileCount"); - return renderResultText( - result, - expanded, - fileCount === 0 - ? theme.fg("dim", "No files") - : theme.fg("muted", `${fileCount} file(s)`), - ); - }, - }); + const GrepParams = Type.Object({ + pattern: Type.String({ + description: 'Regex pattern to search for in file contents', + }), + path: Type.Optional( + Type.String({ + description: 'Directory or file to search. Defaults to current working directory.', + }), + ), + include: Type.Optional( + Type.String({ + description: 'Glob pattern to filter which files are searched (e.g. *.ts, **/*.md)', + }), + ), + }); + + pi.registerTool({ + name: 'Grep', + label: 'Grep', + description: + 'Search for a regex pattern in file contents. Returns matching lines with file path and line number. Use the include parameter to filter by file type.', + parameters: GrepParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as GrepArgs; + return { + ...input, + include: stringFrom(input.include) ?? stringFrom(input.glob_filter), + } as GrepArgs; + }, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const searchPath = resolve(ctx.cwd, params.path ?? '.'); + + try { + const rgArgs = ['-n', '--no-heading', '--color=never']; + if (params.include) rgArgs.push('--glob', params.include); + rgArgs.push(params.pattern, searchPath); + + const grepArgs = ['-r', '-n', '--color=never']; + if (params.include) grepArgs.push(`--include=${params.include}`); + grepArgs.push(params.pattern, searchPath); + + const stdout = await execWithRgFallback(rgArgs, grepArgs, { + cwd: ctx.cwd, + signal, + }); + + const lines = stdout.trim().split('\n').filter(Boolean); + if (lines.length === 0) { + return { + content: [{ type: 'text', text: 'No matches found' }], + details: { matchCount: 0 }, + }; + } + + return { + content: [{ type: 'text', text: truncateChars(truncateLines(lines)) }], + details: { matchCount: lines.length }, + }; + } catch (error: unknown) { + return toolError(error, 'Grep', { matchCount: 0 }); + } + }, + renderCall(args, theme) { + const path = args.path ? theme.fg('muted', ` in ${args.path}`) : ''; + const include = args.include ? theme.fg('dim', ` [${args.include}]`) : ''; + return text( + theme.fg('toolTitle', theme.bold('Grep ')) + + theme.fg('accent', `"${args.pattern}"`) + + path + + include, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const matchCount = numberDetail(result, 'matchCount'); + return renderResultText( + result, + expanded, + matchCount === 0 + ? theme.fg('dim', 'No matches') + : theme.fg('muted', `${matchCount} match(es)`), + ); + }, + }); + + const GlobParams = Type.Object({ + pattern: Type.String({ + description: 'Glob pattern to match files (e.g. **/*.ts, src/**/*.json)', + }), + path: Type.Optional( + Type.String({ + description: 'Directory to search within. Defaults to current working directory.', + }), + ), + }); + + pi.registerTool({ + name: 'Glob', + label: 'Glob', + description: + 'Find files matching a glob pattern. Returns a list of matching file paths sorted by modification time (newest first).', + parameters: GlobParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as GlobArgs; + return { + ...input, + pattern: stringFrom(input.pattern) ?? stringFrom(input.glob_pattern), + } as GlobArgs; + }, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const searchPath = resolve(ctx.cwd, params.path ?? '.'); + + try { + let files: string[]; + + if (await hasRipgrep()) { + const result = await execFileAsync( + 'rg', + ['--files', '--color=never', '--glob', params.pattern, searchPath], + { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, + ); + files = result.stdout.trim().split('\n').filter(Boolean); + } else { + // find fallback — convert **/*.ext → -name "*.ext" + const basename = params.pattern.replace(/^(\*\*\/)+/, ''); + const result = await execFileAsync( + 'find', + [searchPath, '-type', 'f', '-name', basename], + { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, + ); + files = result.stdout.trim().split('\n').filter(Boolean); + } + + if (files.length === 0) { + return { + content: [{ type: 'text', text: 'No files found' }], + details: { fileCount: 0 }, + }; + } + + return { + content: [{ type: 'text', text: truncateChars(truncateLines(files)) }], + details: { fileCount: files.length }, + }; + } catch (error: unknown) { + return toolError(error, 'Glob', { fileCount: 0 }); + } + }, + renderCall(args, theme) { + const path = args.path ? theme.fg('muted', ` in ${args.path}`) : ''; + return text( + theme.fg('toolTitle', theme.bold('Glob ')) + theme.fg('accent', args.pattern) + path, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const fileCount = numberDetail(result, 'fileCount'); + return renderResultText( + result, + expanded, + fileCount === 0 ? theme.fg('dim', 'No files') : theme.fg('muted', `${fileCount} file(s)`), + ); + }, + }); } diff --git a/src/tools/shell.ts b/src/tools/shell.ts index 631e259..b8be923 100644 --- a/src/tools/shell.ts +++ b/src/tools/shell.ts @@ -1,124 +1,113 @@ -import { execFile } from "node:child_process"; -import { resolve } from "node:path"; -import { promisify } from "node:util"; -import { Type } from "@earendil-works/pi-ai"; -import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; +import { execFile } from 'node:child_process'; +import { resolve } from 'node:path'; +import { promisify } from 'node:util'; +import { Type } from '@earendil-works/pi-ai'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { - MAX_OUTPUT_CHARS, - numberDetail, - renderResultText, - renderRunning, - text, -} from "./rendering.js"; + MAX_OUTPUT_CHARS, + numberDetail, + renderResultText, + renderRunning, + text, +} from './rendering.js'; const execFileAsync = promisify(execFile); export function registerShellTool(pi: ExtensionAPI) { - // ── Shell tool ─────────────────────────────────────────────────────── + // ── Shell tool ─────────────────────────────────────────────────────── - const ShellParams = Type.Object({ - command: Type.String({ - description: "Shell command to execute", - }), - working_directory: Type.Optional( - Type.String({ - description: "Working directory for the command", - }), - ), - timeout: Type.Optional( - Type.Number({ - description: "Timeout in milliseconds (default: 120000)", - }), - ), - }); + const ShellParams = Type.Object({ + command: Type.String({ + description: 'Shell command to execute', + }), + working_directory: Type.Optional( + Type.String({ + description: 'Working directory for the command', + }), + ), + timeout: Type.Optional( + Type.Number({ + description: 'Timeout in milliseconds (default: 120000)', + }), + ), + }); - pi.registerTool({ - name: "Shell", - label: "Shell", - description: - "Execute a shell command and return stdout, stderr, and exit code.", - parameters: ShellParams, + pi.registerTool({ + name: 'Shell', + label: 'Shell', + description: 'Execute a shell command and return stdout, stderr, and exit code.', + parameters: ShellParams, - async execute(_toolCallId, params, signal, _onUpdate, ctx) { - const cwd = params.working_directory - ? resolve(ctx.cwd, params.working_directory) - : ctx.cwd; - const timeout = params.timeout ?? 120_000; + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const cwd = params.working_directory ? resolve(ctx.cwd, params.working_directory) : ctx.cwd; + const timeout = params.timeout ?? 120_000; - try { - const { stdout, stderr } = await execFileAsync( - "bash", - ["-c", params.command], - { - cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, - timeout, - signal, - }, - ); + try { + const { stdout, stderr } = await execFileAsync('bash', ['-c', params.command], { + cwd, + maxBuffer: MAX_OUTPUT_CHARS * 2, + timeout, + signal, + }); - let output = ""; - if (stdout) output += stdout; - if (stderr) output += `\n[stderr]\n${stderr}`; + let output = ''; + if (stdout) output += stdout; + if (stderr) output += `\n[stderr]\n${stderr}`; - if (output.length > MAX_OUTPUT_CHARS) { - output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; - } + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } - return { - content: [{ type: "text", text: output || "(no output)" }], - details: { exitCode: 0, command: params.command }, - }; - } catch (error: unknown) { - const err = error as { - code?: number; - message?: string; - stdout?: string; - stderr?: string; - }; + return { + content: [{ type: 'text', text: output || '(no output)' }], + details: { exitCode: 0, command: params.command }, + }; + } catch (error: unknown) { + const err = error as { + code?: number; + message?: string; + stdout?: string; + stderr?: string; + }; - let output = ""; - if (err.stdout) output += err.stdout; - if (err.stderr) output += `\n[stderr]\n${err.stderr}`; + let output = ''; + if (err.stdout) output += err.stdout; + if (err.stderr) output += `\n[stderr]\n${err.stderr}`; - if (output.length > MAX_OUTPUT_CHARS) { - output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; - } + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } - return { - content: [ - { - type: "text", - text: `Shell error (exit code ${err.code ?? "unknown"}): ${err.message ?? "Unknown error"}${output ? `\n${output}` : ""}`, - }, - ], - details: { - exitCode: err.code ?? 1, - command: params.command, - }, - }; - } - }, - renderCall(args, theme) { - const cwd = args.working_directory - ? theme.fg("muted", ` in ${args.working_directory}`) - : ""; - return text( - theme.fg("toolTitle", theme.bold("Shell ")) + - theme.fg("accent", args.command) + - cwd, - ); - }, - renderResult(result, { expanded, isPartial }, theme) { - const running = renderRunning(isPartial); - if (running) return running; - return renderResultText( - result, - expanded, - numberDetail(result, "exitCode") === 0 - ? theme.fg("muted", "Exit 0") - : theme.fg("warning", `Exit ${numberDetail(result, "exitCode")}`), - ); - }, - }); + return { + content: [ + { + type: 'text', + text: `Shell error (exit code ${err.code ?? 'unknown'}): ${err.message ?? 'Unknown error'}${output ? `\n${output}` : ''}`, + }, + ], + details: { + exitCode: err.code ?? 1, + command: params.command, + }, + }; + } + }, + renderCall(args, theme) { + const cwd = args.working_directory ? theme.fg('muted', ` in ${args.working_directory}`) : ''; + return text( + theme.fg('toolTitle', theme.bold('Shell ')) + theme.fg('accent', args.command) + cwd, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText( + result, + expanded, + numberDetail(result, 'exitCode') === 0 + ? theme.fg('muted', 'Exit 0') + : theme.fg('warning', `Exit ${numberDetail(result, 'exitCode')}`), + ); + }, + }); } diff --git a/tests/auth/oauth.test.ts b/tests/auth/oauth.test.ts index d1ce809..4f0d32e 100644 --- a/tests/auth/oauth.test.ts +++ b/tests/auth/oauth.test.ts @@ -1,286 +1,280 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { getBaseUrl, login, refresh } from "../../src/auth/oauth.js"; -import { XaiErrorCode } from "../../src/shared/errors.js"; +import { afterEach, describe, expect, it, vi } from 'vitest'; +import { getBaseUrl, login, refresh } from '../../src/auth/oauth.js'; +import { XaiErrorCode } from '../../src/shared/errors.js'; const originalEnv = { ...process.env }; const originalFetch = globalThis.fetch; const storedRefreshCredentials = { - access: "access-token", - refresh: "refresh-token", - expires: 0, - tokenEndpoint: "https://auth.x.ai/oauth/token", + access: 'access-token', + refresh: 'refresh-token', + expires: 0, + tokenEndpoint: 'https://auth.x.ai/oauth/token', }; const credentialsWithoutEndpoint = { - access: "old-access", - refresh: "old-refresh", - expires: 0, + access: 'old-access', + refresh: 'old-refresh', + expires: 0, }; const discoveryDocument = { - authorization_endpoint: "https://auth.x.ai/oauth/authorize", - token_endpoint: "https://auth.x.ai/oauth/token", + authorization_endpoint: 'https://auth.x.ai/oauth/authorize', + token_endpoint: 'https://auth.x.ai/oauth/token', }; afterEach(() => { - process.env = { ...originalEnv }; - globalThis.fetch = originalFetch; - vi.restoreAllMocks(); - vi.useRealTimers(); + process.env = { ...originalEnv }; + globalThis.fetch = originalFetch; + vi.restoreAllMocks(); + vi.useRealTimers(); }); -describe("OAuth helpers without network access", () => { - it("resolves and trims the configured base URL", () => { - delete process.env.GROK_CLI_BASE_URL; - delete process.env.PI_GROK_CLI_BASE_URL; - expect(getBaseUrl()).toBe("https://cli-chat-proxy.grok.com/v1"); +describe('OAuth helpers without network access', () => { + it('resolves and trims the configured base URL', () => { + delete process.env.GROK_CLI_BASE_URL; + delete process.env.PI_GROK_CLI_BASE_URL; + expect(getBaseUrl()).toBe('https://cli-chat-proxy.grok.com/v1'); - process.env.GROK_CLI_BASE_URL = "https://example.invalid/v1///"; - expect(getBaseUrl()).toBe("https://example.invalid/v1"); + process.env.GROK_CLI_BASE_URL = 'https://example.invalid/v1///'; + expect(getBaseUrl()).toBe('https://example.invalid/v1'); - process.env.PI_GROK_CLI_BASE_URL = "https://override.invalid/api//"; - expect(getBaseUrl()).toBe("https://override.invalid/api"); - }); + process.env.PI_GROK_CLI_BASE_URL = 'https://override.invalid/api//'; + expect(getBaseUrl()).toBe('https://override.invalid/api'); + }); - it("rejects refresh credentials with no refresh token before fetching", async () => { - const fetchMock = vi.fn(); - globalThis.fetch = fetchMock; + it('rejects refresh credentials with no refresh token before fetching', async () => { + const fetchMock = vi.fn(); + globalThis.fetch = fetchMock; - await expect( - refresh({ - access: "access-token", - refresh: "", - expires: 0, - tokenEndpoint: "https://auth.x.ai/oauth/token", - }), - ).rejects.toMatchObject({ - code: XaiErrorCode.REFRESH_MISSING, - reloginRequired: true, - }); - expect(fetchMock).not.toHaveBeenCalled(); - }); + await expect( + refresh({ + access: 'access-token', + refresh: '', + expires: 0, + tokenEndpoint: 'https://auth.x.ai/oauth/token', + }), + ).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_MISSING, + reloginRequired: true, + }); + expect(fetchMock).not.toHaveBeenCalled(); + }); - it("refreshes credentials with the configured token endpoint", async () => { - vi.useFakeTimers(); - vi.setSystemTime(1_700_000_000_000); - process.env.PI_GROK_CLI_BASE_URL = "https://proxy.example/v1//"; - const fetchMock = vi.fn(async () => - Response.json({ - access_token: "new-access", - refresh_token: "new-refresh", - expires_in: 600, - id_token: "new-id", - token_type: "DPoP", - }), - ); - globalThis.fetch = fetchMock; + it('refreshes credentials with the configured token endpoint', async () => { + vi.useFakeTimers(); + vi.setSystemTime(1_700_000_000_000); + process.env.PI_GROK_CLI_BASE_URL = 'https://proxy.example/v1//'; + const fetchMock = vi.fn(async () => + Response.json({ + access_token: 'new-access', + refresh_token: 'new-refresh', + expires_in: 600, + id_token: 'new-id', + token_type: 'DPoP', + }), + ); + globalThis.fetch = fetchMock; - await expect( - refresh({ - access: "old-access", - refresh: "old-refresh", - expires: 0, - tokenEndpoint: "https://auth.x.ai/oauth/token", - idToken: "old-id", - tokenType: "Bearer", - }), - ).resolves.toMatchObject({ - access: "new-access", - refresh: "new-refresh", - expires: 1_700_000_480_000, - tokenEndpoint: "https://auth.x.ai/oauth/token", - idToken: "new-id", - tokenType: "DPoP", - baseUrl: "https://proxy.example/v1", - }); + await expect( + refresh({ + access: 'old-access', + refresh: 'old-refresh', + expires: 0, + tokenEndpoint: 'https://auth.x.ai/oauth/token', + idToken: 'old-id', + tokenType: 'Bearer', + }), + ).resolves.toMatchObject({ + access: 'new-access', + refresh: 'new-refresh', + expires: 1_700_000_480_000, + tokenEndpoint: 'https://auth.x.ai/oauth/token', + idToken: 'new-id', + tokenType: 'DPoP', + baseUrl: 'https://proxy.example/v1', + }); - expect(fetchMock).toHaveBeenCalledOnce(); - expect(fetchMock.mock.calls[0]?.[0]).toBe("https://auth.x.ai/oauth/token"); - expect(fetchMock.mock.calls[0]?.[1]).toMatchObject({ - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - Accept: "application/json", - }, - }); - expect( - (fetchMock.mock.calls[0]?.[1]?.body as URLSearchParams).toString(), - ).toBe( - "grant_type=refresh_token&client_id=b1a00492-073a-47ea-816f-4c329264a828&refresh_token=old-refresh", - ); - }); + expect(fetchMock).toHaveBeenCalledOnce(); + expect(fetchMock.mock.calls[0]?.[0]).toBe('https://auth.x.ai/oauth/token'); + expect(fetchMock.mock.calls[0]?.[1]).toMatchObject({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }, + }); + expect((fetchMock.mock.calls[0]?.[1]?.body as URLSearchParams).toString()).toBe( + 'grant_type=refresh_token&client_id=b1a00492-073a-47ea-816f-4c329264a828&refresh_token=old-refresh', + ); + }); - it("keeps the existing refresh token and metadata when refresh omits optional fields", async () => { - const fetchMock = vi.fn(async () => - Response.json({ access_token: "new-access", expires_in: "900" }), - ); - globalThis.fetch = fetchMock; + it('keeps the existing refresh token and metadata when refresh omits optional fields', async () => { + const fetchMock = vi.fn(async () => + Response.json({ access_token: 'new-access', expires_in: '900' }), + ); + globalThis.fetch = fetchMock; - await expect( - refresh({ - access: "old-access", - refresh: "old-refresh", - expires: 0, - discovery: { - authorization_endpoint: "https://auth.x.ai/oauth/authorize", - token_endpoint: "https://accounts.x.ai/oauth/token", - }, - idToken: "old-id", - tokenType: "Bearer", - }), - ).resolves.toMatchObject({ - access: "new-access", - refresh: "old-refresh", - tokenEndpoint: "https://accounts.x.ai/oauth/token", - idToken: "old-id", - tokenType: "Bearer", - }); - }); + await expect( + refresh({ + access: 'old-access', + refresh: 'old-refresh', + expires: 0, + discovery: { + authorization_endpoint: 'https://auth.x.ai/oauth/authorize', + token_endpoint: 'https://accounts.x.ai/oauth/token', + }, + idToken: 'old-id', + tokenType: 'Bearer', + }), + ).resolves.toMatchObject({ + access: 'new-access', + refresh: 'old-refresh', + tokenEndpoint: 'https://accounts.x.ai/oauth/token', + idToken: 'old-id', + tokenType: 'Bearer', + }); + }); - it("marks unauthorized refresh failures as requiring login", async () => { - const fetchMock = vi.fn( - async () => new Response("revoked", { status: 401 }), - ); - globalThis.fetch = fetchMock; + it('marks unauthorized refresh failures as requiring login', async () => { + const fetchMock = vi.fn(async () => new Response('revoked', { status: 401 })); + globalThis.fetch = fetchMock; - await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ - code: XaiErrorCode.REFRESH_FAILED, - reloginRequired: true, - message: "xAI token refresh failed: 401 revoked", - }); - }); + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + reloginRequired: true, + message: 'xAI token refresh failed: 401 revoked', + }); + }); - it("keeps server refresh failures retryable", async () => { - const fetchMock = vi.fn( - async () => new Response("temporarily unavailable", { status: 500 }), - ); - globalThis.fetch = fetchMock; + it('keeps server refresh failures retryable', async () => { + const fetchMock = vi.fn( + async () => new Response('temporarily unavailable', { status: 500 }), + ); + globalThis.fetch = fetchMock; - await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ - code: XaiErrorCode.REFRESH_FAILED, - reloginRequired: false, - message: "xAI token refresh failed: 500 temporarily unavailable", - }); - }); + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + reloginRequired: false, + message: 'xAI token refresh failed: 500 temporarily unavailable', + }); + }); - it("rejects refresh responses without an access token", async () => { - const fetchMock = vi.fn(async () => Response.json({})); - globalThis.fetch = fetchMock; + it('rejects refresh responses without an access token', async () => { + const fetchMock = vi.fn(async () => Response.json({})); + globalThis.fetch = fetchMock; - await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ - code: XaiErrorCode.REFRESH_FAILED, - reloginRequired: true, - message: "xAI token refresh did not return access_token.", - }); - }); + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + reloginRequired: true, + message: 'xAI token refresh did not return access_token.', + }); + }); - it("rejects unsafe token endpoints before fetching", async () => { - const fetchMock = vi.fn(); - globalThis.fetch = fetchMock; + it('rejects unsafe token endpoints before fetching', async () => { + const fetchMock = vi.fn(); + globalThis.fetch = fetchMock; - await expect( - refresh({ - ...storedRefreshCredentials, - tokenEndpoint: "https://evil.example/oauth/token", - }), - ).rejects.toMatchObject({ - code: XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - message: - "Refusing non-xAI OAuth token_endpoint: https://evil.example/oauth/token", - }); - expect(fetchMock).not.toHaveBeenCalled(); - }); + await expect( + refresh({ + ...storedRefreshCredentials, + tokenEndpoint: 'https://evil.example/oauth/token', + }), + ).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + message: 'Refusing non-xAI OAuth token_endpoint: https://evil.example/oauth/token', + }); + expect(fetchMock).not.toHaveBeenCalled(); + }); - it("discovers the token endpoint when credentials do not include it", async () => { - const fetchMock = vi.fn(async (input) => { - if (input === "https://auth.x.ai/.well-known/openid-configuration") { - return Response.json(discoveryDocument); - } - return Response.json({ access_token: "new-access" }); - }); - globalThis.fetch = fetchMock; + it('discovers the token endpoint when credentials do not include it', async () => { + const fetchMock = vi.fn(async (input) => { + if (input === 'https://auth.x.ai/.well-known/openid-configuration') { + return Response.json(discoveryDocument); + } + return Response.json({ access_token: 'new-access' }); + }); + globalThis.fetch = fetchMock; - await expect(refresh(credentialsWithoutEndpoint)).resolves.toMatchObject({ - access: "new-access", - refresh: "old-refresh", - tokenEndpoint: "https://auth.x.ai/oauth/token", - }); - expect(fetchMock.mock.calls.map((call) => call[0])).toEqual([ - "https://auth.x.ai/.well-known/openid-configuration", - "https://auth.x.ai/oauth/token", - ]); - }); + await expect(refresh(credentialsWithoutEndpoint)).resolves.toMatchObject({ + access: 'new-access', + refresh: 'old-refresh', + tokenEndpoint: 'https://auth.x.ai/oauth/token', + }); + expect(fetchMock.mock.calls.map((call) => call[0])).toEqual([ + 'https://auth.x.ai/.well-known/openid-configuration', + 'https://auth.x.ai/oauth/token', + ]); + }); - it("wraps discovery network failures", async () => { - globalThis.fetch = vi.fn(async () => { - throw new Error("network down"); - }); + it('wraps discovery network failures', async () => { + globalThis.fetch = vi.fn(async () => { + throw new Error('network down'); + }); - await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ - code: XaiErrorCode.DISCOVERY_FAILED, - message: "xAI OIDC discovery failed: network down", - }); - }); + await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_FAILED, + message: 'xAI OIDC discovery failed: network down', + }); + }); - it("rejects failed and invalid discovery responses", async () => { - globalThis.fetch = vi.fn( - async () => new Response("unavailable", { status: 503 }), - ); - await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ - code: XaiErrorCode.DISCOVERY_FAILED, - message: "xAI OIDC discovery returned 503", - }); + it('rejects failed and invalid discovery responses', async () => { + globalThis.fetch = vi.fn( + async () => new Response('unavailable', { status: 503 }), + ); + await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_FAILED, + message: 'xAI OIDC discovery returned 503', + }); - globalThis.fetch = vi.fn(async () => - Response.json({ - authorization_endpoint: "http://auth.x.ai/oauth/authorize", - token_endpoint: "https://auth.x.ai/oauth/token", - }), - ); - await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ - code: XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - message: - "xAI OAuth authorization_endpoint must use HTTPS: http://auth.x.ai/oauth/authorize", - }); - }); + globalThis.fetch = vi.fn(async () => + Response.json({ + authorization_endpoint: 'http://auth.x.ai/oauth/authorize', + token_endpoint: 'https://auth.x.ai/oauth/token', + }), + ); + await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + message: 'xAI OAuth authorization_endpoint must use HTTPS: http://auth.x.ai/oauth/authorize', + }); + }); - it("logs in with a loopback callback and exchanges the authorization code", async () => { - vi.useFakeTimers(); - vi.setSystemTime(1_700_000_000_000); - const fetchMock = vi.fn(async (input) => { - if (input === "https://auth.x.ai/.well-known/openid-configuration") { - return Response.json(discoveryDocument); - } - return Response.json({ - access_token: "login-access", - refresh_token: "login-refresh", - expires_in: 900, - id_token: "login-id", - token_type: "Bearer", - }); - }); - globalThis.fetch = fetchMock; + it('logs in with a loopback callback and exchanges the authorization code', async () => { + vi.useFakeTimers(); + vi.setSystemTime(1_700_000_000_000); + const fetchMock = vi.fn(async (input) => { + if (input === 'https://auth.x.ai/.well-known/openid-configuration') { + return Response.json(discoveryDocument); + } + return Response.json({ + access_token: 'login-access', + refresh_token: 'login-refresh', + expires_in: 900, + id_token: 'login-id', + token_type: 'Bearer', + }); + }); + globalThis.fetch = fetchMock; - await expect( - login({ - onAuth: (auth) => { - const url = new URL(auth.url); - void originalFetch( - `${url.searchParams.get("redirect_uri")}?code=callback-code&state=${url.searchParams.get("state")}`, - ); - }, - }), - ).resolves.toMatchObject({ - access: "login-access", - refresh: "login-refresh", - expires: 1_700_000_780_000, - tokenEndpoint: "https://auth.x.ai/oauth/token", - discovery: discoveryDocument, - idToken: "login-id", - tokenType: "Bearer", - }); + await expect( + login({ + onAuth: (auth) => { + const url = new URL(auth.url); + void originalFetch( + `${url.searchParams.get('redirect_uri')}?code=callback-code&state=${url.searchParams.get('state')}`, + ); + }, + }), + ).resolves.toMatchObject({ + access: 'login-access', + refresh: 'login-refresh', + expires: 1_700_000_780_000, + tokenEndpoint: 'https://auth.x.ai/oauth/token', + discovery: discoveryDocument, + idToken: 'login-id', + tokenType: 'Bearer', + }); - expect(fetchMock.mock.calls[1]?.[0]).toBe("https://auth.x.ai/oauth/token"); - expect( - (fetchMock.mock.calls[1]?.[1]?.body as URLSearchParams).get("code"), - ).toBe("callback-code"); - }); + expect(fetchMock.mock.calls[1]?.[0]).toBe('https://auth.x.ai/oauth/token'); + expect((fetchMock.mock.calls[1]?.[1]?.body as URLSearchParams).get('code')).toBe( + 'callback-code', + ); + }); }); diff --git a/tests/models/catalog.test.ts b/tests/models/catalog.test.ts index 79c8bf7..92af5a2 100644 --- a/tests/models/catalog.test.ts +++ b/tests/models/catalog.test.ts @@ -1,62 +1,53 @@ -import { afterEach, describe, expect, it } from "vitest"; -import { - resolveModels, - supportsReasoningEffort, -} from "../../src/models/catalog.js"; +import { afterEach, describe, expect, it } from 'vitest'; +import { resolveModels, supportsReasoningEffort } from '../../src/models/catalog.js'; const originalEnv = { ...process.env }; afterEach(() => { - process.env = { ...originalEnv }; + process.env = { ...originalEnv }; }); -describe("model catalog", () => { - it("reports reasoning-effort support by normalized model name", () => { - expect(supportsReasoningEffort("grok-4.3")).toBe(true); - expect(supportsReasoningEffort("grok-cli/GROK-COMPOSER-2.5-fast")).toBe( - true, - ); - expect(supportsReasoningEffort("grok-4.20-0309-non-reasoning")).toBe(false); - }); - - it("uses fallback models when no override is configured", () => { - delete process.env.PI_GROK_CLI_MODELS; - - const models = resolveModels(); - - expect(models.map((model) => model.id)).toEqual([ - "grok-composer-2.5-fast", - "grok-build", - "grok-4.3", - "grok-4.20-0309-reasoning", - "grok-4.20-0309-non-reasoning", - "grok-4.20-multi-agent-0309", - ]); - expect( - models.find((model) => model.id === "grok-composer-2.5-fast"), - ).toMatchObject({ contextWindow: 200_000 }); - expect(models.find((model) => model.id === "grok-build")).toMatchObject({ - contextWindow: 512_000, - }); - }); - - it("filters, reorders, and fills unknown model overrides", () => { - process.env.PI_GROK_CLI_MODELS = " custom-model , grok-build ,, grok-4.3 "; - - const models = resolveModels(); - - expect(models.map((model) => model.id)).toEqual([ - "custom-model", - "grok-build", - "grok-4.3", - ]); - expect(models[0]).toMatchObject({ - name: "custom-model", - reasoning: true, - input: ["text"], - contextWindow: 1_000_000, - maxTokens: 30_000, - }); - expect(models[1].name).toBe("Grok Build"); - }); +describe('model catalog', () => { + it('reports reasoning-effort support by normalized model name', () => { + expect(supportsReasoningEffort('grok-4.3')).toBe(true); + expect(supportsReasoningEffort('grok-cli/GROK-COMPOSER-2.5-fast')).toBe(true); + expect(supportsReasoningEffort('grok-4.20-0309-non-reasoning')).toBe(false); + }); + + it('uses fallback models when no override is configured', () => { + delete process.env.PI_GROK_CLI_MODELS; + + const models = resolveModels(); + + expect(models.map((model) => model.id)).toEqual([ + 'grok-composer-2.5-fast', + 'grok-build', + 'grok-4.3', + 'grok-4.20-0309-reasoning', + 'grok-4.20-0309-non-reasoning', + 'grok-4.20-multi-agent-0309', + ]); + expect(models.find((model) => model.id === 'grok-composer-2.5-fast')).toMatchObject({ + contextWindow: 200_000, + }); + expect(models.find((model) => model.id === 'grok-build')).toMatchObject({ + contextWindow: 512_000, + }); + }); + + it('filters, reorders, and fills unknown model overrides', () => { + process.env.PI_GROK_CLI_MODELS = ' custom-model , grok-build ,, grok-4.3 '; + + const models = resolveModels(); + + expect(models.map((model) => model.id)).toEqual(['custom-model', 'grok-build', 'grok-4.3']); + expect(models[0]).toMatchObject({ + name: 'custom-model', + reasoning: true, + input: ['text'], + contextWindow: 1_000_000, + maxTokens: 30_000, + }); + expect(models[1].name).toBe('Grok Build'); + }); }); diff --git a/tests/payload/sanitize.test.ts b/tests/payload/sanitize.test.ts index ff74434..6c925c4 100644 --- a/tests/payload/sanitize.test.ts +++ b/tests/payload/sanitize.test.ts @@ -1,185 +1,185 @@ -import { mkdtempSync, rmSync, writeFileSync } from "node:fs"; -import { tmpdir } from "node:os"; -import { join } from "node:path"; -import { describe, expect, it } from "vitest"; -import { sanitizePayload } from "../../src/payload/sanitize.js"; +import { mkdtempSync, rmSync, writeFileSync } from 'node:fs'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { sanitizePayload } from '../../src/payload/sanitize.js'; -describe("payload sanitization", () => { - it("removes unsupported items and moves leading instructions", () => { - const payload = sanitizePayload( - { - instructions: "existing instruction", - input: [ - { role: "system", content: "system instruction" }, - { - role: "developer", - content: [ - { type: "input_text", text: "developer instruction" }, - { type: "output_text", text: "output text instruction" }, - ], - }, - { type: "reasoning", content: "cached reasoning" }, - { role: "user", content: "" }, - { role: "user", content: "hello" }, - ], - include: ["reasoning.encrypted_content", "message.output_text"], - prompt_cache_retention: "24h", - reasoning: { effort: "minimal", summary: "auto" }, - response_format: { type: "json_object" }, - }, - "grok-4.3", - "session-123", - ); +describe('payload sanitization', () => { + it('removes unsupported items and moves leading instructions', () => { + const payload = sanitizePayload( + { + instructions: 'existing instruction', + input: [ + { role: 'system', content: 'system instruction' }, + { + role: 'developer', + content: [ + { type: 'input_text', text: 'developer instruction' }, + { type: 'output_text', text: 'output text instruction' }, + ], + }, + { type: 'reasoning', content: 'cached reasoning' }, + { role: 'user', content: '' }, + { role: 'user', content: 'hello' }, + ], + include: ['reasoning.encrypted_content', 'message.output_text'], + prompt_cache_retention: '24h', + reasoning: { effort: 'minimal', summary: 'auto' }, + response_format: { type: 'json_object' }, + }, + 'grok-4.3', + 'session-123', + ); - expect(payload.instructions).toBe( - "existing instruction\n\nsystem instruction\n\ndeveloper instruction\noutput text instruction", - ); - expect(payload.input).toEqual([{ role: "user", content: "hello" }]); - expect(payload.include).toEqual(["message.output_text"]); - expect(payload.prompt_cache_retention).toBeUndefined(); - expect(payload.reasoning).toEqual({ effort: "minimal" }); - expect(payload.text).toEqual({ format: { type: "json_object" } }); - expect(payload.response_format).toBeUndefined(); - expect(payload.prompt_cache_key).toBe("session-123"); - }); + expect(payload.instructions).toBe( + 'existing instruction\n\nsystem instruction\n\ndeveloper instruction\noutput text instruction', + ); + expect(payload.input).toEqual([{ role: 'user', content: 'hello' }]); + expect(payload.include).toEqual(['message.output_text']); + expect(payload.prompt_cache_retention).toBeUndefined(); + expect(payload.reasoning).toEqual({ effort: 'minimal' }); + expect(payload.text).toEqual({ format: { type: 'json_object' } }); + expect(payload.response_format).toBeUndefined(); + expect(payload.prompt_cache_key).toBe('session-123'); + }); - it("strips reasoning fields for models that do not accept reasoning effort", () => { - const payload = sanitizePayload( - { - input: "plain prompt", - include: ["reasoning.encrypted_content"], - reasoning: { effort: "high" }, - reasoningEffort: "high", - prompt_cache_key: "existing-session", - }, - "grok-build", - "new-session", - ); + it('strips reasoning fields for models that do not accept reasoning effort', () => { + const payload = sanitizePayload( + { + input: 'plain prompt', + include: ['reasoning.encrypted_content'], + reasoning: { effort: 'high' }, + reasoningEffort: 'high', + prompt_cache_key: 'existing-session', + }, + 'grok-build', + 'new-session', + ); - expect(payload.input).toBe("plain prompt"); - expect(payload.reasoning).toBeUndefined(); - expect(payload.reasoningEffort).toBeUndefined(); - expect(payload.include).toBeUndefined(); - expect(payload.prompt_cache_key).toBe("existing-session"); - }); + expect(payload.input).toBe('plain prompt'); + expect(payload.reasoning).toBeUndefined(); + expect(payload.reasoningEffort).toBeUndefined(); + expect(payload.include).toBeUndefined(); + expect(payload.prompt_cache_key).toBe('existing-session'); + }); - it("normalizes image parts and rewrites image tool output", () => { - const payload = sanitizePayload( - { - input: [ - { - role: "user", - content: [ - { type: "image", data: "ZmFrZQ==", mimeType: "image/png" }, - { - type: "image_url", - image_url: { - url: "https://example.invalid/image.png", - detail: "high", - }, - }, - ], - }, - { - type: "function_call_output", - call_id: "call_1", - output: [ - { type: "input_text", text: "tool text" }, - { type: "input_image", image_url: "data:image/png;base64,aW1n" }, - ], - }, - ], - }, - "grok-composer-2.5-fast", - ); + it('normalizes image parts and rewrites image tool output', () => { + const payload = sanitizePayload( + { + input: [ + { + role: 'user', + content: [ + { type: 'image', data: 'ZmFrZQ==', mimeType: 'image/png' }, + { + type: 'image_url', + image_url: { + url: 'https://example.invalid/image.png', + detail: 'high', + }, + }, + ], + }, + { + type: 'function_call_output', + call_id: 'call_1', + output: [ + { type: 'input_text', text: 'tool text' }, + { type: 'input_image', image_url: 'data:image/png;base64,aW1n' }, + ], + }, + ], + }, + 'grok-composer-2.5-fast', + ); - expect(payload.input).toEqual([ - { - role: "user", - content: [ - { - type: "input_image", - image_url: "data:image/png;base64,ZmFrZQ==", - detail: "auto", - }, - { - type: "input_image", - image_url: "https://example.invalid/image.png", - detail: "high", - }, - ], - }, - { type: "function_call_output", call_id: "call_1", output: "tool text" }, - { - role: "user", - content: [ - { - type: "input_text", - text: "The previous tool result (call_1) included 1 image. Use the attached image as the visual output from that tool.", - }, - { - type: "input_image", - image_url: "data:image/png;base64,aW1n", - detail: "auto", - }, - ], - }, - ]); - }); + expect(payload.input).toEqual([ + { + role: 'user', + content: [ + { + type: 'input_image', + image_url: 'data:image/png;base64,ZmFrZQ==', + detail: 'auto', + }, + { + type: 'input_image', + image_url: 'https://example.invalid/image.png', + detail: 'high', + }, + ], + }, + { type: 'function_call_output', call_id: 'call_1', output: 'tool text' }, + { + role: 'user', + content: [ + { + type: 'input_text', + text: 'The previous tool result (call_1) included 1 image. Use the attached image as the visual output from that tool.', + }, + { + type: 'input_image', + image_url: 'data:image/png;base64,aW1n', + detail: 'auto', + }, + ], + }, + ]); + }); - it("resolves local image paths to data URLs", () => { - const dir = mkdtempSync(join(tmpdir(), "pi-grok-cli-test-")); - const imagePath = join(dir, "sample image.png"); - writeFileSync(imagePath, Buffer.from("png image bytes")); + it('resolves local image paths to data URLs', () => { + const dir = mkdtempSync(join(tmpdir(), 'pi-grok-cli-test-')); + const imagePath = join(dir, 'sample image.png'); + writeFileSync(imagePath, Buffer.from('png image bytes')); - try { - const payload = sanitizePayload( - { - input: [ - { - role: "user", - content: [ - { - type: "input_image", - image_url: `'${imagePath}'`, - }, - ], - }, - ], - }, - "grok-4.3", - ); + try { + const payload = sanitizePayload( + { + input: [ + { + role: 'user', + content: [ + { + type: 'input_image', + image_url: `'${imagePath}'`, + }, + ], + }, + ], + }, + 'grok-4.3', + ); - expect(payload.input).toEqual([ - { - role: "user", - content: [ - { - type: "input_image", - image_url: `data:image/png;base64,${Buffer.from("png image bytes").toString("base64")}`, - detail: "auto", - }, - ], - }, - ]); - } finally { - rmSync(dir, { recursive: true, force: true }); - } - }); + expect(payload.input).toEqual([ + { + role: 'user', + content: [ + { + type: 'input_image', + image_url: `data:image/png;base64,${Buffer.from('png image bytes').toString('base64')}`, + detail: 'auto', + }, + ], + }, + ]); + } finally { + rmSync(dir, { recursive: true, force: true }); + } + }); - it("rejects missing or unsupported local images", () => { - expect(() => - sanitizePayload( - { - input: [ - { - role: "user", - content: [{ type: "input_image", image_url: "missing.png" }], - }, - ], - }, - "grok-4.3", - ), - ).toThrow("Image file does not exist or is not a valid URL: missing.png"); - }); + it('rejects missing or unsupported local images', () => { + expect(() => + sanitizePayload( + { + input: [ + { + role: 'user', + content: [{ type: 'input_image', image_url: 'missing.png' }], + }, + ], + }, + 'grok-4.3', + ), + ).toThrow('Image file does not exist or is not a valid URL: missing.png'); + }); }); diff --git a/tests/provider/package.test.ts b/tests/provider/package.test.ts index 8bb6f04..3ff0c1a 100644 --- a/tests/provider/package.test.ts +++ b/tests/provider/package.test.ts @@ -1,68 +1,60 @@ -import { existsSync, globSync, readFileSync } from "node:fs"; -import { describe, expect, it } from "vitest"; +import { existsSync, globSync, readFileSync } from 'node:fs'; +import { describe, expect, it } from 'vitest'; const packageJson = JSON.parse( - readFileSync(new URL("../../package.json", import.meta.url), "utf8"), + readFileSync(new URL('../../package.json', import.meta.url), 'utf8'), ); -describe("npm package manifest", () => { - it("declares a pi package entry point", () => { - expect(packageJson.name).toBe("pi-grok-cli"); - expect(packageJson.keywords).toContain("pi-package"); - expect(packageJson.pi?.extensions).toEqual(["./src/index.ts"]); - expect(packageJson.main).toBe("./src/index.ts"); - expect(packageJson.files).toEqual(["README.md", "src", "tsconfig.json"]); - }); +describe('npm package manifest', () => { + it('declares a pi package entry point', () => { + expect(packageJson.name).toBe('pi-grok-cli'); + expect(packageJson.keywords).toContain('pi-package'); + expect(packageJson.pi?.extensions).toEqual(['./src/index.ts']); + expect(packageJson.main).toBe('./src/index.ts'); + expect(packageJson.files).toEqual(['README.md', 'src', 'tsconfig.json']); + }); - it("runs publish checks before packing", () => { - expect(packageJson.scripts?.test).toBe("vitest run --reporter=agent"); - expect(packageJson.scripts?.coverage).toBe( - "vitest run --reporter=agent --coverage", - ); - expect(packageJson.scripts?.typecheck).toBe("tsc --noEmit"); - expect(packageJson.scripts?.prepack).toBe( - "bun run test && bun run coverage && bun run typecheck", - ); - expect(packageJson.devDependencies?.vitest).toBeDefined(); - expect(packageJson.devDependencies?.["@vitest/coverage-v8"]).toBeDefined(); - expect(existsSync(new URL("../../vitest.config.ts", import.meta.url))).toBe( - true, - ); - }); + it('runs publish checks before packing', () => { + expect(packageJson.scripts?.test).toBe('vitest run --reporter=agent'); + expect(packageJson.scripts?.coverage).toBe('vitest run --reporter=agent --coverage'); + expect(packageJson.scripts?.typecheck).toBe('tsc --noEmit'); + expect(packageJson.scripts?.prepack).toBe( + 'bun run test && bun run coverage && bun run typecheck', + ); + expect(packageJson.devDependencies?.vitest).toBeDefined(); + expect(packageJson.devDependencies?.['@vitest/coverage-v8']).toBeDefined(); + expect(existsSync(new URL('../../vitest.config.ts', import.meta.url))).toBe(true); + }); }); -describe("repository layout", () => { - it("keeps the extension entrypoint at src/index.ts", () => { - expect(existsSync(new URL("../../src/index.ts", import.meta.url))).toBe( - true, - ); - }); +describe('repository layout', () => { + it('keeps the extension entrypoint at src/index.ts', () => { + expect(existsSync(new URL('../../src/index.ts', import.meta.url))).toBe(true); + }); - it("contains the expected domain source files", () => { - expect(globSync("src/**/*.ts").sort()).toEqual([ - "src/auth/oauth.ts", - "src/index.ts", - "src/models/catalog.ts", - "src/payload/sanitize.ts", - "src/provider/quota.ts", - "src/provider/register.ts", - "src/provider/status.ts", - "src/provider/stream.ts", - "src/provider/toolScope.ts", - "src/shared/errors.ts", - "src/tools/files.ts", - "src/tools/register.ts", - "src/tools/rendering.ts", - "src/tools/search.ts", - "src/tools/shell.ts", - ]); - }); + it('contains the expected domain source files', () => { + expect(globSync('src/**/*.ts').sort()).toEqual([ + 'src/auth/oauth.ts', + 'src/index.ts', + 'src/models/catalog.ts', + 'src/payload/sanitize.ts', + 'src/provider/quota.ts', + 'src/provider/register.ts', + 'src/provider/status.ts', + 'src/provider/stream.ts', + 'src/provider/toolScope.ts', + 'src/shared/errors.ts', + 'src/tools/files.ts', + 'src/tools/register.ts', + 'src/tools/rendering.ts', + 'src/tools/search.ts', + 'src/tools/shell.ts', + ]); + }); - it("does not keep top-level helper compatibility wrappers", () => { - for (const file of ["errors.ts", "models.ts", "oauth.ts", "sanitize.ts"]) { - expect(existsSync(new URL(`../../src/${file}`, import.meta.url))).toBe( - false, - ); - } - }); + it('does not keep top-level helper compatibility wrappers', () => { + for (const file of ['errors.ts', 'models.ts', 'oauth.ts', 'sanitize.ts']) { + expect(existsSync(new URL(`../../src/${file}`, import.meta.url))).toBe(false); + } + }); }); diff --git a/tests/provider/register.test.ts b/tests/provider/register.test.ts index e122036..d44045a 100644 --- a/tests/provider/register.test.ts +++ b/tests/provider/register.test.ts @@ -1,85 +1,76 @@ -import { - mkdirSync, - mkdtempSync, - readFileSync, - rmSync, - writeFileSync, -} from "node:fs"; -import { tmpdir } from "node:os"; -import { join } from "node:path"; -import type { - ExtensionAPI, - ProviderConfig, -} from "@earendil-works/pi-coding-agent"; -import { afterEach, describe, expect, it, vi } from "vitest"; +import { mkdirSync, mkdtempSync, readFileSync, rmSync, writeFileSync } from 'node:fs'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import type { ExtensionAPI, ProviderConfig } from '@earendil-works/pi-coding-agent'; +import { afterEach, describe, expect, it, vi } from 'vitest'; const streamSimpleOpenAIResponses = vi.fn( - ( - _model: unknown, - _context: unknown, - options?: { - onResponse?: (response: { headers: Record }) => void; - }, - ) => { - options?.onResponse?.({ - headers: { - "x-ratelimit-remaining-requests": "179", - "x-ratelimit-limit-requests": "180", - "x-ratelimit-remaining-tokens": "7500000", - "x-ratelimit-limit-tokens": "7500000", - "x-grok-context-window": "512000", - "x-zero-data-retention": "true", - }, - }); - return {}; - }, + ( + _model: unknown, + _context: unknown, + options?: { + onResponse?: (response: { headers: Record }) => void; + }, + ) => { + options?.onResponse?.({ + headers: { + 'x-ratelimit-remaining-requests': '179', + 'x-ratelimit-limit-requests': '180', + 'x-ratelimit-remaining-tokens': '7500000', + 'x-ratelimit-limit-tokens': '7500000', + 'x-grok-context-window': '512000', + 'x-zero-data-retention': 'true', + }, + }); + return {}; + }, ); -vi.mock("@earendil-works/pi-ai", async (importOriginal) => ({ - ...(await importOriginal()), - streamSimpleOpenAIResponses, +vi.mock('@earendil-works/pi-ai', async (importOriginal) => ({ + ...(await importOriginal()), + streamSimpleOpenAIResponses, })); interface CommandConfig { - handler: (args: string[], ctx: TestContext) => Promise; + handler: (args: string[], ctx: TestContext) => Promise; } interface RegisteredTool { - name: string; - renderCall?: (...args: unknown[]) => Renderable; - renderResult?: (...args: unknown[]) => Renderable; + name: string; + renderCall?: (...args: unknown[]) => Renderable; + renderResult?: (...args: unknown[]) => Renderable; } interface Renderable { - render: (width: number) => string[]; + render: (width: number) => string[]; } interface TestContext { - modelRegistry: { - getAll: () => { provider: string; id: string }[]; - getApiKeyForProvider?: (provider: string) => Promise; - }; - model?: { provider: string; id: string }; - sessionManager?: { - getSessionId: () => string; - }; - ui: { - notify: (message: string, level: string) => void; - }; + modelRegistry: { + getAll: () => { provider: string; id: string }[]; + getApiKeyForProvider?: (provider: string) => Promise; + }; + model?: { provider: string; id: string }; + sessionManager?: { + getSessionId: () => string; + }; + ui: { + notify: (message: string, level: string) => void; + }; } type ExtensionHandler = (event: unknown, ctx: TestContext) => unknown; const grokToolNames = [ - "Grep", - "Glob", - "LS", - "Read", - "Write", - "StrReplace", - "Edit", - "Delete", - "Shell", + 'Grep', + 'Glob', + 'LS', + 'Read', + 'Write', + 'StrReplace', + 'Edit', + 'Delete', + 'Shell', ]; const originalFetch = globalThis.fetch; @@ -88,540 +79,487 @@ const originalToken = process.env.GROK_CLI_OAUTH_TOKEN; const tempDirs: string[] = []; afterEach(() => { - vi.resetModules(); - streamSimpleOpenAIResponses.mockClear(); - globalThis.fetch = originalFetch; - if (originalHome === undefined) { - delete process.env.HOME; - } else { - process.env.HOME = originalHome; - } - if (originalToken === undefined) { - delete process.env.GROK_CLI_OAUTH_TOKEN; - } else { - process.env.GROK_CLI_OAUTH_TOKEN = originalToken; - } - for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); + vi.resetModules(); + streamSimpleOpenAIResponses.mockClear(); + globalThis.fetch = originalFetch; + if (originalHome === undefined) { + delete process.env.HOME; + } else { + process.env.HOME = originalHome; + } + if (originalToken === undefined) { + delete process.env.GROK_CLI_OAUTH_TOKEN; + } else { + process.env.GROK_CLI_OAUTH_TOKEN = originalToken; + } + for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); }); -async function setupExtension(initialActiveTools = ["read", "bash"]) { - const commands = new Map(); - const providers = new Map(); - const tools = new Map(); - const handlers = new Map(); - let activeTools = initialActiveTools; - const setActiveTools = vi.fn((toolNames: string[]) => { - activeTools = toolNames; - }); - const registerGrokCli = (await import("../../src/index.js")).default; - registerGrokCli({ - registerProvider(name: string, config: ProviderConfig) { - providers.set(name, config); - }, - on(event: string, handler: ExtensionHandler) { - handlers.set(event, handler); - }, - registerCommand(name: string, config: unknown) { - commands.set(name, config as CommandConfig); - }, - registerTool(tool: RegisteredTool) { - tools.set(tool.name, tool); - }, - getActiveTools() { - return activeTools; - }, - setActiveTools, - } as unknown as ExtensionAPI); - return { commands, providers, tools, handlers, setActiveTools }; +async function setupExtension(initialActiveTools = ['read', 'bash']) { + const commands = new Map(); + const providers = new Map(); + const tools = new Map(); + const handlers = new Map(); + let activeTools = initialActiveTools; + const setActiveTools = vi.fn((toolNames: string[]) => { + activeTools = toolNames; + }); + const registerGrokCli = (await import('../../src/index.js')).default; + registerGrokCli({ + registerProvider(name: string, config: ProviderConfig) { + providers.set(name, config); + }, + on(event: string, handler: ExtensionHandler) { + handlers.set(event, handler); + }, + registerCommand(name: string, config: unknown) { + commands.set(name, config as CommandConfig); + }, + registerTool(tool: RegisteredTool) { + tools.set(tool.name, tool); + }, + getActiveTools() { + return activeTools; + }, + setActiveTools, + } as unknown as ExtensionAPI); + return { commands, providers, tools, handlers, setActiveTools }; } -function statusContext(notify: TestContext["ui"]["notify"]): TestContext { - return { - modelRegistry: { - getAll: () => [ - { provider: "grok-cli", id: "grok-build" }, - { provider: "grok-cli", id: "grok-composer-2.5-fast" }, - ], - }, - ui: { notify }, - }; +function statusContext(notify: TestContext['ui']['notify']): TestContext { + return { + modelRegistry: { + getAll: () => [ + { provider: 'grok-cli', id: 'grok-build' }, + { provider: 'grok-cli', id: 'grok-composer-2.5-fast' }, + ], + }, + ui: { notify }, + }; } -function emptyStatusContext(notify: TestContext["ui"]["notify"]): TestContext { - return { - modelRegistry: { getAll: () => [] }, - ui: { notify }, - }; +function emptyStatusContext(notify: TestContext['ui']['notify']): TestContext { + return { + modelRegistry: { getAll: () => [] }, + ui: { notify }, + }; } function contextForModel(provider: string): TestContext { - return { - model: { provider, id: `${provider}-model` }, - modelRegistry: { getAll: () => [] }, - ui: { notify: vi.fn() }, - }; + return { + model: { provider, id: `${provider}-model` }, + modelRegistry: { getAll: () => [] }, + ui: { notify: vi.fn() }, + }; } function renderText(component: Renderable): string { - return component - .render(120) - .map((line) => line.trimEnd()) - .join("\n"); + return component + .render(120) + .map((line) => line.trimEnd()) + .join('\n'); } const theme = { - bold: (text: string) => text, - fg: (_name: string, text: string) => text, + bold: (text: string) => text, + fg: (_name: string, text: string) => text, }; function setupHome() { - const dir = mkdtempSync(join(tmpdir(), "pi-grok-cli-home-")); - mkdirSync(join(dir, ".pi")); - tempDirs.push(dir); - process.env.HOME = dir; - return dir; + const dir = mkdtempSync(join(tmpdir(), 'pi-grok-cli-home-')); + mkdirSync(join(dir, '.pi')); + tempDirs.push(dir); + process.env.HOME = dir; + return dir; } -async function runStatus( - extension: Awaited>, -) { - const notify = vi.fn(); - await extension.commands - .get("grok-cli-status") - ?.handler([], statusContext(notify)); - return notify; +async function runStatus(extension: Awaited>) { + const notify = vi.fn(); + await extension.commands.get('grok-cli-status')?.handler([], statusContext(notify)); + return notify; } -describe("Grok CLI status command", () => { - it("uses only cached quota data and tells users to make requests first", async () => { - delete process.env.GROK_CLI_OAUTH_TOKEN; - setupHome(); - const fetchMock = vi.fn(); - globalThis.fetch = fetchMock; - const extension = await setupExtension(); - const notify = await runStatus(extension); - - expect(fetchMock).not.toHaveBeenCalled(); - expect(notify.mock.calls.at(-1)?.[0]).toBe( - [ - " Quota:", - "", - " grok-build:", - " no cached quota data — make a request with this model first", - "", - " grok-composer-2.5-fast:", - " no cached quota data — make a request with this model first", - ].join("\n"), - ); - }); - - it("shows separate cached quotas for build and composer", async () => { - delete process.env.GROK_CLI_OAUTH_TOKEN; - setupHome(); - const extension = await setupExtension(); - const provider = extension.providers.get("grok-cli"); - provider?.streamSimple?.( - { provider: "grok-cli", id: "grok-build" }, - {}, - {}, - ); - provider?.streamSimple?.( - { provider: "grok-cli", id: "grok-composer-2.5-fast" }, - {}, - {}, - ); - const notify = await runStatus(extension); - - expect(notify.mock.calls.at(-1)?.[0]).toContain("grok-build:\n Cached:"); - expect(notify.mock.calls.at(-1)?.[0]).toContain( - "grok-composer-2.5-fast:\n Cached:", - ); - expect(notify.mock.calls.at(-1)?.[0]).toContain( - "Requests: 179/180 remaining", - ); - }); - - it("persists cached quotas to the global pi config directory", async () => { - delete process.env.GROK_CLI_OAUTH_TOKEN; - const home = setupHome(); - const extension = await setupExtension(); - extension.providers - .get("grok-cli") - ?.streamSimple?.({ provider: "grok-cli", id: "grok-build" }, {}, {}); - - expect( - JSON.parse(readFileSync(join(home, ".pi", "grok-cli-quota.json"), "utf8")) - .models["grok-build"].remainingRequests, - ).toBe(179); - }); - - it("loads cached quotas from the global pi config directory", async () => { - delete process.env.GROK_CLI_OAUTH_TOKEN; - const home = setupHome(); - writeFileSync( - join(home, ".pi", "grok-cli-quota.json"), - JSON.stringify({ - version: 1, - models: { - "grok-build": { - remainingRequests: 42, - limitRequests: 180, - remainingTokens: 1_000, - limitTokens: 2_000, - contextWindow: 512_000, - zeroDataRetention: true, - capturedAt: Date.now(), - }, - }, - }), - ); - const extension = await setupExtension(); - const notify = await runStatus(extension); - - expect(notify.mock.calls.at(-1)?.[0]).toContain( - "Requests: 42/180 remaining", - ); - }); - - it("warns when no Grok models are registered", async () => { - const extension = await setupExtension(); - const notify = vi.fn(); - - await extension.commands - .get("grok-cli-status") - ?.handler([], emptyStatusContext(notify)); - - expect(notify).toHaveBeenCalledOnce(); - expect(notify).toHaveBeenCalledWith( - "Grok CLI: no models registered. Run /login grok-cli first.", - "warning", - ); - }); - - it("shows env-token bypass and truncates long model lists", async () => { - process.env.GROK_CLI_OAUTH_TOKEN = "token"; - const extension = await setupExtension(); - const notify = vi.fn(); - - await extension.commands.get("grok-cli-status")?.handler([], { - modelRegistry: { - getAll: () => - Array.from({ length: 7 }, (_value, index) => ({ - provider: "grok-cli", - id: `grok-model-${index + 1}`, - })), - }, - ui: { notify }, - }); - - expect(notify.mock.calls[0]).toEqual([ - "⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available", - "warning", - ]); - expect(notify.mock.calls[1]).toEqual([ - "✓ Grok CLI: 7 models available (grok-model-1, grok-model-2, grok-model-3, grok-model-4, grok-model-5 (+2 more))", - "info", - ]); - }); - - it("reports registry errors as status warnings", async () => { - const extension = await setupExtension(); - const notify = vi.fn(); - - await extension.commands.get("grok-cli-status")?.handler([], { - modelRegistry: { - getAll: () => { - throw new Error("registry unavailable"); - }, - }, - ui: { notify }, - }); - - expect(notify).toHaveBeenCalledWith( - "Grok CLI: registry unavailable", - "warning", - ); - }); - - it("includes OAuth error codes in status warnings", async () => { - const { XaiOAuthError } = await import("../../src/shared/errors.js"); - const extension = await setupExtension(); - const notify = vi.fn(); - - await extension.commands.get("grok-cli-status")?.handler([], { - modelRegistry: { - getAll: () => { - throw new XaiOAuthError("refresh failed", "refresh_failed", true); - }, - }, - ui: { notify }, - }); - - expect(notify).toHaveBeenCalledWith( - "Grok CLI: refresh failed (code: refresh_failed)", - "warning", - ); - }); +describe('Grok CLI status command', () => { + it('uses only cached quota data and tells users to make requests first', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + setupHome(); + const fetchMock = vi.fn(); + globalThis.fetch = fetchMock; + const extension = await setupExtension(); + const notify = await runStatus(extension); + + expect(fetchMock).not.toHaveBeenCalled(); + expect(notify.mock.calls.at(-1)?.[0]).toBe( + [ + ' Quota:', + '', + ' grok-build:', + ' no cached quota data — make a request with this model first', + '', + ' grok-composer-2.5-fast:', + ' no cached quota data — make a request with this model first', + ].join('\n'), + ); + }); + + it('shows separate cached quotas for build and composer', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + setupHome(); + const extension = await setupExtension(); + const provider = extension.providers.get('grok-cli'); + provider?.streamSimple?.({ provider: 'grok-cli', id: 'grok-build' }, {}, {}); + provider?.streamSimple?.({ provider: 'grok-cli', id: 'grok-composer-2.5-fast' }, {}, {}); + const notify = await runStatus(extension); + + expect(notify.mock.calls.at(-1)?.[0]).toContain('grok-build:\n Cached:'); + expect(notify.mock.calls.at(-1)?.[0]).toContain('grok-composer-2.5-fast:\n Cached:'); + expect(notify.mock.calls.at(-1)?.[0]).toContain('Requests: 179/180 remaining'); + }); + + it('persists cached quotas to the global pi config directory', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + const home = setupHome(); + const extension = await setupExtension(); + extension.providers + .get('grok-cli') + ?.streamSimple?.({ provider: 'grok-cli', id: 'grok-build' }, {}, {}); + + expect( + JSON.parse(readFileSync(join(home, '.pi', 'grok-cli-quota.json'), 'utf8')).models[ + 'grok-build' + ].remainingRequests, + ).toBe(179); + }); + + it('loads cached quotas from the global pi config directory', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + const home = setupHome(); + writeFileSync( + join(home, '.pi', 'grok-cli-quota.json'), + JSON.stringify({ + version: 1, + models: { + 'grok-build': { + remainingRequests: 42, + limitRequests: 180, + remainingTokens: 1_000, + limitTokens: 2_000, + contextWindow: 512_000, + zeroDataRetention: true, + capturedAt: Date.now(), + }, + }, + }), + ); + const extension = await setupExtension(); + const notify = await runStatus(extension); + + expect(notify.mock.calls.at(-1)?.[0]).toContain('Requests: 42/180 remaining'); + }); + + it('warns when no Grok models are registered', async () => { + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], emptyStatusContext(notify)); + + expect(notify).toHaveBeenCalledOnce(); + expect(notify).toHaveBeenCalledWith( + 'Grok CLI: no models registered. Run /login grok-cli first.', + 'warning', + ); + }); + + it('shows env-token bypass and truncates long model lists', async () => { + process.env.GROK_CLI_OAUTH_TOKEN = 'token'; + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], { + modelRegistry: { + getAll: () => + Array.from({ length: 7 }, (_value, index) => ({ + provider: 'grok-cli', + id: `grok-model-${index + 1}`, + })), + }, + ui: { notify }, + }); + + expect(notify.mock.calls[0]).toEqual([ + '⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available', + 'warning', + ]); + expect(notify.mock.calls[1]).toEqual([ + '✓ Grok CLI: 7 models available (grok-model-1, grok-model-2, grok-model-3, grok-model-4, grok-model-5 (+2 more))', + 'info', + ]); + }); + + it('reports registry errors as status warnings', async () => { + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], { + modelRegistry: { + getAll: () => { + throw new Error('registry unavailable'); + }, + }, + ui: { notify }, + }); + + expect(notify).toHaveBeenCalledWith('Grok CLI: registry unavailable', 'warning'); + }); + + it('includes OAuth error codes in status warnings', async () => { + const { XaiOAuthError } = await import('../../src/shared/errors.js'); + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], { + modelRegistry: { + getAll: () => { + throw new XaiOAuthError('refresh failed', 'refresh_failed', true); + }, + }, + ui: { notify }, + }); + + expect(notify).toHaveBeenCalledWith( + 'Grok CLI: refresh failed (code: refresh_failed)', + 'warning', + ); + }); }); -describe("Grok CLI provider registration", () => { - it("registers provider metadata and OAuth helpers", async () => { - const extension = await setupExtension(); - const provider = extension.providers.get("grok-cli"); - - expect(provider?.name).toBe("Grok CLI"); - expect(provider?.api).toBe("openai-responses"); - expect(provider?.apiKey).toBe("$GROK_CLI_OAUTH_TOKEN"); - expect(provider?.models.map((model) => model.id)).toContain("grok-build"); - expect(provider?.oauth?.getApiKey({ access: "access-token" })).toBe( - "access-token", - ); - expect( - provider?.oauth?.modifyModels( - [ - { provider: "grok-cli", id: "grok-build", baseUrl: "old" }, - { provider: "openai", id: "gpt-4", baseUrl: "keep" }, - ], - { - access: "access-token", - refresh: "refresh-token", - expires: 123, - baseUrl: "https://example.invalid/custom///", - }, - ), - ).toEqual([ - { - provider: "grok-cli", - id: "grok-build", - baseUrl: "https://example.invalid/custom", - }, - { provider: "openai", id: "gpt-4", baseUrl: "keep" }, - ]); - }); - - it("sanitizes Grok provider requests with the current session id", async () => { - const extension = await setupExtension(); - const result = extension.handlers.get("before_provider_request")?.( - { - payload: { - input: [{ role: "system", content: "system instruction" }], - }, - }, - { - model: { provider: "grok-cli", id: "grok-4.3" }, - modelRegistry: { getAll: () => [] }, - sessionManager: { getSessionId: () => "session-123" }, - ui: { notify: vi.fn() }, - }, - ); - - expect(result).toEqual({ - input: [], - instructions: "system instruction", - prompt_cache_key: "session-123", - }); - }); - - it("leaves non-Grok provider requests untouched", async () => { - const extension = await setupExtension(); - const payload = { input: [{ role: "system", content: "keep" }] }; - const result = extension.handlers.get("before_provider_request")?.( - { payload }, - { - model: { provider: "openai", id: "gpt-4" }, - modelRegistry: { getAll: () => [] }, - sessionManager: { getSessionId: () => "session-123" }, - ui: { notify: vi.fn() }, - }, - ); - - expect(result).toBeUndefined(); - expect(payload).toEqual({ input: [{ role: "system", content: "keep" }] }); - }); - - it("warns at session start when env-token bypass is active", async () => { - process.env.GROK_CLI_OAUTH_TOKEN = "token"; - const extension = await setupExtension(); - const notify = vi.fn(); - - await extension.handlers.get("session_start")?.( - {}, - { - modelRegistry: { getAll: () => [] }, - ui: { notify }, - }, - ); - - expect(notify).toHaveBeenCalledWith( - "[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery", - "warning", - ); - }); +describe('Grok CLI provider registration', () => { + it('registers provider metadata and OAuth helpers', async () => { + const extension = await setupExtension(); + const provider = extension.providers.get('grok-cli'); + + expect(provider?.name).toBe('Grok CLI'); + expect(provider?.api).toBe('openai-responses'); + expect(provider?.apiKey).toBe('$GROK_CLI_OAUTH_TOKEN'); + expect(provider?.models.map((model) => model.id)).toContain('grok-build'); + expect(provider?.oauth?.getApiKey({ access: 'access-token' })).toBe('access-token'); + expect( + provider?.oauth?.modifyModels( + [ + { provider: 'grok-cli', id: 'grok-build', baseUrl: 'old' }, + { provider: 'openai', id: 'gpt-4', baseUrl: 'keep' }, + ], + { + access: 'access-token', + refresh: 'refresh-token', + expires: 123, + baseUrl: 'https://example.invalid/custom///', + }, + ), + ).toEqual([ + { + provider: 'grok-cli', + id: 'grok-build', + baseUrl: 'https://example.invalid/custom', + }, + { provider: 'openai', id: 'gpt-4', baseUrl: 'keep' }, + ]); + }); + + it('sanitizes Grok provider requests with the current session id', async () => { + const extension = await setupExtension(); + const result = extension.handlers.get('before_provider_request')?.( + { + payload: { + input: [{ role: 'system', content: 'system instruction' }], + }, + }, + { + model: { provider: 'grok-cli', id: 'grok-4.3' }, + modelRegistry: { getAll: () => [] }, + sessionManager: { getSessionId: () => 'session-123' }, + ui: { notify: vi.fn() }, + }, + ); + + expect(result).toEqual({ + input: [], + instructions: 'system instruction', + prompt_cache_key: 'session-123', + }); + }); + + it('leaves non-Grok provider requests untouched', async () => { + const extension = await setupExtension(); + const payload = { input: [{ role: 'system', content: 'keep' }] }; + const result = extension.handlers.get('before_provider_request')?.( + { payload }, + { + model: { provider: 'openai', id: 'gpt-4' }, + modelRegistry: { getAll: () => [] }, + sessionManager: { getSessionId: () => 'session-123' }, + ui: { notify: vi.fn() }, + }, + ); + + expect(result).toBeUndefined(); + expect(payload).toEqual({ input: [{ role: 'system', content: 'keep' }] }); + }); + + it('warns at session start when env-token bypass is active', async () => { + process.env.GROK_CLI_OAUTH_TOKEN = 'token'; + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.handlers.get('session_start')?.( + {}, + { + modelRegistry: { getAll: () => [] }, + ui: { notify }, + }, + ); + + expect(notify).toHaveBeenCalledWith( + '[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery', + 'warning', + ); + }); }); -describe("Grok CLI tool scoping", () => { - it("registers the Grok/Cursor-native tool shims", async () => { - const extension = await setupExtension(); - - expect([...extension.tools.keys()].sort()).toEqual( - [...grokToolNames].sort(), - ); - }); - - it("enables Grok tools for Grok models while preserving other active tools", async () => { - const extension = await setupExtension(["read", "custom_tool"]); - - await extension.handlers.get("model_select")?.( - { model: { provider: "grok-cli", id: "grok-build" } }, - contextForModel("grok-cli"), - ); - - expect(extension.setActiveTools).toHaveBeenLastCalledWith([ - "read", - "custom_tool", - ...grokToolNames, - ]); - }); - - it("removes Grok tools for non-Grok models while preserving other active tools", async () => { - const extension = await setupExtension([ - "read", - "Grep", - "custom_tool", - "Shell", - ]); - - await extension.handlers.get("model_select")?.( - { model: { provider: "openai", id: "gpt-4" } }, - contextForModel("openai"), - ); - - expect(extension.setActiveTools).toHaveBeenLastCalledWith([ - "read", - "custom_tool", - ]); - }); - - it("syncs tool scope before each agent turn from the current context model", async () => { - const extension = await setupExtension(["read"]); - - await extension.handlers.get("before_agent_start")?.( - {}, - contextForModel("grok-cli"), - ); - - expect(extension.setActiveTools).toHaveBeenLastCalledWith([ - "read", - ...grokToolNames, - ]); - }); - - it("does not update active tools when the selection is already correct", async () => { - const extension = await setupExtension(["read", ...grokToolNames]); - - await extension.handlers.get("before_agent_start")?.( - {}, - contextForModel("grok-cli"), - ); - - expect(extension.setActiveTools).not.toHaveBeenCalled(); - }); +describe('Grok CLI tool scoping', () => { + it('registers the Grok/Cursor-native tool shims', async () => { + const extension = await setupExtension(); + + expect([...extension.tools.keys()].sort()).toEqual([...grokToolNames].sort()); + }); + + it('enables Grok tools for Grok models while preserving other active tools', async () => { + const extension = await setupExtension(['read', 'custom_tool']); + + await extension.handlers.get('model_select')?.( + { model: { provider: 'grok-cli', id: 'grok-build' } }, + contextForModel('grok-cli'), + ); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith([ + 'read', + 'custom_tool', + ...grokToolNames, + ]); + }); + + it('removes Grok tools for non-Grok models while preserving other active tools', async () => { + const extension = await setupExtension(['read', 'Grep', 'custom_tool', 'Shell']); + + await extension.handlers.get('model_select')?.( + { model: { provider: 'openai', id: 'gpt-4' } }, + contextForModel('openai'), + ); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith(['read', 'custom_tool']); + }); + + it('syncs tool scope before each agent turn from the current context model', async () => { + const extension = await setupExtension(['read']); + + await extension.handlers.get('before_agent_start')?.({}, contextForModel('grok-cli')); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith(['read', ...grokToolNames]); + }); + + it('does not update active tools when the selection is already correct', async () => { + const extension = await setupExtension(['read', ...grokToolNames]); + + await extension.handlers.get('before_agent_start')?.({}, contextForModel('grok-cli')); + + expect(extension.setActiveTools).not.toHaveBeenCalled(); + }); }); -describe("Grok CLI tool rendering", () => { - it("adds renderers to every Grok tool shim", async () => { - const extension = await setupExtension(); - - for (const name of grokToolNames) { - expect(extension.tools.get(name)?.renderCall).toBeTypeOf("function"); - expect(extension.tools.get(name)?.renderResult).toBeTypeOf("function"); - } - }); - - it("keeps collapsed search output compact and expands to full output", async () => { - const extension = await setupExtension(); - const grep = extension.tools.get("Grep"); - const result = { - content: [{ type: "text", text: "src/a.ts:1:match\nsrc/b.ts:2:match" }], - details: { matchCount: 2 }, - }; - - const collapsed = renderText( - grep?.renderResult?.( - result, - { expanded: false, isPartial: false }, - theme, - {}, - ) as Renderable, - ); - const expanded = renderText( - grep?.renderResult?.( - result, - { expanded: true, isPartial: false }, - theme, - {}, - ) as Renderable, - ); - - expect(collapsed).toBe("2 match(es)"); - expect(collapsed).not.toContain("src/a.ts"); - expect(expanded).toContain("src/a.ts:1:match"); - }); - - it("renders compact summaries for file mutations, delete, and shell tools", async () => { - const extension = await setupExtension(); - - expect( - renderText( - extension.tools.get("Write")?.renderResult?.( - { - content: [{ type: "text", text: "long write output" }], - details: { bytesWritten: 42 }, - }, - { expanded: false, isPartial: false }, - theme, - {}, - ) as Renderable, - ), - ).toBe("42 bytes written"); - expect( - renderText( - extension.tools.get("StrReplace")?.renderResult?.( - { - content: [{ type: "text", text: "long replace output" }], - details: { replacements: 3 }, - }, - { expanded: false, isPartial: false }, - theme, - {}, - ) as Renderable, - ), - ).toBe("3 replacement(s)"); - expect( - renderText( - extension.tools.get("Delete")?.renderResult?.( - { - content: [{ type: "text", text: "long delete output" }], - details: { deleted: true }, - }, - { expanded: false, isPartial: false }, - theme, - {}, - ) as Renderable, - ), - ).toBe("Deleted"); - expect( - renderText( - extension.tools.get("Shell")?.renderResult?.( - { - content: [{ type: "text", text: "long shell output" }], - details: { exitCode: 2 }, - }, - { expanded: false, isPartial: false }, - theme, - {}, - ) as Renderable, - ), - ).toBe("Exit 2"); - }); +describe('Grok CLI tool rendering', () => { + it('adds renderers to every Grok tool shim', async () => { + const extension = await setupExtension(); + + for (const name of grokToolNames) { + expect(extension.tools.get(name)?.renderCall).toBeTypeOf('function'); + expect(extension.tools.get(name)?.renderResult).toBeTypeOf('function'); + } + }); + + it('keeps collapsed search output compact and expands to full output', async () => { + const extension = await setupExtension(); + const grep = extension.tools.get('Grep'); + const result = { + content: [{ type: 'text', text: 'src/a.ts:1:match\nsrc/b.ts:2:match' }], + details: { matchCount: 2 }, + }; + + const collapsed = renderText( + grep?.renderResult?.(result, { expanded: false, isPartial: false }, theme, {}) as Renderable, + ); + const expanded = renderText( + grep?.renderResult?.(result, { expanded: true, isPartial: false }, theme, {}) as Renderable, + ); + + expect(collapsed).toBe('2 match(es)'); + expect(collapsed).not.toContain('src/a.ts'); + expect(expanded).toContain('src/a.ts:1:match'); + }); + + it('renders compact summaries for file mutations, delete, and shell tools', async () => { + const extension = await setupExtension(); + + expect( + renderText( + extension.tools.get('Write')?.renderResult?.( + { + content: [{ type: 'text', text: 'long write output' }], + details: { bytesWritten: 42 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe('42 bytes written'); + expect( + renderText( + extension.tools.get('StrReplace')?.renderResult?.( + { + content: [{ type: 'text', text: 'long replace output' }], + details: { replacements: 3 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe('3 replacement(s)'); + expect( + renderText( + extension.tools.get('Delete')?.renderResult?.( + { + content: [{ type: 'text', text: 'long delete output' }], + details: { deleted: true }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe('Deleted'); + expect( + renderText( + extension.tools.get('Shell')?.renderResult?.( + { + content: [{ type: 'text', text: 'long shell output' }], + details: { exitCode: 2 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe('Exit 2'); + }); }); diff --git a/tests/shared/errors.test.ts b/tests/shared/errors.test.ts index 8d0d27d..5b5863d 100644 --- a/tests/shared/errors.test.ts +++ b/tests/shared/errors.test.ts @@ -1,18 +1,14 @@ -import { describe, expect, it } from "vitest"; -import { XaiErrorCode, XaiOAuthError } from "../../src/shared/errors.js"; +import { describe, expect, it } from 'vitest'; +import { XaiErrorCode, XaiOAuthError } from '../../src/shared/errors.js'; -describe("OAuth errors", () => { - it("keeps machine-readable code and relogin state", () => { - const error = new XaiOAuthError( - "Refresh token was revoked", - XaiErrorCode.REFRESH_FAILED, - true, - ); +describe('OAuth errors', () => { + it('keeps machine-readable code and relogin state', () => { + const error = new XaiOAuthError('Refresh token was revoked', XaiErrorCode.REFRESH_FAILED, true); - expect(error).toBeInstanceOf(Error); - expect(error.name).toBe("XaiOAuthError"); - expect(error.message).toBe("Refresh token was revoked"); - expect(error.code).toBe("refresh_failed"); - expect(error.reloginRequired).toBe(true); - }); + expect(error).toBeInstanceOf(Error); + expect(error.name).toBe('XaiOAuthError'); + expect(error.message).toBe('Refresh token was revoked'); + expect(error.code).toBe('refresh_failed'); + expect(error.reloginRequired).toBe(true); + }); }); diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts index a036709..3a44bbe 100644 --- a/tests/tools/files.test.ts +++ b/tests/tools/files.test.ts @@ -1,406 +1,350 @@ -import { existsSync, mkdirSync, readFileSync, writeFileSync } from "node:fs"; -import { join } from "node:path"; -import { describe, expect, it } from "vitest"; -import { registerFileTools } from "../../src/tools/files.js"; +import { existsSync, mkdirSync, readFileSync, writeFileSync } from 'node:fs'; +import { join } from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { registerFileTools } from '../../src/tools/files.js'; import { - collectTools, - executePreparedTool, - executeTool, - firstText, - renderToolCall, - renderToolResult, - type ToolResult, - tempDir, -} from "./toolTestHelpers.js"; - -function expectStoryState( - result: ToolResult, - cwd: string, - replacements: number, - content: string, -) { - expect(result.details).toEqual({ - path: join(cwd, "story.txt"), - replacements, - }); - expect(readFileSync(join(cwd, "story.txt"), "utf-8")).toBe(content); + collectTools, + executePreparedTool, + executeTool, + firstText, + renderToolCall, + renderToolResult, + type ToolResult, + tempDir, +} from './toolTestHelpers.js'; + +function expectStoryState(result: ToolResult, cwd: string, replacements: number, content: string) { + expect(result.details).toEqual({ + path: join(cwd, 'story.txt'), + replacements, + }); + expect(readFileSync(join(cwd, 'story.txt'), 'utf-8')).toBe(content); } function strReplace(cwd: string, old_str: string, new_str: string) { - return executeTool( - collectTools(registerFileTools).get("StrReplace"), - { path: "story.txt", old_str, new_str }, - cwd, - ); + return executeTool( + collectTools(registerFileTools).get('StrReplace'), + { path: 'story.txt', old_str, new_str }, + cwd, + ); } -function strReplaceWithPreparedArgs( - cwd: string, - params: Record, -) { - return executePreparedTool( - collectTools(registerFileTools).get("StrReplace"), - { path: "story.txt", ...params }, - cwd, - ); +function strReplaceWithPreparedArgs(cwd: string, params: Record) { + return executePreparedTool( + collectTools(registerFileTools).get('StrReplace'), + { path: 'story.txt', ...params }, + cwd, + ); } -describe("file tools", () => { - it("lists directory contents including hidden files", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(cwd, ".hidden"), "secret", "utf-8"); - writeFileSync(join(cwd, "visible.txt"), "visible", "utf-8"); - - const result = await executeTool( - collectTools(registerFileTools).get("LS"), - { path: "." }, - cwd, - ); - - expect(firstText(result)).toContain(".hidden"); - expect(firstText(result)).toContain("visible.txt"); - expect(result.details).toEqual({ path: cwd }); - }); - - it("reports filesystem errors for invalid file operations", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - mkdirSync(join(cwd, "dir")); - writeFileSync(join(cwd, "blocked"), "not a directory", "utf-8"); - const tools = collectTools(registerFileTools); - - const lsResult = await executeTool( - tools.get("LS"), - { path: "missing-dir" }, - cwd, - ); - const readResult = await executeTool( - tools.get("Read"), - { path: "dir" }, - cwd, - ); - const writeResult = await executeTool( - tools.get("Write"), - { path: "blocked/file.txt", content: "content" }, - cwd, - ); - const replaceResult = await executeTool( - tools.get("StrReplace"), - { path: "dir", old_str: "old", new_str: "new" }, - cwd, - ); - const deleteResult = await executeTool( - tools.get("Delete"), - { path: "dir" }, - cwd, - ); - - expect(firstText(lsResult).startsWith("LS error:")).toBe(true); - expect(firstText(readResult).startsWith("Read error:")).toBe(true); - expect(firstText(writeResult).startsWith("Write error:")).toBe(true); - expect(firstText(replaceResult).startsWith("StrReplace error:")).toBe(true); - expect(firstText(deleteResult).startsWith("Delete error:")).toBe(true); - expect(writeResult.details).toEqual({ - path: join(cwd, "blocked", "file.txt"), - bytesWritten: 0, - }); - expect(replaceResult.details).toEqual({ - path: join(cwd, "dir"), - replacements: 0, - }); - expect(deleteResult.details).toEqual({ - path: join(cwd, "dir"), - deleted: false, - }); - }); - - it("writes a nested file and reads a requested line window", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - const tools = collectTools(registerFileTools); - - const writeResult = await executeTool( - tools.get("Write"), - { path: "nested/notes.txt", content: "alpha\nbeta\ngamma\ndelta" }, - cwd, - ); - - expect(firstText(writeResult)).toBe( - "Successfully wrote 22 bytes to nested/notes.txt", - ); - expect(writeResult.details).toEqual({ - path: join(cwd, "nested/notes.txt"), - bytesWritten: 22, - }); - - const readResult = await executeTool( - tools.get("Read"), - { path: "nested/notes.txt", offset: 1, limit: 2 }, - cwd, - ); - - expect(firstText(readResult)).toBe( - "2\tbeta\n3\tgamma\n\n[Showing lines 2-3 of 4 total lines. Use offset to see more.]", - ); - expect(readResult.details).toEqual({ - path: join(cwd, "nested/notes.txt"), - totalLines: 4, - }); - }); - - it("writes Cursor-style contents arguments", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - - const result = await executePreparedTool( - collectTools(registerFileTools).get("Write"), - { path: "nested/notes.txt", contents: "alpha\nbeta" }, - cwd, - ); - - expect(firstText(result)).toBe( - "Successfully wrote 10 bytes to nested/notes.txt", - ); - expect(readFileSync(join(cwd, "nested/notes.txt"), "utf-8")).toBe( - "alpha\nbeta", - ); - expect(result.details).toEqual({ - path: join(cwd, "nested/notes.txt"), - bytesWritten: 10, - }); - }); - - it("reports missing files without throwing", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - const result = await executeTool( - collectTools(registerFileTools).get("Read"), - { path: "missing.txt" }, - cwd, - ); - - expect(firstText(result)).toBe( - `File not found: ${join(cwd, "missing.txt")}`, - ); - expect(result.details).toEqual({ - path: join(cwd, "missing.txt"), - exists: false, - totalLines: 0, - }); - }); - - it("replaces every exact string occurrence", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(cwd, "story.txt"), "red blue red", "utf-8"); - - const result = await strReplace(cwd, "red", "green"); - - expect(firstText(result)).toBe("Replaced 2 occurrence(s) in story.txt"); - expectStoryState(result, cwd, 2, "green blue green"); - }); - - it("replaces string occurrences with Grok and Cursor argument variants", async () => { - const oldStringCwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(oldStringCwd, "story.txt"), "red blue red", "utf-8"); - - const oldStringResult = await strReplaceWithPreparedArgs(oldStringCwd, { - old_string: "red", - new_string: "green", - }); - - expect(firstText(oldStringResult)).toBe( - "Replaced 2 occurrence(s) in story.txt", - ); - expectStoryState(oldStringResult, oldStringCwd, 2, "green blue green"); - - const oldTextCwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(oldTextCwd, "story.txt"), "red blue red", "utf-8"); - - const oldTextResult = await strReplaceWithPreparedArgs(oldTextCwd, { - oldText: "red", - newText: "green", - }); - - expect(firstText(oldTextResult)).toBe( - "Replaced 2 occurrence(s) in story.txt", - ); - expectStoryState(oldTextResult, oldTextCwd, 2, "green blue green"); - - const nestedCwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(nestedCwd, "story.txt"), "red blue red", "utf-8"); - - const nestedResult = await strReplaceWithPreparedArgs(nestedCwd, { - strReplace: { oldText: "red", newText: "green" }, - }); - - expect(firstText(nestedResult)).toBe( - "Replaced 2 occurrence(s) in story.txt", - ); - expectStoryState(nestedResult, nestedCwd, 2, "green blue green"); - }); - - it("edits files with single, multiple, and stringified replacement inputs", async () => { - const singleCwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(singleCwd, "story.txt"), "red blue red", "utf-8"); - - const singleResult = await executePreparedTool( - collectTools(registerFileTools).get("Edit"), - { path: "story.txt", oldText: "red", newText: "green" }, - singleCwd, - ); - - expect(firstText(singleResult)).toBe( - "Applied 2 replacement(s) in story.txt", - ); - expectStoryState(singleResult, singleCwd, 2, "green blue green"); - - const multipleCwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(multipleCwd, "story.txt"), "red blue red", "utf-8"); - - const multipleResult = await executePreparedTool( - collectTools(registerFileTools).get("Edit"), - { - path: "story.txt", - edits: [ - { oldText: "red", newText: "green" }, - { oldText: "blue", newText: "yellow" }, - ], - }, - multipleCwd, - ); - - expect(firstText(multipleResult)).toBe( - "Applied 3 replacement(s) in story.txt", - ); - expectStoryState(multipleResult, multipleCwd, 3, "green yellow green"); - - const stringifiedCwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(stringifiedCwd, "story.txt"), "red blue red", "utf-8"); - - const stringifiedResult = await executePreparedTool( - collectTools(registerFileTools).get("Edit"), - { - path: "story.txt", - edits: JSON.stringify([{ oldText: "red", newText: "green" }]), - }, - stringifiedCwd, - ); - - expect(firstText(stringifiedResult)).toBe( - "Applied 2 replacement(s) in story.txt", - ); - expectStoryState(stringifiedResult, stringifiedCwd, 2, "green blue green"); - }); - - it("reports unsupported edit strategies without changing files", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(cwd, "story.txt"), "red blue red", "utf-8"); - - const result = await executePreparedTool( - collectTools(registerFileTools).get("Edit"), - { path: "story.txt", applyPatch: { patchContent: "patch" } }, - cwd, - ); - - expect(firstText(result)).toBe( - "Edit error: applyPatch is not supported by this Grok tool shim", - ); - expectStoryState(result, cwd, 0, "red blue red"); - }); - - it("leaves files unchanged when the replacement string is absent", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(cwd, "story.txt"), "red blue red", "utf-8"); - - const result = await strReplace(cwd, "purple", "green"); - - expect(firstText(result)).toBe('String not found in story.txt: "purple"'); - expectStoryState(result, cwd, 0, "red blue red"); - }); - - it("deletes existing files and reports missing files", async () => { - const cwd = tempDir("pi-grok-cli-files-"); - writeFileSync(join(cwd, "remove.txt"), "delete me", "utf-8"); - const tools = collectTools(registerFileTools); - - const deletedResult = await executeTool( - tools.get("Delete"), - { path: "remove.txt" }, - cwd, - ); - - expect(firstText(deletedResult)).toBe("Successfully deleted remove.txt"); - expect(deletedResult.details).toEqual({ - path: join(cwd, "remove.txt"), - deleted: true, - }); - expect(existsSync(join(cwd, "remove.txt"))).toBe(false); - - const missingResult = await executeTool( - tools.get("Delete"), - { path: "remove.txt" }, - cwd, - ); - - expect(firstText(missingResult)).toBe( - `File not found: ${join(cwd, "remove.txt")}`, - ); - expect(missingResult.details).toEqual({ - path: join(cwd, "remove.txt"), - deleted: false, - }); - }); - - it("renders file tool calls and result states", () => { - const tools = collectTools(registerFileTools); - - expect(renderToolCall(tools.get("LS"), { path: "." })).toBe("LS ."); - expect( - renderToolCall(tools.get("Read"), { - path: "notes.txt", - offset: 5, - limit: 10, - }), - ).toBe("Read notes.txt (from 5, 10 lines)"); - expect(renderToolCall(tools.get("StrReplace"), { path: "notes.txt" })).toBe( - "StrReplace notes.txt", - ); - expect(renderToolCall(tools.get("Delete"), { path: "notes.txt" })).toBe( - "Delete notes.txt", - ); - expect( - renderToolResult(tools.get("Read"), { - content: [{ type: "text", text: "missing" }], - details: { exists: false, totalLines: 0 }, - }), - ).toBe("File not found"); - expect( - renderToolResult(tools.get("StrReplace"), { - content: [{ type: "text", text: "no replacement" }], - details: { replacements: 0 }, - }), - ).toBe("No replacements"); - expect( - renderToolResult(tools.get("Delete"), { - content: [{ type: "text", text: "not deleted" }], - details: { deleted: false }, - }), - ).toBe("Not deleted"); - expect( - renderToolResult( - tools.get("LS"), - { - content: [{ type: "text", text: "full listing" }], - details: { path: "/tmp/project" }, - }, - { expanded: true, isPartial: false }, - ), - ).toBe("full listing"); - expect( - renderToolResult( - tools.get("Write"), - { - content: [{ type: "text", text: "writing" }], - details: { bytesWritten: 10 }, - }, - { expanded: false, isPartial: true }, - ), - ).toBe("Running..."); - }); +describe('file tools', () => { + it('lists directory contents including hidden files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, '.hidden'), 'secret', 'utf-8'); + writeFileSync(join(cwd, 'visible.txt'), 'visible', 'utf-8'); + + const result = await executeTool(collectTools(registerFileTools).get('LS'), { path: '.' }, cwd); + + expect(firstText(result)).toContain('.hidden'); + expect(firstText(result)).toContain('visible.txt'); + expect(result.details).toEqual({ path: cwd }); + }); + + it('reports filesystem errors for invalid file operations', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + mkdirSync(join(cwd, 'dir')); + writeFileSync(join(cwd, 'blocked'), 'not a directory', 'utf-8'); + const tools = collectTools(registerFileTools); + + const lsResult = await executeTool(tools.get('LS'), { path: 'missing-dir' }, cwd); + const readResult = await executeTool(tools.get('Read'), { path: 'dir' }, cwd); + const writeResult = await executeTool( + tools.get('Write'), + { path: 'blocked/file.txt', content: 'content' }, + cwd, + ); + const replaceResult = await executeTool( + tools.get('StrReplace'), + { path: 'dir', old_str: 'old', new_str: 'new' }, + cwd, + ); + const deleteResult = await executeTool(tools.get('Delete'), { path: 'dir' }, cwd); + + expect(firstText(lsResult).startsWith('LS error:')).toBe(true); + expect(firstText(readResult).startsWith('Read error:')).toBe(true); + expect(firstText(writeResult).startsWith('Write error:')).toBe(true); + expect(firstText(replaceResult).startsWith('StrReplace error:')).toBe(true); + expect(firstText(deleteResult).startsWith('Delete error:')).toBe(true); + expect(writeResult.details).toEqual({ + path: join(cwd, 'blocked', 'file.txt'), + bytesWritten: 0, + }); + expect(replaceResult.details).toEqual({ + path: join(cwd, 'dir'), + replacements: 0, + }); + expect(deleteResult.details).toEqual({ + path: join(cwd, 'dir'), + deleted: false, + }); + }); + + it('writes a nested file and reads a requested line window', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const tools = collectTools(registerFileTools); + + const writeResult = await executeTool( + tools.get('Write'), + { path: 'nested/notes.txt', content: 'alpha\nbeta\ngamma\ndelta' }, + cwd, + ); + + expect(firstText(writeResult)).toBe('Successfully wrote 22 bytes to nested/notes.txt'); + expect(writeResult.details).toEqual({ + path: join(cwd, 'nested/notes.txt'), + bytesWritten: 22, + }); + + const readResult = await executeTool( + tools.get('Read'), + { path: 'nested/notes.txt', offset: 1, limit: 2 }, + cwd, + ); + + expect(firstText(readResult)).toBe( + '2\tbeta\n3\tgamma\n\n[Showing lines 2-3 of 4 total lines. Use offset to see more.]', + ); + expect(readResult.details).toEqual({ + path: join(cwd, 'nested/notes.txt'), + totalLines: 4, + }); + }); + + it('writes Cursor-style contents arguments', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + + const result = await executePreparedTool( + collectTools(registerFileTools).get('Write'), + { path: 'nested/notes.txt', contents: 'alpha\nbeta' }, + cwd, + ); + + expect(firstText(result)).toBe('Successfully wrote 10 bytes to nested/notes.txt'); + expect(readFileSync(join(cwd, 'nested/notes.txt'), 'utf-8')).toBe('alpha\nbeta'); + expect(result.details).toEqual({ + path: join(cwd, 'nested/notes.txt'), + bytesWritten: 10, + }); + }); + + it('reports missing files without throwing', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const result = await executeTool( + collectTools(registerFileTools).get('Read'), + { path: 'missing.txt' }, + cwd, + ); + + expect(firstText(result)).toBe(`File not found: ${join(cwd, 'missing.txt')}`); + expect(result.details).toEqual({ + path: join(cwd, 'missing.txt'), + exists: false, + totalLines: 0, + }); + }); + + it('replaces every exact string occurrence', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await strReplace(cwd, 'red', 'green'); + + expect(firstText(result)).toBe('Replaced 2 occurrence(s) in story.txt'); + expectStoryState(result, cwd, 2, 'green blue green'); + }); + + it('replaces string occurrences with Grok and Cursor argument variants', async () => { + const oldStringCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(oldStringCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const oldStringResult = await strReplaceWithPreparedArgs(oldStringCwd, { + old_string: 'red', + new_string: 'green', + }); + + expect(firstText(oldStringResult)).toBe('Replaced 2 occurrence(s) in story.txt'); + expectStoryState(oldStringResult, oldStringCwd, 2, 'green blue green'); + + const oldTextCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(oldTextCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const oldTextResult = await strReplaceWithPreparedArgs(oldTextCwd, { + oldText: 'red', + newText: 'green', + }); + + expect(firstText(oldTextResult)).toBe('Replaced 2 occurrence(s) in story.txt'); + expectStoryState(oldTextResult, oldTextCwd, 2, 'green blue green'); + + const nestedCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(nestedCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const nestedResult = await strReplaceWithPreparedArgs(nestedCwd, { + strReplace: { oldText: 'red', newText: 'green' }, + }); + + expect(firstText(nestedResult)).toBe('Replaced 2 occurrence(s) in story.txt'); + expectStoryState(nestedResult, nestedCwd, 2, 'green blue green'); + }); + + it('edits files with single, multiple, and stringified replacement inputs', async () => { + const singleCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(singleCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const singleResult = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { path: 'story.txt', oldText: 'red', newText: 'green' }, + singleCwd, + ); + + expect(firstText(singleResult)).toBe('Applied 2 replacement(s) in story.txt'); + expectStoryState(singleResult, singleCwd, 2, 'green blue green'); + + const multipleCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(multipleCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const multipleResult = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { + path: 'story.txt', + edits: [ + { oldText: 'red', newText: 'green' }, + { oldText: 'blue', newText: 'yellow' }, + ], + }, + multipleCwd, + ); + + expect(firstText(multipleResult)).toBe('Applied 3 replacement(s) in story.txt'); + expectStoryState(multipleResult, multipleCwd, 3, 'green yellow green'); + + const stringifiedCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(stringifiedCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const stringifiedResult = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { + path: 'story.txt', + edits: JSON.stringify([{ oldText: 'red', newText: 'green' }]), + }, + stringifiedCwd, + ); + + expect(firstText(stringifiedResult)).toBe('Applied 2 replacement(s) in story.txt'); + expectStoryState(stringifiedResult, stringifiedCwd, 2, 'green blue green'); + }); + + it('reports unsupported edit strategies without changing files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { path: 'story.txt', applyPatch: { patchContent: 'patch' } }, + cwd, + ); + + expect(firstText(result)).toBe( + 'Edit error: applyPatch is not supported by this Grok tool shim', + ); + expectStoryState(result, cwd, 0, 'red blue red'); + }); + + it('leaves files unchanged when the replacement string is absent', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await strReplace(cwd, 'purple', 'green'); + + expect(firstText(result)).toBe('String not found in story.txt: "purple"'); + expectStoryState(result, cwd, 0, 'red blue red'); + }); + + it('deletes existing files and reports missing files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'remove.txt'), 'delete me', 'utf-8'); + const tools = collectTools(registerFileTools); + + const deletedResult = await executeTool(tools.get('Delete'), { path: 'remove.txt' }, cwd); + + expect(firstText(deletedResult)).toBe('Successfully deleted remove.txt'); + expect(deletedResult.details).toEqual({ + path: join(cwd, 'remove.txt'), + deleted: true, + }); + expect(existsSync(join(cwd, 'remove.txt'))).toBe(false); + + const missingResult = await executeTool(tools.get('Delete'), { path: 'remove.txt' }, cwd); + + expect(firstText(missingResult)).toBe(`File not found: ${join(cwd, 'remove.txt')}`); + expect(missingResult.details).toEqual({ + path: join(cwd, 'remove.txt'), + deleted: false, + }); + }); + + it('renders file tool calls and result states', () => { + const tools = collectTools(registerFileTools); + + expect(renderToolCall(tools.get('LS'), { path: '.' })).toBe('LS .'); + expect( + renderToolCall(tools.get('Read'), { + path: 'notes.txt', + offset: 5, + limit: 10, + }), + ).toBe('Read notes.txt (from 5, 10 lines)'); + expect(renderToolCall(tools.get('StrReplace'), { path: 'notes.txt' })).toBe( + 'StrReplace notes.txt', + ); + expect(renderToolCall(tools.get('Delete'), { path: 'notes.txt' })).toBe('Delete notes.txt'); + expect( + renderToolResult(tools.get('Read'), { + content: [{ type: 'text', text: 'missing' }], + details: { exists: false, totalLines: 0 }, + }), + ).toBe('File not found'); + expect( + renderToolResult(tools.get('StrReplace'), { + content: [{ type: 'text', text: 'no replacement' }], + details: { replacements: 0 }, + }), + ).toBe('No replacements'); + expect( + renderToolResult(tools.get('Delete'), { + content: [{ type: 'text', text: 'not deleted' }], + details: { deleted: false }, + }), + ).toBe('Not deleted'); + expect( + renderToolResult( + tools.get('LS'), + { + content: [{ type: 'text', text: 'full listing' }], + details: { path: '/tmp/project' }, + }, + { expanded: true, isPartial: false }, + ), + ).toBe('full listing'); + expect( + renderToolResult( + tools.get('Write'), + { + content: [{ type: 'text', text: 'writing' }], + details: { bytesWritten: 10 }, + }, + { expanded: false, isPartial: true }, + ), + ).toBe('Running...'); + }); }); diff --git a/tests/tools/register.test.ts b/tests/tools/register.test.ts index 5e1a272..e7eb446 100644 --- a/tests/tools/register.test.ts +++ b/tests/tools/register.test.ts @@ -1,33 +1,29 @@ -import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; -import { describe, expect, it } from "vitest"; -import { registerGrokTools } from "../../src/tools/register.js"; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { describe, expect, it } from 'vitest'; +import { registerGrokTools } from '../../src/tools/register.js'; -describe("Grok tool registration", () => { - it("registers all Grok/Cursor-native tool shims with renderers", () => { - const toolNames: string[] = []; +describe('Grok tool registration', () => { + it('registers all Grok/Cursor-native tool shims with renderers', () => { + const toolNames: string[] = []; - registerGrokTools({ - registerTool(tool: { - name: string; - renderCall?: unknown; - renderResult?: unknown; - }) { - toolNames.push(tool.name); - expect(tool.renderCall).toBeTypeOf("function"); - expect(tool.renderResult).toBeTypeOf("function"); - }, - } as unknown as ExtensionAPI); + registerGrokTools({ + registerTool(tool: { name: string; renderCall?: unknown; renderResult?: unknown }) { + toolNames.push(tool.name); + expect(tool.renderCall).toBeTypeOf('function'); + expect(tool.renderResult).toBeTypeOf('function'); + }, + } as unknown as ExtensionAPI); - expect(toolNames.sort()).toEqual([ - "Delete", - "Edit", - "Glob", - "Grep", - "LS", - "Read", - "Shell", - "StrReplace", - "Write", - ]); - }); + expect(toolNames.sort()).toEqual([ + 'Delete', + 'Edit', + 'Glob', + 'Grep', + 'LS', + 'Read', + 'Shell', + 'StrReplace', + 'Write', + ]); + }); }); diff --git a/tests/tools/rendering.test.ts b/tests/tools/rendering.test.ts index fc54d81..c4db165 100644 --- a/tests/tools/rendering.test.ts +++ b/tests/tools/rendering.test.ts @@ -1,98 +1,86 @@ -import { describe, expect, it } from "vitest"; +import { describe, expect, it } from 'vitest'; import { - booleanDetail, - detailRecord, - fileError, - fileNotFound, - MAX_LINES, - MAX_OUTPUT_CHARS, - numberDetail, - renderResultSummary, - renderResultText, - renderRunning, - stringDetail, - text, - toolError, - truncateChars, - truncateLines, -} from "../../src/tools/rendering.js"; -import { renderText } from "./toolTestHelpers.js"; + booleanDetail, + detailRecord, + fileError, + fileNotFound, + MAX_LINES, + MAX_OUTPUT_CHARS, + numberDetail, + renderResultSummary, + renderResultText, + renderRunning, + stringDetail, + text, + toolError, + truncateChars, + truncateLines, +} from '../../src/tools/rendering.js'; +import { renderText } from './toolTestHelpers.js'; -describe("tool rendering helpers", () => { - it("truncates long result lists and large output", () => { - expect(truncateLines(["one", "two"])).toBe("one\ntwo"); - expect( - truncateLines(Array.from({ length: MAX_LINES + 1 }, String)).endsWith( - `\n\n[Showing first ${MAX_LINES} of ${MAX_LINES + 1} results. Refine your pattern to narrow results.]`, - ), - ).toBe(true); - expect(truncateChars("short")).toBe("short"); - expect(truncateChars("x".repeat(MAX_OUTPUT_CHARS + 1))).toBe( - `${"x".repeat(MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`, - ); - }); +describe('tool rendering helpers', () => { + it('truncates long result lists and large output', () => { + expect(truncateLines(['one', 'two'])).toBe('one\ntwo'); + expect( + truncateLines(Array.from({ length: MAX_LINES + 1 }, String)).endsWith( + `\n\n[Showing first ${MAX_LINES} of ${MAX_LINES + 1} results. Refine your pattern to narrow results.]`, + ), + ).toBe(true); + expect(truncateChars('short')).toBe('short'); + expect(truncateChars('x'.repeat(MAX_OUTPUT_CHARS + 1))).toBe( + `${'x'.repeat(MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`, + ); + }); - it("renders summaries, expanded text, missing text fallback, and partial state", () => { - const result = { - content: [{ type: "text", text: "full output" }], - details: {}, - }; + it('renders summaries, expanded text, missing text fallback, and partial state', () => { + const result = { + content: [{ type: 'text', text: 'full output' }], + details: {}, + }; - expect(renderText(text("plain"))).toBe("plain"); - expect(renderText(renderResultText(result, false, "summary"))).toBe( - "summary", - ); - expect(renderText(renderResultText(result, true, "summary"))).toBe( - "full output", - ); - expect( - renderText( - renderResultText( - { content: [{ type: "image" }], details: {} }, - true, - "summary", - ), - ), - ).toBe("summary"); - expect(renderText(renderRunning(true) ?? text(""))).toBe("Running..."); - expect(renderRunning(false)).toBeUndefined(); - expect( - renderText(renderResultSummary(result, false, true, "summary")), - ).toBe("Running..."); - }); + expect(renderText(text('plain'))).toBe('plain'); + expect(renderText(renderResultText(result, false, 'summary'))).toBe('summary'); + expect(renderText(renderResultText(result, true, 'summary'))).toBe('full output'); + expect( + renderText(renderResultText({ content: [{ type: 'image' }], details: {} }, true, 'summary')), + ).toBe('summary'); + expect(renderText(renderRunning(true) ?? text(''))).toBe('Running...'); + expect(renderRunning(false)).toBeUndefined(); + expect(renderText(renderResultSummary(result, false, true, 'summary'))).toBe('Running...'); + }); - it("reads typed detail values with defaults for absent or invalid details", () => { - const result = { - content: [{ type: "text", text: "" }], - details: { count: 2, path: "file.txt", deleted: true, invalid: null }, - }; + it('reads typed detail values with defaults for absent or invalid details', () => { + const result = { + content: [{ type: 'text', text: '' }], + details: { count: 2, path: 'file.txt', deleted: true, invalid: null }, + }; - expect(detailRecord(result)).toEqual(result.details); - expect(detailRecord({ details: null })).toEqual({}); - expect(numberDetail(result, "count")).toBe(2); - expect(numberDetail(result, "path")).toBe(0); - expect(stringDetail(result, "path")).toBe("file.txt"); - expect(stringDetail(result, "count")).toBe(""); - expect(booleanDetail(result, "deleted")).toBe(true); - expect(booleanDetail(result, "invalid")).toBe(false); - }); + expect(detailRecord(result)).toEqual(result.details); + expect(detailRecord({ details: null })).toEqual({}); + expect(numberDetail(result, 'count')).toBe(2); + expect(numberDetail(result, 'path')).toBe(0); + expect(stringDetail(result, 'path')).toBe('file.txt'); + expect(stringDetail(result, 'count')).toBe(''); + expect(booleanDetail(result, 'deleted')).toBe(true); + expect(booleanDetail(result, 'invalid')).toBe(false); + }); - it("formats file and command errors with stable empty details", () => { - expect(fileNotFound("/tmp/missing.txt", { deleted: false })).toEqual({ - content: [{ type: "text", text: "File not found: /tmp/missing.txt" }], - details: { path: "/tmp/missing.txt", deleted: false }, - }); - expect(fileError({}, "Read", "/tmp/file.txt", { totalLines: 0 })).toEqual({ - content: [{ type: "text", text: "Read error: Unknown error" }], - details: { path: "/tmp/file.txt", totalLines: 0 }, - }); - expect(toolError({ code: 1 }, "Grep", { matchCount: 0 })).toEqual({ - content: [{ type: "text", text: "No matches found" }], - details: { matchCount: 0 }, - }); - expect(toolError({}, "Grep", { matchCount: 0 })).toEqual({ - content: [{ type: "text", text: "Grep error: Unknown error" }], - details: { matchCount: 0 }, - }); - }); + it('formats file and command errors with stable empty details', () => { + expect(fileNotFound('/tmp/missing.txt', { deleted: false })).toEqual({ + content: [{ type: 'text', text: 'File not found: /tmp/missing.txt' }], + details: { path: '/tmp/missing.txt', deleted: false }, + }); + expect(fileError({}, 'Read', '/tmp/file.txt', { totalLines: 0 })).toEqual({ + content: [{ type: 'text', text: 'Read error: Unknown error' }], + details: { path: '/tmp/file.txt', totalLines: 0 }, + }); + expect(toolError({ code: 1 }, 'Grep', { matchCount: 0 })).toEqual({ + content: [{ type: 'text', text: 'No matches found' }], + details: { matchCount: 0 }, + }); + expect(toolError({}, 'Grep', { matchCount: 0 })).toEqual({ + content: [{ type: 'text', text: 'Grep error: Unknown error' }], + details: { matchCount: 0 }, + }); + }); }); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index 2a34811..fb865b3 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -1,226 +1,208 @@ -import { mkdirSync, writeFileSync } from "node:fs"; -import { join } from "node:path"; -import { describe, expect, it } from "vitest"; -import { registerSearchTools } from "../../src/tools/search.js"; +import { mkdirSync, writeFileSync } from 'node:fs'; +import { join } from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { registerSearchTools } from '../../src/tools/search.js'; import { - collectTools, - executePreparedTool, - executeTool, - firstText, - plainTheme, - renderText, - type ToolResult, - tempDir, -} from "./toolTestHelpers.js"; + collectTools, + executePreparedTool, + executeTool, + firstText, + plainTheme, + renderText, + type ToolResult, + tempDir, +} from './toolTestHelpers.js'; function setupProject() { - const dir = tempDir("pi-grok-cli-search-"); - mkdirSync(join(dir, "src")); - writeFileSync(join(dir, "src", "alpha.ts"), "needle\nhaystack\n", "utf-8"); - writeFileSync(join(dir, "src", "beta.md"), "needle in docs\n", "utf-8"); - writeFileSync(join(dir, "src", "gamma.ts"), "plain text\n", "utf-8"); - return dir; + const dir = tempDir('pi-grok-cli-search-'); + mkdirSync(join(dir, 'src')); + writeFileSync(join(dir, 'src', 'alpha.ts'), 'needle\nhaystack\n', 'utf-8'); + writeFileSync(join(dir, 'src', 'beta.md'), 'needle in docs\n', 'utf-8'); + writeFileSync(join(dir, 'src', 'gamma.ts'), 'plain text\n', 'utf-8'); + return dir; } function expectGrepResult(cwd: string, result: ToolResult) { - expect(firstText(result)).toContain( - `${join(cwd, "src", "alpha.ts")}:1:needle`, - ); - expect(firstText(result)).not.toContain("beta.md"); - expect(result.details).toEqual({ matchCount: 1 }); + expect(firstText(result)).toContain(`${join(cwd, 'src', 'alpha.ts')}:1:needle`); + expect(firstText(result)).not.toContain('beta.md'); + expect(result.details).toEqual({ matchCount: 1 }); } function expectGlobResult(cwd: string, result: ToolResult) { - expect(firstText(result)).toContain(join(cwd, "src", "alpha.ts")); - expect(firstText(result)).toContain(join(cwd, "src", "gamma.ts")); - expect(firstText(result)).not.toContain("beta.md"); - expect(result.details).toEqual({ fileCount: 2 }); + expect(firstText(result)).toContain(join(cwd, 'src', 'alpha.ts')); + expect(firstText(result)).toContain(join(cwd, 'src', 'gamma.ts')); + expect(firstText(result)).not.toContain('beta.md'); + expect(result.details).toEqual({ fileCount: 2 }); } -describe("search tools", () => { - it("greps matching file contents with include filters", async () => { - const cwd = setupProject(); - const result = await executeTool( - collectTools(registerSearchTools).get("Grep"), - { pattern: "needle", path: "src", include: "*.ts" }, - cwd, - ); - - expectGrepResult(cwd, result); - }); - - it("greps matching file contents with Cursor-style glob filters", async () => { - const cwd = setupProject(); - const result = await executePreparedTool( - collectTools(registerSearchTools).get("Grep"), - { pattern: "needle", path: "src", glob_filter: "*.ts" }, - cwd, - ); - - expectGrepResult(cwd, result); - }); - - it("reports no grep matches as an empty result", async () => { - const cwd = setupProject(); - const result = await executeTool( - collectTools(registerSearchTools).get("Grep"), - { pattern: "absent", path: "src" }, - cwd, - ); - - expect(firstText(result)).toBe("No matches found"); - expect(result.details).toEqual({ matchCount: 0 }); - }); - - it("reports grep command errors with empty match details", async () => { - const cwd = setupProject(); - const result = await executeTool( - collectTools(registerSearchTools).get("Grep"), - { pattern: "[", path: "src" }, - cwd, - ); - - expect(firstText(result).startsWith("Grep error:")).toBe(true); - expect(result.details).toEqual({ matchCount: 0 }); - }); - - it("globs files under the requested path", async () => { - const cwd = setupProject(); - const result = await executeTool( - collectTools(registerSearchTools).get("Glob"), - { pattern: "**/*.ts", path: "src" }, - cwd, - ); - - expectGlobResult(cwd, result); - }); - - it("globs files with Cursor-style glob pattern arguments", async () => { - const cwd = setupProject(); - const result = await executePreparedTool( - collectTools(registerSearchTools).get("Glob"), - { glob_pattern: "**/*.ts", path: "src" }, - cwd, - ); - - expectGlobResult(cwd, result); - }); - - it("reports empty glob command results", async () => { - const cwd = setupProject(); - const result = await executeTool( - collectTools(registerSearchTools).get("Glob"), - { pattern: "**/*.json", path: "src" }, - cwd, - ); - - expect(firstText(result)).toBe("No matches found"); - expect(result.details).toEqual({ fileCount: 0 }); - }); - - it("renders grep calls and result states", () => { - const grep = collectTools(registerSearchTools).get("Grep"); - const result = { - content: [{ type: "text", text: "src/alpha.ts:1:needle" }], - details: { matchCount: 1 }, - }; - - expect( - renderText( - grep?.renderCall?.( - { pattern: "needle", path: "src", include: "*.ts" }, - plainTheme, - ) ?? { render: () => [] }, - ), - ).toBe('Grep "needle" in src [*.ts]'); - expect( - renderText( - grep?.renderResult?.( - result, - { expanded: false, isPartial: false }, - plainTheme, - {}, - ) ?? { render: () => [] }, - ), - ).toBe("1 match(es)"); - expect( - renderText( - grep?.renderResult?.( - result, - { expanded: true, isPartial: false }, - plainTheme, - {}, - ) ?? { render: () => [] }, - ), - ).toBe("src/alpha.ts:1:needle"); - expect( - renderText( - grep?.renderResult?.( - { - content: [{ type: "text", text: "No matches found" }], - details: {}, - }, - { expanded: false, isPartial: false }, - plainTheme, - {}, - ) ?? { render: () => [] }, - ), - ).toBe("No matches"); - expect( - renderText( - grep?.renderResult?.( - result, - { expanded: false, isPartial: true }, - plainTheme, - {}, - ) ?? { render: () => [] }, - ), - ).toBe("Running..."); - }); - - it("renders glob calls and result states", () => { - const glob = collectTools(registerSearchTools).get("Glob"); - const result = { - content: [{ type: "text", text: "src/alpha.ts\nsrc/gamma.ts" }], - details: { fileCount: 2 }, - }; - - expect( - renderText( - glob?.renderCall?.({ pattern: "**/*.ts", path: "src" }, plainTheme) ?? { - render: () => [], - }, - ), - ).toBe("Glob **/*.ts in src"); - expect( - renderText( - glob?.renderResult?.( - result, - { expanded: false, isPartial: false }, - plainTheme, - {}, - ) ?? { render: () => [] }, - ), - ).toBe("2 file(s)"); - expect( - renderText( - glob?.renderResult?.( - { content: [{ type: "text", text: "No files found" }], details: {} }, - { expanded: false, isPartial: false }, - plainTheme, - {}, - ) ?? { render: () => [] }, - ), - ).toBe("No files"); - expect( - renderText( - glob?.renderResult?.( - result, - { expanded: false, isPartial: true }, - plainTheme, - {}, - ) ?? { render: () => [] }, - ), - ).toBe("Running..."); - }); +describe('search tools', () => { + it('greps matching file contents with include filters', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: 'needle', path: 'src', include: '*.ts' }, + cwd, + ); + + expectGrepResult(cwd, result); + }); + + it('greps matching file contents with Cursor-style glob filters', async () => { + const cwd = setupProject(); + const result = await executePreparedTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: 'needle', path: 'src', glob_filter: '*.ts' }, + cwd, + ); + + expectGrepResult(cwd, result); + }); + + it('reports no grep matches as an empty result', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: 'absent', path: 'src' }, + cwd, + ); + + expect(firstText(result)).toBe('No matches found'); + expect(result.details).toEqual({ matchCount: 0 }); + }); + + it('reports grep command errors with empty match details', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: '[', path: 'src' }, + cwd, + ); + + expect(firstText(result).startsWith('Grep error:')).toBe(true); + expect(result.details).toEqual({ matchCount: 0 }); + }); + + it('globs files under the requested path', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Glob'), + { pattern: '**/*.ts', path: 'src' }, + cwd, + ); + + expectGlobResult(cwd, result); + }); + + it('globs files with Cursor-style glob pattern arguments', async () => { + const cwd = setupProject(); + const result = await executePreparedTool( + collectTools(registerSearchTools).get('Glob'), + { glob_pattern: '**/*.ts', path: 'src' }, + cwd, + ); + + expectGlobResult(cwd, result); + }); + + it('reports empty glob command results', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Glob'), + { pattern: '**/*.json', path: 'src' }, + cwd, + ); + + expect(firstText(result)).toBe('No matches found'); + expect(result.details).toEqual({ fileCount: 0 }); + }); + + it('renders grep calls and result states', () => { + const grep = collectTools(registerSearchTools).get('Grep'); + const result = { + content: [{ type: 'text', text: 'src/alpha.ts:1:needle' }], + details: { matchCount: 1 }, + }; + + expect( + renderText( + grep?.renderCall?.({ pattern: 'needle', path: 'src', include: '*.ts' }, plainTheme) ?? { + render: () => [], + }, + ), + ).toBe('Grep "needle" in src [*.ts]'); + expect( + renderText( + grep?.renderResult?.(result, { expanded: false, isPartial: false }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('1 match(es)'); + expect( + renderText( + grep?.renderResult?.(result, { expanded: true, isPartial: false }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('src/alpha.ts:1:needle'); + expect( + renderText( + grep?.renderResult?.( + { + content: [{ type: 'text', text: 'No matches found' }], + details: {}, + }, + { expanded: false, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe('No matches'); + expect( + renderText( + grep?.renderResult?.(result, { expanded: false, isPartial: true }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('Running...'); + }); + + it('renders glob calls and result states', () => { + const glob = collectTools(registerSearchTools).get('Glob'); + const result = { + content: [{ type: 'text', text: 'src/alpha.ts\nsrc/gamma.ts' }], + details: { fileCount: 2 }, + }; + + expect( + renderText( + glob?.renderCall?.({ pattern: '**/*.ts', path: 'src' }, plainTheme) ?? { + render: () => [], + }, + ), + ).toBe('Glob **/*.ts in src'); + expect( + renderText( + glob?.renderResult?.(result, { expanded: false, isPartial: false }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('2 file(s)'); + expect( + renderText( + glob?.renderResult?.( + { content: [{ type: 'text', text: 'No files found' }], details: {} }, + { expanded: false, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe('No files'); + expect( + renderText( + glob?.renderResult?.(result, { expanded: false, isPartial: true }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('Running...'); + }); }); diff --git a/tests/tools/shell.test.ts b/tests/tools/shell.test.ts index 3640fa0..8612306 100644 --- a/tests/tools/shell.test.ts +++ b/tests/tools/shell.test.ts @@ -1,139 +1,129 @@ -import { writeFileSync } from "node:fs"; -import { join } from "node:path"; -import { describe, expect, it } from "vitest"; -import { registerShellTool } from "../../src/tools/shell.js"; +import { writeFileSync } from 'node:fs'; +import { join } from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { registerShellTool } from '../../src/tools/shell.js'; import { - collectTools, - executeTool, - firstText, - renderToolCall, - renderToolResult, - tempDir, -} from "./toolTestHelpers.js"; + collectTools, + executeTool, + firstText, + renderToolCall, + renderToolResult, + tempDir, +} from './toolTestHelpers.js'; -describe("shell tool", () => { - it("returns stdout, stderr, and exit zero details", async () => { - const cwd = tempDir("pi-grok-cli-shell-"); - const result = await executeTool( - collectTools(registerShellTool).get("Shell"), - { command: "printf stdout && printf stderr >&2" }, - cwd, - ); +describe('shell tool', () => { + it('returns stdout, stderr, and exit zero details', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'printf stdout && printf stderr >&2' }, + cwd, + ); - expect(firstText(result)).toBe("stdout\n[stderr]\nstderr"); - expect(result.details).toEqual({ - exitCode: 0, - command: "printf stdout && printf stderr >&2", - }); - }); + expect(firstText(result)).toBe('stdout\n[stderr]\nstderr'); + expect(result.details).toEqual({ + exitCode: 0, + command: 'printf stdout && printf stderr >&2', + }); + }); - it("runs commands in a resolved working directory", async () => { - const cwd = tempDir("pi-grok-cli-shell-"); - writeFileSync(join(cwd, "target.txt"), "from cwd", "utf-8"); - const result = await executeTool( - collectTools(registerShellTool).get("Shell"), - { command: "cat target.txt", working_directory: "." }, - cwd, - ); + it('runs commands in a resolved working directory', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + writeFileSync(join(cwd, 'target.txt'), 'from cwd', 'utf-8'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'cat target.txt', working_directory: '.' }, + cwd, + ); - expect(firstText(result)).toBe("from cwd"); - expect(result.details).toEqual({ - exitCode: 0, - command: "cat target.txt", - }); - }); + expect(firstText(result)).toBe('from cwd'); + expect(result.details).toEqual({ + exitCode: 0, + command: 'cat target.txt', + }); + }); - it("returns a clear placeholder when commands produce no output", async () => { - const cwd = tempDir("pi-grok-cli-shell-"); - const result = await executeTool( - collectTools(registerShellTool).get("Shell"), - { command: "true" }, - cwd, - ); + it('returns a clear placeholder when commands produce no output', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'true' }, + cwd, + ); - expect(firstText(result)).toBe("(no output)"); - expect(result.details).toEqual({ exitCode: 0, command: "true" }); - }); + expect(firstText(result)).toBe('(no output)'); + expect(result.details).toEqual({ exitCode: 0, command: 'true' }); + }); - it("includes exit code, error message, and captured output on failure", async () => { - const cwd = tempDir("pi-grok-cli-shell-"); - const result = await executeTool( - collectTools(registerShellTool).get("Shell"), - { command: "printf before && printf problem >&2 && exit 7" }, - cwd, - ); + it('includes exit code, error message, and captured output on failure', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'printf before && printf problem >&2 && exit 7' }, + cwd, + ); - expect(firstText(result)).toContain("Shell error (exit code 7):"); - expect(firstText(result)).toContain("before\n[stderr]\nproblem"); - expect(result.details).toEqual({ - exitCode: 7, - command: "printf before && printf problem >&2 && exit 7", - }); - }); + expect(firstText(result)).toContain('Shell error (exit code 7):'); + expect(firstText(result)).toContain('before\n[stderr]\nproblem'); + expect(result.details).toEqual({ + exitCode: 7, + command: 'printf before && printf problem >&2 && exit 7', + }); + }); - it("truncates large successful and failed output", async () => { - const cwd = tempDir("pi-grok-cli-shell-"); - const tools = collectTools(registerShellTool); - const largeOutput = "head -c 50001 /dev/zero | tr '\\0' x"; + it('truncates large successful and failed output', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const tools = collectTools(registerShellTool); + const largeOutput = "head -c 50001 /dev/zero | tr '\\0' x"; - const successResult = await executeTool( - tools.get("Shell"), - { command: largeOutput }, - cwd, - ); - const failureResult = await executeTool( - tools.get("Shell"), - { command: `${largeOutput}; exit 9` }, - cwd, - ); + const successResult = await executeTool(tools.get('Shell'), { command: largeOutput }, cwd); + const failureResult = await executeTool( + tools.get('Shell'), + { command: `${largeOutput}; exit 9` }, + cwd, + ); - expect(firstText(successResult)).toHaveLength( - "\n\n[Output truncated at 50KB]".length + 50_000, - ); - expect( - firstText(successResult).endsWith("[Output truncated at 50KB]"), - ).toBe(true); - expect(firstText(failureResult)).toContain("Shell error (exit code 9):"); - expect( - firstText(failureResult).endsWith("[Output truncated at 50KB]"), - ).toBe(true); - }); + expect(firstText(successResult)).toHaveLength('\n\n[Output truncated at 50KB]'.length + 50_000); + expect(firstText(successResult).endsWith('[Output truncated at 50KB]')).toBe(true); + expect(firstText(failureResult)).toContain('Shell error (exit code 9):'); + expect(firstText(failureResult).endsWith('[Output truncated at 50KB]')).toBe(true); + }); - it("renders shell calls and result states", () => { - const shell = collectTools(registerShellTool).get("Shell"); + it('renders shell calls and result states', () => { + const shell = collectTools(registerShellTool).get('Shell'); - expect( - renderToolCall(shell, { - command: "pwd", - working_directory: "src", - }), - ).toBe("Shell pwd in src"); - expect(renderToolCall(shell, { command: "pwd" })).toBe("Shell pwd"); - expect( - renderToolResult(shell, { - content: [{ type: "text", text: "full output" }], - details: { exitCode: 0 }, - }), - ).toBe("Exit 0"); - expect( - renderToolResult( - shell, - { - content: [{ type: "text", text: "full output" }], - details: { exitCode: 0 }, - }, - { expanded: true, isPartial: false }, - ), - ).toBe("full output"); - expect( - renderToolResult( - shell, - { - content: [{ type: "text", text: "still running" }], - details: { exitCode: 0 }, - }, - { expanded: false, isPartial: true }, - ), - ).toBe("Running..."); - }); + expect( + renderToolCall(shell, { + command: 'pwd', + working_directory: 'src', + }), + ).toBe('Shell pwd in src'); + expect(renderToolCall(shell, { command: 'pwd' })).toBe('Shell pwd'); + expect( + renderToolResult(shell, { + content: [{ type: 'text', text: 'full output' }], + details: { exitCode: 0 }, + }), + ).toBe('Exit 0'); + expect( + renderToolResult( + shell, + { + content: [{ type: 'text', text: 'full output' }], + details: { exitCode: 0 }, + }, + { expanded: true, isPartial: false }, + ), + ).toBe('full output'); + expect( + renderToolResult( + shell, + { + content: [{ type: 'text', text: 'still running' }], + details: { exitCode: 0 }, + }, + { expanded: false, isPartial: true }, + ), + ).toBe('Running...'); + }); }); diff --git a/tests/tools/toolTestHelpers.ts b/tests/tools/toolTestHelpers.ts index e24011f..eca9206 100644 --- a/tests/tools/toolTestHelpers.ts +++ b/tests/tools/toolTestHelpers.ts @@ -1,130 +1,118 @@ -import { mkdtempSync, rmSync } from "node:fs"; -import { tmpdir } from "node:os"; -import { join } from "node:path"; -import type { ExtensionAPI } from "@earendil-works/pi-coding-agent"; -import { afterEach } from "vitest"; +import { mkdtempSync, rmSync } from 'node:fs'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { afterEach } from 'vitest'; const tempDirs: string[] = []; afterEach(() => { - for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); + for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); }); export type ToolResult = { - content: { type: string; text?: string }[]; - details: Record; + content: { type: string; text?: string }[]; + details: Record; }; type Renderable = { render: (width: number) => string[] }; type ToolTheme = { - bold: (text: string) => string; - fg: (name: string, text: string) => string; + bold: (text: string) => string; + fg: (name: string, text: string) => string; }; type RegisteredTool = { - name: string; - prepareArguments?: ( - params: Record, - ) => Record; - execute: ( - toolCallId: string, - params: Record, - signal: AbortSignal, - onUpdate: () => void, - ctx: { cwd: string }, - ) => Promise; - renderCall?: (args: Record, theme: ToolTheme) => Renderable; - renderResult?: ( - result: ToolResult, - state: { expanded: boolean; isPartial: boolean }, - theme: ToolTheme, - args: Record, - ) => Renderable; + name: string; + prepareArguments?: (params: Record) => Record; + execute: ( + toolCallId: string, + params: Record, + signal: AbortSignal, + onUpdate: () => void, + ctx: { cwd: string }, + ) => Promise; + renderCall?: (args: Record, theme: ToolTheme) => Renderable; + renderResult?: ( + result: ToolResult, + state: { expanded: boolean; isPartial: boolean }, + theme: ToolTheme, + args: Record, + ) => Renderable; }; export function collectTools(registerTools: (pi: ExtensionAPI) => void) { - const tools = new Map(); - registerTools({ - registerTool(tool: RegisteredTool) { - tools.set(tool.name, tool); - }, - } as unknown as ExtensionAPI); - return tools; + const tools = new Map(); + registerTools({ + registerTool(tool: RegisteredTool) { + tools.set(tool.name, tool); + }, + } as unknown as ExtensionAPI); + return tools; } export async function executeTool( - tool: RegisteredTool | undefined, - params: Record, - cwd: string, + tool: RegisteredTool | undefined, + params: Record, + cwd: string, ) { - if (!tool) throw new Error("Tool was not registered"); - return tool.execute( - "tool-call-id", - params, - new AbortController().signal, - () => {}, - { - cwd, - }, - ); + if (!tool) throw new Error('Tool was not registered'); + return tool.execute('tool-call-id', params, new AbortController().signal, () => {}, { + cwd, + }); } export function prepareToolArguments( - tool: RegisteredTool | undefined, - params: Record, + tool: RegisteredTool | undefined, + params: Record, ) { - if (!tool) throw new Error("Tool was not registered"); - return tool.prepareArguments?.(params) ?? params; + if (!tool) throw new Error('Tool was not registered'); + return tool.prepareArguments?.(params) ?? params; } export async function executePreparedTool( - tool: RegisteredTool | undefined, - params: Record, - cwd: string, + tool: RegisteredTool | undefined, + params: Record, + cwd: string, ) { - if (!tool) throw new Error("Tool was not registered"); - return executeTool(tool, prepareToolArguments(tool, params), cwd); + if (!tool) throw new Error('Tool was not registered'); + return executeTool(tool, prepareToolArguments(tool, params), cwd); } export function firstText(result: ToolResult) { - return result.content[0]?.text ?? ""; + return result.content[0]?.text ?? ''; } export function renderText(component: { render: (width: number) => string[] }) { - return component - .render(120) - .map((line) => line.trimEnd()) - .join("\n"); + return component + .render(120) + .map((line) => line.trimEnd()) + .join('\n'); } export const plainTheme = { - bold: (text: string) => text, - fg: (_name: string, text: string) => text, + bold: (text: string) => text, + fg: (_name: string, text: string) => text, }; -export function renderToolCall( - tool: RegisteredTool | undefined, - args: Record, -) { - if (!tool?.renderCall) - throw new Error("Tool call renderer was not registered"); - return renderText(tool.renderCall(args, plainTheme)); +export function renderToolCall(tool: RegisteredTool | undefined, args: Record) { + if (!tool?.renderCall) throw new Error('Tool call renderer was not registered'); + return renderText(tool.renderCall(args, plainTheme)); } export function renderToolResult( - tool: RegisteredTool | undefined, - result: ToolResult, - state = { expanded: false, isPartial: false }, + tool: RegisteredTool | undefined, + result: ToolResult, + state = { expanded: false, isPartial: false }, ) { - if (!tool?.renderResult) { - throw new Error("Tool result renderer was not registered"); - } - return renderText(tool.renderResult(result, state, plainTheme, {})); + if (!tool?.renderResult) { + throw new Error('Tool result renderer was not registered'); + } + return renderText(tool.renderResult(result, state, plainTheme, {})); } export function tempDir(prefix: string) { - const dir = mkdtempSync(join(tmpdir(), prefix)); - tempDirs.push(dir); - return dir; + const dir = mkdtempSync(join(tmpdir(), prefix)); + tempDirs.push(dir); + return dir; } diff --git a/tsconfig.json b/tsconfig.json index 73fa4c9..e9b2d7a 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,18 +1,18 @@ { - "compilerOptions": { - "target": "ES2022", - "module": "ES2022", - "moduleResolution": "bundler", - "strict": true, - "esModuleInterop": true, - "skipLibCheck": true, - "forceConsistentCasingInFileNames": true, - "resolveJsonModule": true, - "declaration": true, - "declarationMap": true, - "sourceMap": true, - "outDir": "./dist" - }, - "include": ["src/**/*.ts"], - "exclude": ["node_modules", "dist"] + "compilerOptions": { + "target": "ES2022", + "module": "ES2022", + "moduleResolution": "bundler", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist" + }, + "include": ["src/**/*.ts"], + "exclude": ["node_modules", "dist"] } diff --git a/vitest.config.ts b/vitest.config.ts index 3f9296d..0e6e4d1 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -1,12 +1,12 @@ -import { defineConfig } from "vitest/config"; +import { defineConfig } from 'vitest/config'; export default defineConfig({ - test: { - coverage: { - provider: "v8", - reporter: ["text", "lcov"], - include: ["src/**/*.ts"], - exclude: ["src/index.ts"], - }, - }, + test: { + coverage: { + provider: 'v8', + reporter: ['text', 'lcov'], + include: ['src/**/*.ts'], + exclude: ['src/index.ts'], + }, + }, }); From dff7f828807676cc8909c7f400bf715472c435b8 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Tue, 2 Jun 2026 23:46:33 +0900 Subject: [PATCH 08/24] docs: add Cursor tool compatibility documentation to README --- README.md | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 3570791..6d5ed50 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,8 @@ [![Version](https://img.shields.io/github/v/tag/kenryu42/pi-grok-cli?label=version&color=blue)](https://github.com/kenryu42/pi-grok-cli) [![License: MIT](https://img.shields.io/badge/License-MIT-red.svg)](https://opensource.org/licenses/MIT) -A pi extension that connects to **Grok CLI's API endpoint** . - -## Why? - -The Grok CLI uhas access to models **not available** on the public `api.x.ai` API yet: +A pi extension that connects to **Grok CLI's API endpoint**. +The Grok CLI has access to models **not available** on the public `api.x.ai` API yet: | Model | Public API (`api.x.ai`) | Grok CLI | |---|---|---| @@ -18,6 +15,16 @@ The Grok CLI uhas access to models **not available** on the public `api.x.ai` AP `grok-composer-2.5-fast` is Cursor's Composer 2.5 model, a purpose-built agentic coding model optimized for long-horizon coding tasks. +## Cursor Tool Compatibility + +Grok CLI models are trained to use Cursor-style coding tools. This extension includes compatibility shims so those models can keep using familiar tool calls inside pi: + +- File tools: `Read`, `Write`, `StrReplace`, `Edit`, `Delete`, and `LS` +- Search tools: `Grep` and `Glob` +- Terminal tool: `Shell` + +The shims also normalize common Cursor/Grok argument shapes, such as `contents` for writes, `glob_pattern` for file search, `glob_filter` for grep filters, and `old_string`/`new_string` or `oldText`/`newText` for exact replacements. This keeps agentic coding workflows moving instead of failing on tool schema mismatches. + ## Requirements You need an active Grok subscription or an X Premium subscription with Grok access to use this extension. @@ -66,4 +73,4 @@ Select **"Grok CLI"** from the provider list. This opens the xAI OAuth page in y | `PI_GROK_CLI_MODELS` | (all models) | Comma-separated model IDs to expose | | `PI_GROK_CLI_OAUTH_CLIENT_ID` | `b1a00492-...` | Override OAuth client ID | | `PI_GROK_CLI_OAUTH_SCOPE` | `openid profile email offline_access grok-cli:access api:access` | Override OAuth scopes | -| `GROK_CLI_OAUTH_TOKEN` | — | Direct token bypass (no auto-refresh) | \ No newline at end of file +| `GROK_CLI_OAUTH_TOKEN` | — | Direct token bypass (no auto-refresh) | From 0f79e0d169be65567b9eda08198427a0acb71d06 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 00:04:31 +0900 Subject: [PATCH 09/24] fix: improve file editing, grep robustness, and tool error handling while sanitizing reasoning efforts --- src/payload/sanitize.ts | 8 +++--- src/tools/files.ts | 13 +++++++-- src/tools/rendering.ts | 2 +- src/tools/search.ts | 15 ++++++++--- src/tools/shell.ts | 13 ++++----- tests/payload/sanitize.test.ts | 2 +- tests/tools/files.test.ts | 48 ++++++++++++++++++++++++++++++++++ tests/tools/rendering.test.ts | 6 +++++ tests/tools/search.test.ts | 28 +++++++++++++++++++- tests/tools/shell.test.ts | 6 +++++ 10 files changed, 121 insertions(+), 20 deletions(-) diff --git a/src/payload/sanitize.ts b/src/payload/sanitize.ts index 2b3ccf5..75539e4 100644 --- a/src/payload/sanitize.ts +++ b/src/payload/sanitize.ts @@ -284,11 +284,9 @@ export function sanitizePayload( // ── Reasoning effort ────────────────────────────────────────────────── if (supportsReasoningEffort(modelId)) { const reasoning = next.reasoning as Record | undefined; - if (reasoning && reasoning.effort === 'minimal') { - next.reasoning = { ...reasoning, effort: 'low' }; - } - if (reasoning && reasoning.summary !== undefined) { - next.reasoning = { effort: reasoning.effort }; + if (reasoning) { + const effort = reasoning.effort === 'minimal' ? 'low' : reasoning.effort; + next.reasoning = reasoning.summary !== undefined ? { effort } : { ...reasoning, effort }; } } else { delete next.reasoning; diff --git a/src/tools/files.ts b/src/tools/files.ts index e5ee3ac..719220d 100644 --- a/src/tools/files.ts +++ b/src/tools/files.ts @@ -83,7 +83,9 @@ function applyEdits(content: string, edits: ReplacementEdit[]) { const count = result.content.split(edit.oldText).length - 1; return { content: - count === 0 ? result.content : result.content.replaceAll(edit.oldText, edit.newText), + count === 0 + ? result.content + : result.content.replaceAll(edit.oldText, () => edit.newText), replacements: result.replacements + count, }; }, @@ -383,6 +385,10 @@ export function registerFileTools(pi: ExtensionAPI) { } const content = readFileSync(filePath, 'utf-8'); + if (params.old_str === '') { + return replacementResult('StrReplace error: old_str must not be empty', filePath); + } + const count = content.split(params.old_str).length - 1; if (count === 0) { @@ -392,7 +398,7 @@ export function registerFileTools(pi: ExtensionAPI) { ); } - const newContent = content.replaceAll(params.old_str, params.new_str); + const newContent = content.replaceAll(params.old_str, () => params.new_str); writeFileSync(filePath, newContent, 'utf-8'); return { @@ -494,6 +500,9 @@ export function registerFileTools(pi: ExtensionAPI) { details: { path: filePath, replacements: 0 }, }; } + if (params.edits.some((edit) => edit.oldText === '')) { + return replacementResult('Edit error: oldText must not be empty', filePath); + } const result = applyEdits(readFileSync(filePath, 'utf-8'), params.edits); diff --git a/src/tools/rendering.ts b/src/tools/rendering.ts index c442677..4e7eaf5 100644 --- a/src/tools/rendering.ts +++ b/src/tools/rendering.ts @@ -141,7 +141,7 @@ export function fileError( export function toolError(error: unknown, toolName: string, emptyDetails: T): ToolResult { const err = error as ToolError; - if (err.code === 1) { + if (toolName === 'Grep' && err.code === 1) { return { content: [{ type: 'text', text: 'No matches found' }], details: emptyDetails, diff --git a/src/tools/search.ts b/src/tools/search.ts index 34b9b63..9273e29 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -60,13 +60,13 @@ export function registerSearchTools(pi: ExtensionAPI) { const searchPath = resolve(ctx.cwd, params.path ?? '.'); try { - const rgArgs = ['-n', '--no-heading', '--color=never']; + const rgArgs = ['-n', '-H', '--no-heading', '--color=never']; if (params.include) rgArgs.push('--glob', params.include); - rgArgs.push(params.pattern, searchPath); + rgArgs.push('--', params.pattern, searchPath); - const grepArgs = ['-r', '-n', '--color=never']; + const grepArgs = ['-r', '-n', '-H', '--color=never']; if (params.include) grepArgs.push(`--include=${params.include}`); - grepArgs.push(params.pattern, searchPath); + grepArgs.push('--', params.pattern, searchPath); const stdout = await execWithRgFallback(rgArgs, grepArgs, { cwd: ctx.cwd, @@ -176,6 +176,13 @@ export function registerSearchTools(pi: ExtensionAPI) { details: { fileCount: files.length }, }; } catch (error: unknown) { + const err = error as { code?: unknown; stderr?: string }; + if (err.code === 1 && !err.stderr) { + return { + content: [{ type: 'text', text: 'No files found' }], + details: { fileCount: 0 }, + }; + } return toolError(error, 'Glob', { fileCount: 0 }); } }, diff --git a/src/tools/shell.ts b/src/tools/shell.ts index b8be923..23278aa 100644 --- a/src/tools/shell.ts +++ b/src/tools/shell.ts @@ -4,8 +4,8 @@ import { promisify } from 'node:util'; import { Type } from '@earendil-works/pi-ai'; import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { + detailRecord, MAX_OUTPUT_CHARS, - numberDetail, renderResultText, renderRunning, text, @@ -64,11 +64,12 @@ export function registerShellTool(pi: ExtensionAPI) { }; } catch (error: unknown) { const err = error as { - code?: number; + code?: unknown; message?: string; stdout?: string; stderr?: string; }; + const exitCode = typeof err.code === 'number' ? err.code : 1; let output = ''; if (err.stdout) output += err.stdout; @@ -86,7 +87,7 @@ export function registerShellTool(pi: ExtensionAPI) { }, ], details: { - exitCode: err.code ?? 1, + exitCode, command: params.command, }, }; @@ -101,12 +102,12 @@ export function registerShellTool(pi: ExtensionAPI) { renderResult(result, { expanded, isPartial }, theme) { const running = renderRunning(isPartial); if (running) return running; + const exitCode = + typeof detailRecord(result).exitCode === 'number' ? detailRecord(result).exitCode : 1; return renderResultText( result, expanded, - numberDetail(result, 'exitCode') === 0 - ? theme.fg('muted', 'Exit 0') - : theme.fg('warning', `Exit ${numberDetail(result, 'exitCode')}`), + exitCode === 0 ? theme.fg('muted', 'Exit 0') : theme.fg('warning', `Exit ${exitCode}`), ); }, }); diff --git a/tests/payload/sanitize.test.ts b/tests/payload/sanitize.test.ts index 6c925c4..a7063ec 100644 --- a/tests/payload/sanitize.test.ts +++ b/tests/payload/sanitize.test.ts @@ -37,7 +37,7 @@ describe('payload sanitization', () => { expect(payload.input).toEqual([{ role: 'user', content: 'hello' }]); expect(payload.include).toEqual(['message.output_text']); expect(payload.prompt_cache_retention).toBeUndefined(); - expect(payload.reasoning).toEqual({ effort: 'minimal' }); + expect(payload.reasoning).toEqual({ effort: 'low' }); expect(payload.text).toEqual({ format: { type: 'json_object' } }); expect(payload.response_format).toBeUndefined(); expect(payload.prompt_cache_key).toBe('session-123'); diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts index 3a44bbe..d6d5da1 100644 --- a/tests/tools/files.test.ts +++ b/tests/tools/files.test.ts @@ -163,6 +163,26 @@ describe('file tools', () => { expectStoryState(result, cwd, 2, 'green blue green'); }); + it('rejects empty replacement search strings without changing files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await strReplace(cwd, '', 'green'); + + expect(firstText(result)).toBe('StrReplace error: old_str must not be empty'); + expectStoryState(result, cwd, 0, 'red blue red'); + }); + + it('treats replacement text as a literal string', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'abc', 'utf-8'); + + const result = await strReplace(cwd, 'a', '$&'); + + expect(firstText(result)).toBe('Replaced 1 occurrence(s) in story.txt'); + expectStoryState(result, cwd, 1, '$&bc'); + }); + it('replaces string occurrences with Grok and Cursor argument variants', async () => { const oldStringCwd = tempDir('pi-grok-cli-files-'); writeFileSync(join(oldStringCwd, 'story.txt'), 'red blue red', 'utf-8'); @@ -244,6 +264,34 @@ describe('file tools', () => { expectStoryState(stringifiedResult, stringifiedCwd, 2, 'green blue green'); }); + it('edits files with literal replacement text', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'abc', 'utf-8'); + + const result = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { path: 'story.txt', oldText: 'a', newText: '$&' }, + cwd, + ); + + expect(firstText(result)).toBe('Applied 1 replacement(s) in story.txt'); + expectStoryState(result, cwd, 1, '$&bc'); + }); + + it('rejects empty edit search strings without changing files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { path: 'story.txt', oldText: '', newText: 'green' }, + cwd, + ); + + expect(firstText(result)).toBe('Edit error: oldText must not be empty'); + expectStoryState(result, cwd, 0, 'red blue red'); + }); + it('reports unsupported edit strategies without changing files', async () => { const cwd = tempDir('pi-grok-cli-files-'); writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); diff --git a/tests/tools/rendering.test.ts b/tests/tools/rendering.test.ts index c4db165..e54809c 100644 --- a/tests/tools/rendering.test.ts +++ b/tests/tools/rendering.test.ts @@ -78,6 +78,12 @@ describe('tool rendering helpers', () => { content: [{ type: 'text', text: 'No matches found' }], details: { matchCount: 0 }, }); + expect( + toolError({ code: 1, message: 'find: missing: No such file' }, 'Glob', { fileCount: 0 }), + ).toEqual({ + content: [{ type: 'text', text: 'Glob error: find: missing: No such file' }], + details: { fileCount: 0 }, + }); expect(toolError({}, 'Grep', { matchCount: 0 })).toEqual({ content: [{ type: 'text', text: 'Grep error: Unknown error' }], details: { matchCount: 0 }, diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index fb865b3..3fefbe3 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -58,6 +58,32 @@ describe('search tools', () => { expectGrepResult(cwd, result); }); + it('greps patterns that start with a dash', async () => { + const cwd = setupProject(); + writeFileSync(join(cwd, 'src', 'dash.ts'), '-export const value = 1\n', 'utf-8'); + + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: '-export', path: 'src/dash.ts' }, + cwd, + ); + + expect(firstText(result)).toBe(`${join(cwd, 'src', 'dash.ts')}:1:-export const value = 1`); + expect(result.details).toEqual({ matchCount: 1 }); + }); + + it('includes file paths when grepping a single file', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: 'needle', path: 'src/alpha.ts' }, + cwd, + ); + + expect(firstText(result)).toBe(`${join(cwd, 'src', 'alpha.ts')}:1:needle`); + expect(result.details).toEqual({ matchCount: 1 }); + }); + it('reports no grep matches as an empty result', async () => { const cwd = setupProject(); const result = await executeTool( @@ -112,7 +138,7 @@ describe('search tools', () => { cwd, ); - expect(firstText(result)).toBe('No matches found'); + expect(firstText(result)).toBe('No files found'); expect(result.details).toEqual({ fileCount: 0 }); }); diff --git a/tests/tools/shell.test.ts b/tests/tools/shell.test.ts index 8612306..b4e99c8 100644 --- a/tests/tools/shell.test.ts +++ b/tests/tools/shell.test.ts @@ -105,6 +105,12 @@ describe('shell tool', () => { details: { exitCode: 0 }, }), ).toBe('Exit 0'); + expect( + renderToolResult(shell, { + content: [{ type: 'text', text: 'spawn failed' }], + details: { exitCode: 'ENOENT' }, + }), + ).toBe('Exit 1'); expect( renderToolResult( shell, From feafff5fc6133553799627873c9937aa7c31af5b Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 00:21:10 +0900 Subject: [PATCH 10/24] refactor: update catalog model structure to support dynamic provider configurations --- src/models/catalog.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/models/catalog.ts b/src/models/catalog.ts index 5fe3eb7..50e3bb8 100644 --- a/src/models/catalog.ts +++ b/src/models/catalog.ts @@ -4,8 +4,8 @@ // ─── Cost constants ($/M tokens) ────────────────────────────────────────────── -const COST_BUILD = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }; -const COST_COMPOSER = { input: 3, output: 15, cacheRead: 0, cacheWrite: 0 }; +const COST_BUILD = { input: 1, output: 2, cacheRead: 0.2, cacheWrite: 0.2 }; +const COST_COMPOSER_FAST = { input: 3, output: 15, cacheRead: 0.5, cacheWrite: 0 }; const COST_43 = { input: 1.25, output: 2.5, cacheRead: 0.2, cacheWrite: 0 }; const COST_420 = { input: 2, output: 6, cacheRead: 0.2, cacheWrite: 0 }; @@ -39,7 +39,7 @@ const FALLBACK_MODELS: GrokCliModelConfig[] = [ name: 'Composer 2.5 Fast (Grok CLI)', reasoning: false, input: ['text', 'image'], - cost: COST_COMPOSER, + cost: COST_COMPOSER_FAST, contextWindow: 200_000, maxTokens: 30_000, thinkingLevelMap: { From 2096679835211a642569d7b3a693a8b772c6c71f Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 01:58:04 +0900 Subject: [PATCH 11/24] fix: improve tool output handling, file system robustness, quota caching, and glob search functionality --- src/provider/quota.ts | 8 ++++- src/tools/files.ts | 24 ++++++++----- src/tools/rendering.ts | 5 +-- src/tools/search.ts | 61 +++++++++++++++++++++++++------ src/tools/shell.ts | 3 +- tests/provider/register.test.ts | 27 +++++++++++++- tests/tools/files.test.ts | 64 +++++++++++++++++++++++++++++++++ tests/tools/search.test.ts | 43 ++++++++++++++++++++-- tests/tools/shell.test.ts | 13 +++++++ 9 files changed, 221 insertions(+), 27 deletions(-) diff --git a/src/provider/quota.ts b/src/provider/quota.ts index 3295fcb..ccbc038 100644 --- a/src/provider/quota.ts +++ b/src/provider/quota.ts @@ -76,7 +76,13 @@ function extractRateLimit(h: Record): RateLimitInfo | undefined const limitTokens = Number(h['x-ratelimit-limit-tokens']); const contextWindow = Number(h['x-grok-context-window']); - if (Number.isNaN(remainingReqs) && Number.isNaN(remainingTokens)) return undefined; + if ( + [remainingReqs, limitReqs, remainingTokens, limitTokens].some( + (value) => !Number.isFinite(value), + ) + ) { + return undefined; + } return { remainingRequests: remainingReqs, diff --git a/src/tools/files.ts b/src/tools/files.ts index 719220d..88260e4 100644 --- a/src/tools/files.ts +++ b/src/tools/files.ts @@ -9,6 +9,7 @@ import { detailRecord, fileError, fileNotFound, + MAX_OUTPUT_BYTES, MAX_OUTPUT_CHARS, numberDetail, recordFrom, @@ -142,7 +143,7 @@ export function registerFileTools(pi: ExtensionAPI) { try { const { stdout } = await execFileAsync('ls', ['-la', targetPath], { cwd: ctx.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, + maxBuffer: MAX_OUTPUT_BYTES, signal, }); @@ -214,12 +215,15 @@ export function registerFileTools(pi: ExtensionAPI) { } const content = readFileSync(filePath, 'utf-8'); - const lines = content.split('\n'); + const lines = content.endsWith('\n') + ? content.slice(0, -1).split('\n') + : content.split('\n'); const startLine = params.offset ?? 0; - const endLine = params.limit - ? Math.min(startLine + params.limit, lines.length) - : Math.min(startLine + 2000, lines.length); + const endLine = + params.limit !== undefined + ? Math.min(startLine + params.limit, lines.length) + : Math.min(startLine + 2000, lines.length); const selectedLines = lines.slice(startLine, endLine); const numberedLines = selectedLines.map((line, i) => `${startLine + i + 1}\t${line}`); @@ -238,8 +242,9 @@ export function registerFileTools(pi: ExtensionAPI) { details: { path: filePath, totalLines: lines.length }, }; } catch (error: unknown) { + const err = error as { code?: string }; return fileError(error, 'Read', filePath, { - exists: false, + exists: err.code !== 'ENOENT', totalLines: 0, }); } @@ -249,7 +254,7 @@ export function registerFileTools(pi: ExtensionAPI) { args.offset !== undefined || args.limit !== undefined ? theme.fg( 'muted', - ` (from ${args.offset ?? 0}${args.limit ? `, ${args.limit} lines` : ''})`, + ` (from ${args.offset ?? 0}${args.limit !== undefined ? `, ${args.limit} lines` : ''})`, ) : ''; return text( @@ -301,15 +306,16 @@ export function registerFileTools(pi: ExtensionAPI) { try { mkdirSync(dirname(filePath), { recursive: true }); writeFileSync(filePath, params.content, 'utf-8'); + const bytesWritten = Buffer.byteLength(params.content, 'utf8'); return { content: [ { type: 'text', - text: `Successfully wrote ${params.content.length} bytes to ${params.path}`, + text: `Successfully wrote ${bytesWritten} bytes to ${params.path}`, }, ], - details: { path: filePath, bytesWritten: params.content.length }, + details: { path: filePath, bytesWritten }, }; } catch (error: unknown) { const err = error as ToolError; diff --git a/src/tools/rendering.ts b/src/tools/rendering.ts index 4e7eaf5..0a3e2d4 100644 --- a/src/tools/rendering.ts +++ b/src/tools/rendering.ts @@ -5,6 +5,7 @@ import { Text } from '@earendil-works/pi-tui'; const execFileAsync = promisify(execFile); export const MAX_OUTPUT_CHARS = 50_000; +export const MAX_OUTPUT_BYTES = MAX_OUTPUT_CHARS * 4 + 1024; export const MAX_LINES = 500; export function recordFrom(value: unknown): Record | undefined { @@ -166,14 +167,14 @@ export async function execWithRgFallback( if (await hasRipgrep()) { const result = await execFileAsync('rg', rgArgs, { cwd: options.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, + maxBuffer: MAX_OUTPUT_BYTES, signal: options.signal, }); return result.stdout; } const result = await execFileAsync('grep', grepArgs, { cwd: options.cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, + maxBuffer: MAX_OUTPUT_BYTES, signal: options.signal, }); return result.stdout; diff --git a/src/tools/search.ts b/src/tools/search.ts index 9273e29..b068cb2 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -1,12 +1,13 @@ import { execFile } from 'node:child_process'; -import { resolve } from 'node:path'; +import { statSync } from 'node:fs'; +import { relative, resolve } from 'node:path'; import { promisify } from 'node:util'; import { Type } from '@earendil-works/pi-ai'; import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { execWithRgFallback, hasRipgrep, - MAX_OUTPUT_CHARS, + MAX_OUTPUT_BYTES, numberDetail, recordFrom, renderResultText, @@ -23,6 +24,40 @@ const execFileAsync = promisify(execFile); type GrepArgs = { pattern: string; path?: string; include?: string }; type GlobArgs = { pattern: string; path?: string }; +function globToRegExp(pattern: string) { + let source = '^'; + for (let i = 0; i < pattern.length; i += 1) { + const char = pattern[i]; + const next = pattern[i + 1]; + if (char === '*' && next === '*' && pattern[i + 2] === '/') { + source += '(?:.*/)?'; + i += 2; + } else if (char === '*' && next === '*') { + source += '.*'; + i += 1; + } else if (char === '*') { + source += '[^/]*'; + } else if (char === '?') { + source += '[^/]'; + } else { + source += char.replace(/[|\\{}()[\]^$+?.]/g, '\\$&'); + } + } + return new RegExp(`${source}$`); +} + +function normalizePath(filePath: string) { + return filePath.replaceAll('\\', '/'); +} + +function sortByModifiedNewest(files: string[]) { + return files.sort((a, b) => { + const delta = statSync(b).mtimeMs - statSync(a).mtimeMs; + if (delta !== 0) return delta; + return a.localeCompare(b); + }); +} + export function registerSearchTools(pi: ExtensionAPI) { const GrepParams = Type.Object({ pattern: Type.String({ @@ -150,19 +185,23 @@ export function registerSearchTools(pi: ExtensionAPI) { const result = await execFileAsync( 'rg', ['--files', '--color=never', '--glob', params.pattern, searchPath], - { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, + { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_BYTES, signal }, ); files = result.stdout.trim().split('\n').filter(Boolean); } else { - // find fallback — convert **/*.ext → -name "*.ext" - const basename = params.pattern.replace(/^(\*\*\/)+/, ''); - const result = await execFileAsync( - 'find', - [searchPath, '-type', 'f', '-name', basename], - { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_CHARS * 2, signal }, - ); - files = result.stdout.trim().split('\n').filter(Boolean); + const matcher = globToRegExp(normalizePath(params.pattern)); + const result = await execFileAsync('find', [searchPath, '-type', 'f'], { + cwd: ctx.cwd, + maxBuffer: MAX_OUTPUT_BYTES, + signal, + }); + files = result.stdout + .trim() + .split('\n') + .filter(Boolean) + .filter((file) => matcher.test(normalizePath(relative(ctx.cwd, file)))); } + files = sortByModifiedNewest(files); if (files.length === 0) { return { diff --git a/src/tools/shell.ts b/src/tools/shell.ts index 23278aa..a0d5dcf 100644 --- a/src/tools/shell.ts +++ b/src/tools/shell.ts @@ -5,6 +5,7 @@ import { Type } from '@earendil-works/pi-ai'; import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { detailRecord, + MAX_OUTPUT_BYTES, MAX_OUTPUT_CHARS, renderResultText, renderRunning, @@ -45,7 +46,7 @@ export function registerShellTool(pi: ExtensionAPI) { try { const { stdout, stderr } = await execFileAsync('bash', ['-c', params.command], { cwd, - maxBuffer: MAX_OUTPUT_CHARS * 2, + maxBuffer: MAX_OUTPUT_BYTES, timeout, signal, }); diff --git a/tests/provider/register.test.ts b/tests/provider/register.test.ts index d44045a..8b0873c 100644 --- a/tests/provider/register.test.ts +++ b/tests/provider/register.test.ts @@ -1,4 +1,4 @@ -import { mkdirSync, mkdtempSync, readFileSync, rmSync, writeFileSync } from 'node:fs'; +import { existsSync, mkdirSync, mkdtempSync, readFileSync, rmSync, writeFileSync } from 'node:fs'; import { tmpdir } from 'node:os'; import { join } from 'node:path'; import type { ExtensionAPI, ProviderConfig } from '@earendil-works/pi-coding-agent'; @@ -231,6 +231,31 @@ describe('Grok CLI status command', () => { ).toBe(179); }); + it('ignores incomplete quota headers instead of caching NaN values', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + const home = setupHome(); + streamSimpleOpenAIResponses.mockImplementationOnce((_model, _context, options) => { + options?.onResponse?.({ + headers: { + 'x-ratelimit-remaining-tokens': '7500000', + 'x-ratelimit-limit-tokens': '7500000', + }, + }); + return {}; + }); + const extension = await setupExtension(); + extension.providers + .get('grok-cli') + ?.streamSimple?.({ provider: 'grok-cli', id: 'grok-build' }, {}, {}); + const notify = await runStatus(extension); + + expect(existsSync(join(home, '.pi', 'grok-cli-quota.json'))).toBe(false); + expect(notify.mock.calls.at(-1)?.[0]).not.toContain('NaN'); + expect(notify.mock.calls.at(-1)?.[0]).toContain( + 'no cached quota data — make a request with this model first', + ); + }); + it('loads cached quotas from the global pi config directory', async () => { delete process.env.GROK_CLI_OAUTH_TOKEN; const home = setupHome(); diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts index d6d5da1..baa2a9e 100644 --- a/tests/tools/files.test.ts +++ b/tests/tools/files.test.ts @@ -137,6 +137,55 @@ describe('file tools', () => { }); }); + it('reports UTF-8 bytes written for multibyte content', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const result = await executeTool( + collectTools(registerFileTools).get('Write'), + { path: 'emoji.txt', content: 'a🙂漢' }, + cwd, + ); + + expect(firstText(result)).toBe('Successfully wrote 8 bytes to emoji.txt'); + expect(result.details).toEqual({ + path: join(cwd, 'emoji.txt'), + bytesWritten: 8, + }); + }); + + it('honors a zero read limit', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'notes.txt'), 'alpha\nbeta', 'utf-8'); + const result = await executeTool( + collectTools(registerFileTools).get('Read'), + { path: 'notes.txt', limit: 0 }, + cwd, + ); + + expect(firstText(result)).toBe( + '\n\n[Showing lines 1-0 of 2 total lines. Use offset to see more.]', + ); + expect(result.details).toEqual({ + path: join(cwd, 'notes.txt'), + totalLines: 2, + }); + }); + + it('does not add a blank numbered line for files ending with a newline', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'notes.txt'), 'alpha\nbeta\n', 'utf-8'); + const result = await executeTool( + collectTools(registerFileTools).get('Read'), + { path: 'notes.txt' }, + cwd, + ); + + expect(firstText(result)).toBe('1\talpha\n2\tbeta'); + expect(result.details).toEqual({ + path: join(cwd, 'notes.txt'), + totalLines: 2, + }); + }); + it('reports missing files without throwing', async () => { const cwd = tempDir('pi-grok-cli-files-'); const result = await executeTool( @@ -153,6 +202,21 @@ describe('file tools', () => { }); }); + it('renders read errors for existing paths without claiming the file is missing', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + mkdirSync(join(cwd, 'dir')); + const tools = collectTools(registerFileTools); + const result = await executeTool(tools.get('Read'), { path: 'dir' }, cwd); + + expect(firstText(result).startsWith('Read error:')).toBe(true); + expect(result.details).toEqual({ + path: join(cwd, 'dir'), + exists: true, + totalLines: 0, + }); + expect(renderToolResult(tools.get('Read'), result)).toBe('0 line(s)'); + }); + it('replaces every exact string occurrence', async () => { const cwd = tempDir('pi-grok-cli-files-'); writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index 3fefbe3..c96f233 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -1,6 +1,6 @@ -import { mkdirSync, writeFileSync } from 'node:fs'; +import { mkdirSync, symlinkSync, utimesSync, writeFileSync } from 'node:fs'; import { join } from 'node:path'; -import { describe, expect, it } from 'vitest'; +import { describe, expect, it, vi } from 'vitest'; import { registerSearchTools } from '../../src/tools/search.js'; import { collectTools, @@ -13,6 +13,8 @@ import { tempDir, } from './toolTestHelpers.js'; +const originalPath = process.env.PATH; + function setupProject() { const dir = tempDir('pi-grok-cli-search-'); mkdirSync(join(dir, 'src')); @@ -142,6 +144,43 @@ describe('search tools', () => { expect(result.details).toEqual({ fileCount: 0 }); }); + it('globs path-containing patterns through the find fallback', async () => { + const cwd = setupProject(); + const bin = tempDir('pi-grok-cli-search-bin-'); + symlinkSync('/usr/bin/find', join(bin, 'find')); + process.env.PATH = bin; + vi.resetModules(); + const fallbackTools = collectTools( + (await import('../../src/tools/search.js')).registerSearchTools, + ); + + try { + const result = await executeTool(fallbackTools.get('Glob'), { pattern: 'src/**/*.ts' }, cwd); + + expectGlobResult(cwd, result); + } finally { + process.env.PATH = originalPath; + } + }); + + it('sorts glob results by modification time newest first', async () => { + const cwd = setupProject(); + const oldTime = new Date('2024-01-01T00:00:00.000Z'); + const newTime = new Date('2024-01-02T00:00:00.000Z'); + utimesSync(join(cwd, 'src', 'alpha.ts'), oldTime, oldTime); + utimesSync(join(cwd, 'src', 'gamma.ts'), newTime, newTime); + const result = await executeTool( + collectTools(registerSearchTools).get('Glob'), + { pattern: '**/*.ts', path: 'src' }, + cwd, + ); + + expect(firstText(result).split('\n')).toEqual([ + join(cwd, 'src', 'gamma.ts'), + join(cwd, 'src', 'alpha.ts'), + ]); + }); + it('renders grep calls and result states', () => { const grep = collectTools(registerSearchTools).get('Grep'); const result = { diff --git a/tests/tools/shell.test.ts b/tests/tools/shell.test.ts index b4e99c8..2fc519b 100644 --- a/tests/tools/shell.test.ts +++ b/tests/tools/shell.test.ts @@ -89,6 +89,19 @@ describe('shell tool', () => { expect(firstText(failureResult).endsWith('[Output truncated at 50KB]')).toBe(true); }); + it('truncates multibyte output by characters without hitting exec buffer limits', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'perl -e \'print "漢" x 50001\'' }, + cwd, + ); + + expect(firstText(result)).toHaveLength('\n\n[Output truncated at 50KB]'.length + 50_000); + expect(firstText(result).startsWith('Shell error')).toBe(false); + expect(firstText(result).endsWith('[Output truncated at 50KB]')).toBe(true); + }); + it('renders shell calls and result states', () => { const shell = collectTools(registerShellTool).get('Shell'); From d60be860a7d07f82da98470a572ebf3bfc3ab217 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:20 +0900 Subject: [PATCH 12/24] fix: add failed and error metadata to tool error details --- src/tools/rendering.ts | 10 ++++++---- tests/tools/rendering.test.ts | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/tools/rendering.ts b/src/tools/rendering.ts index 0a3e2d4..724e90a 100644 --- a/src/tools/rendering.ts +++ b/src/tools/rendering.ts @@ -129,14 +129,15 @@ export function fileError( extraDetails: Omit, ): ToolResult { const err = error as ToolError; + const message = err.message ?? 'Unknown error'; return { content: [ { type: 'text', - text: `${toolName} error: ${err.message ?? 'Unknown error'}`, + text: `${toolName} error: ${message}`, }, ], - details: { path: filePath, ...extraDetails } as T, + details: { path: filePath, ...extraDetails, failed: true, error: message } as unknown as T, }; } @@ -148,14 +149,15 @@ export function toolError(error: unknown, toolName: string, emptyDetails: T): details: emptyDetails, }; } + const message = err.message ?? 'Unknown error'; return { content: [ { type: 'text', - text: `${toolName} error: ${err.message ?? 'Unknown error'}`, + text: `${toolName} error: ${message}`, }, ], - details: emptyDetails, + details: { ...emptyDetails, failed: true, error: message } as T, }; } diff --git a/tests/tools/rendering.test.ts b/tests/tools/rendering.test.ts index e54809c..2f673f2 100644 --- a/tests/tools/rendering.test.ts +++ b/tests/tools/rendering.test.ts @@ -72,7 +72,7 @@ describe('tool rendering helpers', () => { }); expect(fileError({}, 'Read', '/tmp/file.txt', { totalLines: 0 })).toEqual({ content: [{ type: 'text', text: 'Read error: Unknown error' }], - details: { path: '/tmp/file.txt', totalLines: 0 }, + details: { path: '/tmp/file.txt', totalLines: 0, failed: true, error: 'Unknown error' }, }); expect(toolError({ code: 1 }, 'Grep', { matchCount: 0 })).toEqual({ content: [{ type: 'text', text: 'No matches found' }], @@ -82,11 +82,11 @@ describe('tool rendering helpers', () => { toolError({ code: 1, message: 'find: missing: No such file' }, 'Glob', { fileCount: 0 }), ).toEqual({ content: [{ type: 'text', text: 'Glob error: find: missing: No such file' }], - details: { fileCount: 0 }, + details: { fileCount: 0, failed: true, error: 'find: missing: No such file' }, }); expect(toolError({}, 'Grep', { matchCount: 0 })).toEqual({ content: [{ type: 'text', text: 'Grep error: Unknown error' }], - details: { matchCount: 0 }, + details: { matchCount: 0, failed: true, error: 'Unknown error' }, }); }); }); From ff169164623e205a7ea03815b66fe657de25f89d Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:24 +0900 Subject: [PATCH 13/24] fix: enforce workspace boundary for file tool paths --- src/tools/files.ts | 133 ++++++++++++++++++++++++++------------ tests/tools/files.test.ts | 74 ++++++++++++++++++--- 2 files changed, 157 insertions(+), 50 deletions(-) diff --git a/src/tools/files.ts b/src/tools/files.ts index 88260e4..18c3cfb 100644 --- a/src/tools/files.ts +++ b/src/tools/files.ts @@ -1,6 +1,13 @@ import { execFile } from 'node:child_process'; -import { existsSync, mkdirSync, readFileSync, unlinkSync, writeFileSync } from 'node:fs'; -import { dirname, resolve } from 'node:path'; +import { + existsSync, + promises as fs, + mkdirSync, + readFileSync, + unlinkSync, + writeFileSync, +} from 'node:fs'; +import { basename, dirname, join, resolve, sep } from 'node:path'; import { promisify } from 'node:util'; import { Type } from '@earendil-works/pi-ai'; import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; @@ -23,6 +30,7 @@ import { const execFileAsync = promisify(execFile); type ReplacementEdit = { oldText: string; newText: string }; +type FileDetails = { path: string; [key: string]: unknown }; type WriteArgs = { path: string; content: string }; type StrReplaceArgs = { path: string; old_str: string; new_str: string }; type EditArgs = { @@ -122,6 +130,48 @@ function renderPathToolCall(toolName: string, filePath: string, theme: ToolTheme return text(theme.fg('toolTitle', theme.bold(`${toolName} `)) + theme.fg('accent', filePath)); } +async function canonicalizeWithinWorkspace(cwd: string, requestedPath: string) { + const targetPath = resolve(cwd, requestedPath); + const realCwd = await fs.realpath(cwd); + const missingParts: string[] = []; + let currentPath = targetPath; + let realTarget: string | undefined; + while (!realTarget) { + try { + realTarget = join(await fs.realpath(currentPath), ...[...missingParts].reverse()); + } catch (error) { + const parentPath = dirname(currentPath); + if (parentPath === currentPath) throw error; + missingParts.push(basename(currentPath)); + currentPath = parentPath; + } + } + if (realTarget !== realCwd && !realTarget.startsWith(`${realCwd}${sep}`)) { + throw new Error('Path is outside the workspace'); + } + return realTarget; +} + +async function existingPathWithinWorkspace(cwd: string, requestedPath: string) { + const safePath = await canonicalizeWithinWorkspace(cwd, requestedPath); + return existsSync(safePath) ? safePath : undefined; +} + +async function existingPathOrNotFound( + cwd: string, + requestedPath: string, + extraDetails: Omit, +) { + return ( + (await existingPathWithinWorkspace(cwd, requestedPath)) ?? + fileNotFound(resolve(cwd, requestedPath), extraDetails) + ); +} + +function replacementPathOrNotFound(cwd: string, requestedPath: string) { + return existingPathOrNotFound(cwd, requestedPath, { replacements: 0 }); +} + export function registerFileTools(pi: ExtensionAPI) { // ── LS tool ────────────────────────────────────────────────────────── @@ -141,7 +191,8 @@ export function registerFileTools(pi: ExtensionAPI) { const targetPath = resolve(ctx.cwd, params.path); try { - const { stdout } = await execFileAsync('ls', ['-la', targetPath], { + const safePath = await canonicalizeWithinWorkspace(ctx.cwd, params.path); + const { stdout } = await execFileAsync('ls', ['-la', safePath], { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_BYTES, signal, @@ -154,18 +205,19 @@ export function registerFileTools(pi: ExtensionAPI) { return { content: [{ type: 'text', text: output }], - details: { path: targetPath }, + details: { path: safePath }, }; } catch (error: unknown) { const err = error as ToolError; + const message = err.message ?? 'Unknown error'; return { content: [ { type: 'text', - text: `LS error: ${err.message ?? 'Unknown error'}`, + text: `LS error: ${message}`, }, ], - details: { path: targetPath }, + details: { path: targetPath, failed: true, error: message }, }; } }, @@ -210,11 +262,13 @@ export function registerFileTools(pi: ExtensionAPI) { const filePath = resolve(ctx.cwd, params.path); try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { exists: false, totalLines: 0 }); - } + const safePath = await existingPathOrNotFound(ctx.cwd, params.path, { + exists: false, + totalLines: 0, + }); + if (typeof safePath !== 'string') return safePath; - const content = readFileSync(filePath, 'utf-8'); + const content = readFileSync(safePath, 'utf-8'); const lines = content.endsWith('\n') ? content.slice(0, -1).split('\n') : content.split('\n'); @@ -239,7 +293,7 @@ export function registerFileTools(pi: ExtensionAPI) { return { content: [{ type: 'text', text: output }], - details: { path: filePath, totalLines: lines.length }, + details: { path: safePath, totalLines: lines.length }, }; } catch (error: unknown) { const err = error as { code?: string }; @@ -304,8 +358,9 @@ export function registerFileTools(pi: ExtensionAPI) { const filePath = resolve(ctx.cwd, params.path); try { - mkdirSync(dirname(filePath), { recursive: true }); - writeFileSync(filePath, params.content, 'utf-8'); + const safePath = await canonicalizeWithinWorkspace(ctx.cwd, params.path); + mkdirSync(dirname(safePath), { recursive: true }); + writeFileSync(safePath, params.content, 'utf-8'); const bytesWritten = Buffer.byteLength(params.content, 'utf8'); return { @@ -315,18 +370,19 @@ export function registerFileTools(pi: ExtensionAPI) { text: `Successfully wrote ${bytesWritten} bytes to ${params.path}`, }, ], - details: { path: filePath, bytesWritten }, + details: { path: safePath, bytesWritten }, }; } catch (error: unknown) { const err = error as ToolError; + const message = err.message ?? 'Unknown error'; return { content: [ { type: 'text', - text: `Write error: ${err.message ?? 'Unknown error'}`, + text: `Write error: ${message}`, }, ], - details: { path: filePath, bytesWritten: 0 }, + details: { path: filePath, bytesWritten: 0, failed: true, error: message }, }; } }, @@ -383,16 +439,16 @@ export function registerFileTools(pi: ExtensionAPI) { }, async execute(_toolCallId, params, _signal, _onUpdate, ctx) { - const filePath = resolve(ctx.cwd, params.path); + const requestedPath = params.path; + const filePath = resolve(ctx.cwd, requestedPath); try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { replacements: 0 }); - } + const safePath = await replacementPathOrNotFound(ctx.cwd, requestedPath); + if (typeof safePath !== 'string') return safePath; - const content = readFileSync(filePath, 'utf-8'); + const content = readFileSync(safePath, 'utf-8'); if (params.old_str === '') { - return replacementResult('StrReplace error: old_str must not be empty', filePath); + return replacementResult('StrReplace error: old_str must not be empty', safePath); } const count = content.split(params.old_str).length - 1; @@ -400,12 +456,12 @@ export function registerFileTools(pi: ExtensionAPI) { if (count === 0) { return replacementResult( `String not found in ${params.path}: "${params.old_str}"`, - filePath, + safePath, ); } const newContent = content.replaceAll(params.old_str, () => params.new_str); - writeFileSync(filePath, newContent, 'utf-8'); + writeFileSync(safePath, newContent, 'utf-8'); return { content: [ @@ -414,7 +470,7 @@ export function registerFileTools(pi: ExtensionAPI) { text: `Replaced ${count} occurrence(s) in ${params.path}`, }, ], - details: { path: filePath, replacements: count }, + details: { path: safePath, replacements: count }, }; } catch (error: unknown) { return fileError(error, 'StrReplace', filePath, { replacements: 0 }); @@ -488,11 +544,9 @@ export function registerFileTools(pi: ExtensionAPI) { async execute(_toolCallId, params, _signal, _onUpdate, ctx) { const filePath = resolve(ctx.cwd, params.path); - if (!existsSync(filePath)) { - return fileNotFound(filePath, { replacements: 0 }); - } - try { + const safePath = await replacementPathOrNotFound(ctx.cwd, params.path); + if (typeof safePath !== 'string') return safePath; if (!params.edits?.length) { return { content: [ @@ -503,20 +557,20 @@ export function registerFileTools(pi: ExtensionAPI) { : 'Edit error: provide at least one exact text replacement', }, ], - details: { path: filePath, replacements: 0 }, + details: { path: safePath, replacements: 0 }, }; } if (params.edits.some((edit) => edit.oldText === '')) { - return replacementResult('Edit error: oldText must not be empty', filePath); + return replacementResult('Edit error: oldText must not be empty', safePath); } - const result = applyEdits(readFileSync(filePath, 'utf-8'), params.edits); + const result = applyEdits(readFileSync(safePath, 'utf-8'), params.edits); if (result.replacements === 0) { - return replacementResult(`No replacement strings found in ${params.path}`, filePath); + return replacementResult(`No replacement strings found in ${params.path}`, safePath); } - writeFileSync(filePath, result.content, 'utf-8'); + writeFileSync(safePath, result.content, 'utf-8'); return { content: [ @@ -525,7 +579,7 @@ export function registerFileTools(pi: ExtensionAPI) { text: `Applied ${result.replacements} replacement(s) in ${params.path}`, }, ], - details: { path: filePath, replacements: result.replacements }, + details: { path: safePath, replacements: result.replacements }, }; } catch (error: unknown) { return fileError(error, 'Edit', filePath, { replacements: 0 }); @@ -557,15 +611,14 @@ export function registerFileTools(pi: ExtensionAPI) { const filePath = resolve(ctx.cwd, params.path); try { - if (!existsSync(filePath)) { - return fileNotFound(filePath, { deleted: false }); - } + const safePath = await existingPathOrNotFound(ctx.cwd, params.path, { deleted: false }); + if (typeof safePath !== 'string') return safePath; - unlinkSync(filePath); + unlinkSync(safePath); return { content: [{ type: 'text', text: `Successfully deleted ${params.path}` }], - details: { path: filePath, deleted: true }, + details: { path: safePath, deleted: true }, }; } catch (error: unknown) { return fileError(error, 'Delete', filePath, { deleted: false }); diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts index baa2a9e..b7fd725 100644 --- a/tests/tools/files.test.ts +++ b/tests/tools/files.test.ts @@ -1,4 +1,11 @@ -import { existsSync, mkdirSync, readFileSync, writeFileSync } from 'node:fs'; +import { + existsSync, + mkdirSync, + readFileSync, + realpathSync, + symlinkSync, + writeFileSync, +} from 'node:fs'; import { join } from 'node:path'; import { describe, expect, it } from 'vitest'; import { registerFileTools } from '../../src/tools/files.js'; @@ -15,12 +22,16 @@ import { function expectStoryState(result: ToolResult, cwd: string, replacements: number, content: string) { expect(result.details).toEqual({ - path: join(cwd, 'story.txt'), + path: expectedPath(cwd, 'story.txt'), replacements, }); expect(readFileSync(join(cwd, 'story.txt'), 'utf-8')).toBe(content); } +function expectedPath(cwd: string, ...parts: string[]) { + return join(realpathSync(cwd), ...parts); +} + function strReplace(cwd: string, old_str: string, new_str: string) { return executeTool( collectTools(registerFileTools).get('StrReplace'), @@ -47,7 +58,7 @@ describe('file tools', () => { expect(firstText(result)).toContain('.hidden'); expect(firstText(result)).toContain('visible.txt'); - expect(result.details).toEqual({ path: cwd }); + expect(result.details).toEqual({ path: realpathSync(cwd) }); }); it('reports filesystem errors for invalid file operations', async () => { @@ -78,14 +89,20 @@ describe('file tools', () => { expect(writeResult.details).toEqual({ path: join(cwd, 'blocked', 'file.txt'), bytesWritten: 0, + failed: true, + error: expect.stringContaining('EEXIST: file already exists, mkdir'), }); expect(replaceResult.details).toEqual({ path: join(cwd, 'dir'), replacements: 0, + failed: true, + error: expect.stringContaining('EISDIR: illegal operation on a directory, read'), }); expect(deleteResult.details).toEqual({ path: join(cwd, 'dir'), deleted: false, + failed: true, + error: expect.stringContaining('operation not permitted'), }); }); @@ -101,7 +118,7 @@ describe('file tools', () => { expect(firstText(writeResult)).toBe('Successfully wrote 22 bytes to nested/notes.txt'); expect(writeResult.details).toEqual({ - path: join(cwd, 'nested/notes.txt'), + path: expectedPath(cwd, 'nested/notes.txt'), bytesWritten: 22, }); @@ -115,7 +132,7 @@ describe('file tools', () => { '2\tbeta\n3\tgamma\n\n[Showing lines 2-3 of 4 total lines. Use offset to see more.]', ); expect(readResult.details).toEqual({ - path: join(cwd, 'nested/notes.txt'), + path: expectedPath(cwd, 'nested/notes.txt'), totalLines: 4, }); }); @@ -132,7 +149,7 @@ describe('file tools', () => { expect(firstText(result)).toBe('Successfully wrote 10 bytes to nested/notes.txt'); expect(readFileSync(join(cwd, 'nested/notes.txt'), 'utf-8')).toBe('alpha\nbeta'); expect(result.details).toEqual({ - path: join(cwd, 'nested/notes.txt'), + path: expectedPath(cwd, 'nested/notes.txt'), bytesWritten: 10, }); }); @@ -147,7 +164,7 @@ describe('file tools', () => { expect(firstText(result)).toBe('Successfully wrote 8 bytes to emoji.txt'); expect(result.details).toEqual({ - path: join(cwd, 'emoji.txt'), + path: expectedPath(cwd, 'emoji.txt'), bytesWritten: 8, }); }); @@ -165,7 +182,7 @@ describe('file tools', () => { '\n\n[Showing lines 1-0 of 2 total lines. Use offset to see more.]', ); expect(result.details).toEqual({ - path: join(cwd, 'notes.txt'), + path: expectedPath(cwd, 'notes.txt'), totalLines: 2, }); }); @@ -181,7 +198,7 @@ describe('file tools', () => { expect(firstText(result)).toBe('1\talpha\n2\tbeta'); expect(result.details).toEqual({ - path: join(cwd, 'notes.txt'), + path: expectedPath(cwd, 'notes.txt'), totalLines: 2, }); }); @@ -202,6 +219,41 @@ describe('file tools', () => { }); }); + it('rejects paths that escape the workspace', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const outside = tempDir('pi-grok-cli-files-outside-'); + writeFileSync(join(outside, 'secret.txt'), 'secret', 'utf-8'); + symlinkSync(outside, join(cwd, 'outside')); + + const readResult = await executeTool( + collectTools(registerFileTools).get('Read'), + { path: 'outside/secret.txt' }, + cwd, + ); + const writeResult = await executeTool( + collectTools(registerFileTools).get('Write'), + { path: '../escape.txt', content: 'escape' }, + cwd, + ); + + expect(firstText(readResult)).toBe('Read error: Path is outside the workspace'); + expect(readResult.details).toEqual({ + path: join(cwd, 'outside', 'secret.txt'), + exists: true, + totalLines: 0, + failed: true, + error: 'Path is outside the workspace', + }); + expect(firstText(writeResult)).toBe('Write error: Path is outside the workspace'); + expect(writeResult.details).toEqual({ + path: join(cwd, '..', 'escape.txt'), + bytesWritten: 0, + failed: true, + error: 'Path is outside the workspace', + }); + expect(existsSync(join(cwd, '..', 'escape.txt'))).toBe(false); + }); + it('renders read errors for existing paths without claiming the file is missing', async () => { const cwd = tempDir('pi-grok-cli-files-'); mkdirSync(join(cwd, 'dir')); @@ -213,6 +265,8 @@ describe('file tools', () => { path: join(cwd, 'dir'), exists: true, totalLines: 0, + failed: true, + error: expect.stringContaining('EISDIR: illegal operation on a directory, read'), }); expect(renderToolResult(tools.get('Read'), result)).toBe('0 line(s)'); }); @@ -391,7 +445,7 @@ describe('file tools', () => { expect(firstText(deletedResult)).toBe('Successfully deleted remove.txt'); expect(deletedResult.details).toEqual({ - path: join(cwd, 'remove.txt'), + path: expectedPath(cwd, 'remove.txt'), deleted: true, }); expect(existsSync(join(cwd, 'remove.txt'))).toBe(false); From e7ae2f788a052340cdc3953d01e91fd2bf4a2377 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:28 +0900 Subject: [PATCH 14/24] fix: support basename-only glob patterns in find fallback --- src/tools/search.ts | 14 ++++++------ tests/tools/search.test.ts | 47 ++++++++++++++++++++++++++------------ 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/src/tools/search.ts b/src/tools/search.ts index b068cb2..d605602 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -1,6 +1,6 @@ import { execFile } from 'node:child_process'; import { statSync } from 'node:fs'; -import { relative, resolve } from 'node:path'; +import { basename, relative, resolve } from 'node:path'; import { promisify } from 'node:util'; import { Type } from '@earendil-works/pi-ai'; import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; @@ -189,17 +189,17 @@ export function registerSearchTools(pi: ExtensionAPI) { ); files = result.stdout.trim().split('\n').filter(Boolean); } else { - const matcher = globToRegExp(normalizePath(params.pattern)); + const normalizedPattern = normalizePath(params.pattern); + const matcher = globToRegExp(normalizedPattern); + const matchesFile = normalizedPattern.includes('/') + ? (file: string) => matcher.test(normalizePath(relative(ctx.cwd, file))) + : (file: string) => matcher.test(basename(file)); const result = await execFileAsync('find', [searchPath, '-type', 'f'], { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_BYTES, signal, }); - files = result.stdout - .trim() - .split('\n') - .filter(Boolean) - .filter((file) => matcher.test(normalizePath(relative(ctx.cwd, file)))); + files = result.stdout.trim().split('\n').filter(Boolean).filter(matchesFile); } files = sortByModifiedNewest(files); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index c96f233..a9d71af 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -13,8 +13,6 @@ import { tempDir, } from './toolTestHelpers.js'; -const originalPath = process.env.PATH; - function setupProject() { const dir = tempDir('pi-grok-cli-search-'); mkdirSync(join(dir, 'src')); @@ -37,6 +35,22 @@ function expectGlobResult(cwd: string, result: ToolResult) { expect(result.details).toEqual({ fileCount: 2 }); } +async function withFindFallbackTools( + run: (tools: ReturnType) => Promise, +) { + const bin = tempDir('pi-grok-cli-search-bin-'); + symlinkSync('/usr/bin/find', join(bin, 'find')); + const oldPath = process.env.PATH; + process.env.PATH = bin; + vi.resetModules(); + try { + await run(collectTools((await import('../../src/tools/search.js')).registerSearchTools)); + } finally { + process.env.PATH = oldPath; + vi.resetModules(); + } +} + describe('search tools', () => { it('greps matching file contents with include filters', async () => { const cwd = setupProject(); @@ -107,7 +121,11 @@ describe('search tools', () => { ); expect(firstText(result).startsWith('Grep error:')).toBe(true); - expect(result.details).toEqual({ matchCount: 0 }); + expect(result.details).toEqual({ + matchCount: 0, + failed: true, + error: expect.stringContaining('regex parse error'), + }); }); it('globs files under the requested path', async () => { @@ -146,21 +164,20 @@ describe('search tools', () => { it('globs path-containing patterns through the find fallback', async () => { const cwd = setupProject(); - const bin = tempDir('pi-grok-cli-search-bin-'); - symlinkSync('/usr/bin/find', join(bin, 'find')); - process.env.PATH = bin; - vi.resetModules(); - const fallbackTools = collectTools( - (await import('../../src/tools/search.js')).registerSearchTools, - ); - - try { + await withFindFallbackTools(async (fallbackTools) => { const result = await executeTool(fallbackTools.get('Glob'), { pattern: 'src/**/*.ts' }, cwd); expectGlobResult(cwd, result); - } finally { - process.env.PATH = originalPath; - } + }); + }); + + it('globs basename-only patterns through the find fallback', async () => { + const cwd = setupProject(); + await withFindFallbackTools(async (fallbackTools) => { + const result = await executeTool(fallbackTools.get('Glob'), { pattern: '*.ts' }, cwd); + + expectGlobResult(cwd, result); + }); }); it('sorts glob results by modification time newest first', async () => { From 95a73d724f9ff2683f1c9899c2edd036de9e7504 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:32 +0900 Subject: [PATCH 15/24] fix: improve OAuth error handling, timeouts, and transport robustness --- src/auth/oauth.ts | 133 ++++++++++++++++++++++++++++++--------- tests/auth/oauth.test.ts | 98 +++++++++++++++++++++++++++-- 2 files changed, 197 insertions(+), 34 deletions(-) diff --git a/src/auth/oauth.ts b/src/auth/oauth.ts index 73a5b85..aba62a6 100644 --- a/src/auth/oauth.ts +++ b/src/auth/oauth.ts @@ -26,6 +26,10 @@ const CALLBACK_PORT = Number.parseInt(process.env.PI_GROK_CLI_CALLBACK_PORT || ' const CALLBACK_PATH = '/callback'; /** Refresh 120s before actual expiry. */ const REFRESH_SKEW_MS = 120_000; +const TOKEN_REQUEST_TIMEOUT_MS = Number.parseInt( + process.env.PI_GROK_CLI_TOKEN_TIMEOUT_MS || '30000', + 10, +); // ─── Types ──────────────────────────────────────────────────────────────────── @@ -128,7 +132,15 @@ async function discover(): Promise { XaiErrorCode.DISCOVERY_FAILED, ); } - const payload = (await response.json()) as Record; + let payload: Record; + try { + payload = (await response.json()) as Record; + } catch (cause) { + throw new XaiOAuthError( + `xAI OIDC discovery returned invalid JSON: ${cause instanceof Error ? cause.message : String(cause)}`, + XaiErrorCode.DISCOVERY_FAILED, + ); + } const authorizationEndpoint = validateEndpoint( String(payload.authorization_endpoint ?? ''), 'authorization_endpoint', @@ -216,8 +228,20 @@ function startCallbackServer(): Promise<{ let actualPort: number; try { actualPort = await listen(CALLBACK_PORT); - } catch { - actualPort = await listen(0); + } catch (firstError) { + try { + actualPort = await listen(0); + } catch (secondError) { + const errorDescription = `Could not bind xAI OAuth callback server on ${CALLBACK_HOST}:${CALLBACK_PORT} or an ephemeral port: ${secondError instanceof Error ? secondError.message : String(secondError)} (initial error: ${firstError instanceof Error ? firstError.message : String(firstError)})`; + return { + server, + redirectUri: `http://${CALLBACK_HOST}:${CALLBACK_PORT}${CALLBACK_PATH}`, + waitForCallback: async () => ({ + error: XaiErrorCode.CALLBACK_BIND_FAILED, + errorDescription, + }), + }; + } } const redirectUri = `http://${CALLBACK_HOST}:${actualPort}${CALLBACK_PATH}`; return { @@ -230,7 +254,7 @@ function startCallbackServer(): Promise<{ setTimeout( () => resolve({ - error: 'timeout', + error: XaiErrorCode.CALLBACK_TIMEOUT, errorDescription: 'Timed out waiting for xAI OAuth callback.', }), timeoutMs, @@ -243,33 +267,86 @@ function startCallbackServer(): Promise<{ // ─── Token exchange ─────────────────────────────────────────────────────────── +async function fetchTokenResponse( + tokenEndpoint: string, + body: URLSearchParams, + errorCode: string, + label: string, +): Promise { + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), TOKEN_REQUEST_TIMEOUT_MS); + try { + return await fetch(tokenEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }, + body, + signal: controller.signal, + }); + } catch (cause) { + throw new XaiOAuthError( + `xAI ${label} failed: ${cause instanceof Error ? cause.message : String(cause)}`, + errorCode, + ); + } finally { + clearTimeout(timeout); + } +} + +async function tokenResponseText(response: Response) { + try { + return await response.text(); + } catch (cause) { + return `unable to read response body: ${cause instanceof Error ? cause.message : String(cause)}`; + } +} + +async function tokenResponseJson( + response: Response, + errorCode: string, + label: string, +): Promise> { + try { + return (await response.json()) as Record; + } catch (cause) { + throw new XaiOAuthError( + `xAI ${label} returned invalid JSON: ${cause instanceof Error ? cause.message : String(cause)}`, + errorCode, + ); + } +} + async function exchangeCode( tokenEndpoint: string, code: string, redirectUri: string, verifier: string, ): Promise { - const response = await fetch(tokenEndpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - Accept: 'application/json', - }, - body: new URLSearchParams({ + const response = await fetchTokenResponse( + tokenEndpoint, + new URLSearchParams({ grant_type: 'authorization_code', client_id: CLIENT_ID, code, redirect_uri: redirectUri, code_verifier: verifier, }), - }); + XaiErrorCode.TOKEN_EXCHANGE_FAILED, + 'token exchange', + ); if (!response.ok) { throw new XaiOAuthError( - `xAI token exchange failed: ${response.status} ${await response.text()}`, + `xAI token exchange failed: ${response.status} ${await tokenResponseText(response)}`, XaiErrorCode.TOKEN_EXCHANGE_FAILED, ); } - const payload = (await response.json()) as Record; + const payload = await tokenResponseJson( + response, + XaiErrorCode.TOKEN_EXCHANGE_FAILED, + 'token exchange', + ); const access = String(payload.access_token ?? ''); const refresh = String(payload.refresh_token ?? ''); if (!access) { @@ -332,10 +409,12 @@ export async function login( const result = await callback.waitForCallback(180_000); if (result.error) { - throw new XaiOAuthError( - result.errorDescription ?? result.error, - XaiErrorCode.AUTHORIZATION_FAILED, - ); + const code = + result.error === XaiErrorCode.CALLBACK_BIND_FAILED || + result.error === XaiErrorCode.CALLBACK_TIMEOUT + ? result.error + : XaiErrorCode.AUTHORIZATION_FAILED; + throw new XaiOAuthError(result.errorDescription ?? result.error, code); } if (result.state !== state) { throw new XaiOAuthError( @@ -381,29 +460,27 @@ export async function refresh( ); } - const response = await fetch(tokenEndpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - Accept: 'application/json', - }, - body: new URLSearchParams({ + const response = await fetchTokenResponse( + tokenEndpoint, + new URLSearchParams({ grant_type: 'refresh_token', client_id: CLIENT_ID, refresh_token: credentials.refresh, }), - }); + XaiErrorCode.REFRESH_FAILED, + 'token refresh', + ); if (!response.ok) { const isFatal = response.status === 400 || response.status === 401 || response.status === 403; throw new XaiOAuthError( - `xAI token refresh failed: ${response.status} ${await response.text()}`, + `xAI token refresh failed: ${response.status} ${await tokenResponseText(response)}`, XaiErrorCode.REFRESH_FAILED, isFatal, ); } - const payload = (await response.json()) as Record; + const payload = await tokenResponseJson(response, XaiErrorCode.REFRESH_FAILED, 'token refresh'); const access = String(payload.access_token ?? ''); if (!access) { throw new XaiOAuthError( diff --git a/tests/auth/oauth.test.ts b/tests/auth/oauth.test.ts index 4f0d32e..57ea9d2 100644 --- a/tests/auth/oauth.test.ts +++ b/tests/auth/oauth.test.ts @@ -20,6 +20,13 @@ const discoveryDocument = { token_endpoint: 'https://auth.x.ai/oauth/token', }; +function authorizeCallback(auth: { url: string }) { + const url = new URL(auth.url); + void originalFetch( + `${url.searchParams.get('redirect_uri')}?code=callback-code&state=${url.searchParams.get('state')}`, + ); +} + afterEach(() => { process.env = { ...originalEnv }; globalThis.fetch = originalFetch; @@ -168,6 +175,26 @@ describe('OAuth helpers without network access', () => { }); }); + it('wraps refresh transport and JSON failures', async () => { + globalThis.fetch = vi.fn(async () => { + throw new Error('socket closed'); + }); + + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + message: 'xAI token refresh failed: socket closed', + }); + + globalThis.fetch = vi.fn( + async () => new Response('proxy error', { status: 200 }), + ); + + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + message: expect.stringContaining('xAI token refresh returned invalid JSON:'), + }); + }); + it('rejects unsafe token endpoints before fetching', async () => { const fetchMock = vi.fn(); globalThis.fetch = fetchMock; @@ -215,6 +242,17 @@ describe('OAuth helpers without network access', () => { }); }); + it('wraps malformed discovery JSON as discovery failure', async () => { + globalThis.fetch = vi.fn( + async () => new Response('proxy error', { status: 200 }), + ); + + await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_FAILED, + message: expect.stringContaining('xAI OIDC discovery returned invalid JSON:'), + }); + }); + it('rejects failed and invalid discovery responses', async () => { globalThis.fetch = vi.fn( async () => new Response('unavailable', { status: 503 }), @@ -255,12 +293,7 @@ describe('OAuth helpers without network access', () => { await expect( login({ - onAuth: (auth) => { - const url = new URL(auth.url); - void originalFetch( - `${url.searchParams.get('redirect_uri')}?code=callback-code&state=${url.searchParams.get('state')}`, - ); - }, + onAuth: authorizeCallback, }), ).resolves.toMatchObject({ access: 'login-access', @@ -277,4 +310,57 @@ describe('OAuth helpers without network access', () => { 'callback-code', ); }); + + it('reports callback timeouts with a dedicated error code', async () => { + vi.useFakeTimers(); + globalThis.fetch = vi.fn(async () => Response.json(discoveryDocument)); + const onAuth = vi.fn(); + const resultPromise = login({ onAuth }).then( + () => undefined, + (error: unknown) => error, + ); + + await vi.waitFor(() => expect(onAuth).toHaveBeenCalledOnce()); + await vi.advanceTimersByTimeAsync(180_000); + + await expect(resultPromise).resolves.toMatchObject({ + code: XaiErrorCode.CALLBACK_TIMEOUT, + message: 'Timed out waiting for xAI OAuth callback.', + }); + }); + + it('wraps token exchange transport and JSON failures', async () => { + const fetchMock = vi.fn(async (input) => { + if (input === 'https://auth.x.ai/.well-known/openid-configuration') { + return Response.json(discoveryDocument); + } + throw new Error('exchange socket closed'); + }); + globalThis.fetch = fetchMock; + + await expect( + login({ + onAuth: authorizeCallback, + }), + ).rejects.toMatchObject({ + code: XaiErrorCode.TOKEN_EXCHANGE_FAILED, + message: 'xAI token exchange failed: exchange socket closed', + }); + + globalThis.fetch = vi.fn(async (input) => { + if (input === 'https://auth.x.ai/.well-known/openid-configuration') { + return Response.json(discoveryDocument); + } + return new Response('proxy error', { status: 200 }); + }); + + await expect( + login({ + onAuth: authorizeCallback, + }), + ).rejects.toMatchObject({ + code: XaiErrorCode.TOKEN_EXCHANGE_FAILED, + message: expect.stringContaining('xAI token exchange returned invalid JSON:'), + }); + }); }); From 13b1b0ff03b8fd233158c37e2cfd45d882e46295 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:35 +0900 Subject: [PATCH 16/24] fix: add cross-platform shell command detection --- src/tools/shell.ts | 44 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/tools/shell.ts b/src/tools/shell.ts index a0d5dcf..5078826 100644 --- a/src/tools/shell.ts +++ b/src/tools/shell.ts @@ -1,4 +1,5 @@ import { execFile } from 'node:child_process'; +import { existsSync } from 'node:fs'; import { resolve } from 'node:path'; import { promisify } from 'node:util'; import { Type } from '@earendil-works/pi-ai'; @@ -14,6 +15,35 @@ import { const execFileAsync = promisify(execFile); +function shellCommand(command: string): { file: string; args: string[] } | undefined { + if (process.platform === 'win32') { + if (existsSync('C:\\Windows\\System32\\cmd.exe')) { + return { file: 'cmd.exe', args: ['/d', '/s', '/c', command] }; + } + if (existsSync('C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe')) { + return { + file: 'powershell.exe', + args: ['-NoLogo', '-NoProfile', '-Command', command], + }; + } + return undefined; + } + + if ( + process.platform !== 'darwin' && + process.platform !== 'linux' && + process.platform !== 'freebsd' + ) { + return undefined; + } + + if (existsSync('/bin/bash')) return { file: '/bin/bash', args: ['-c', command] }; + if (existsSync('/usr/bin/bash')) return { file: '/usr/bin/bash', args: ['-c', command] }; + if (existsSync('/bin/sh')) return { file: '/bin/sh', args: ['-c', command] }; + if (existsSync('/usr/bin/sh')) return { file: '/usr/bin/sh', args: ['-c', command] }; + return undefined; +} + export function registerShellTool(pi: ExtensionAPI) { // ── Shell tool ─────────────────────────────────────────────────────── @@ -44,7 +74,19 @@ export function registerShellTool(pi: ExtensionAPI) { const timeout = params.timeout ?? 120_000; try { - const { stdout, stderr } = await execFileAsync('bash', ['-c', params.command], { + const shell = shellCommand(params.command); + if (!shell) { + return { + content: [ + { + type: 'text', + text: 'Shell error: unsupported platform or shell not found', + }, + ], + details: { exitCode: 1, command: params.command }, + }; + } + const { stdout, stderr } = await execFileAsync(shell.file, shell.args, { cwd, maxBuffer: MAX_OUTPUT_BYTES, timeout, From 294e966b623b05f623c0e59f6dad553c8c0a6a13 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:39 +0900 Subject: [PATCH 17/24] fix: refine reasoning effort support check with model-level config --- src/models/catalog.ts | 21 ++++++++------------- tests/models/catalog.test.ts | 2 +- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/models/catalog.ts b/src/models/catalog.ts index 50e3bb8..73f4858 100644 --- a/src/models/catalog.ts +++ b/src/models/catalog.ts @@ -106,23 +106,18 @@ const FALLBACK_MODELS: GrokCliModelConfig[] = [ }, ]; -// ─── Reasoning-effort allowlist ─────────────────────────────────────────────── - -/** - * Only these model prefixes support `reasoning.effort` in the Responses API. - * Everything else gets the param stripped in the sanitizer. - */ -const EFFORT_CAPABLE_PREFIXES = [ - 'grok-3-mini', - 'grok-4.20-multi-agent', - 'grok-4.3', - 'grok-composer', -]; +const EFFORT_CAPABLE_PREFIXES = ['grok-3-mini', 'grok-4.20-multi-agent', 'grok-4.3']; export function supportsReasoningEffort(modelId: string): boolean { const parts = modelId.split('/'); const name = parts.at(-1) ?? modelId; - return EFFORT_CAPABLE_PREFIXES.some((p) => name.toLowerCase().startsWith(p)); + const model = resolveModels().find((entry) => entry.id.toLowerCase() === name.toLowerCase()); + if (!EFFORT_CAPABLE_PREFIXES.some((prefix) => name.toLowerCase().startsWith(prefix))) { + return false; + } + if (!model?.reasoning) return false; + if (!model.thinkingLevelMap) return true; + return Object.values(model.thinkingLevelMap).some((level) => level !== null && level !== 'none'); } // ─── PI_GROK_CLI_MODELS env override ────────────────────────────────────────── diff --git a/tests/models/catalog.test.ts b/tests/models/catalog.test.ts index 92af5a2..0f9ea87 100644 --- a/tests/models/catalog.test.ts +++ b/tests/models/catalog.test.ts @@ -10,7 +10,7 @@ afterEach(() => { describe('model catalog', () => { it('reports reasoning-effort support by normalized model name', () => { expect(supportsReasoningEffort('grok-4.3')).toBe(true); - expect(supportsReasoningEffort('grok-cli/GROK-COMPOSER-2.5-fast')).toBe(true); + expect(supportsReasoningEffort('grok-cli/GROK-COMPOSER-2.5-fast')).toBe(false); expect(supportsReasoningEffort('grok-4.20-0309-non-reasoning')).toBe(false); }); From d8267a8e9222c41322402cc7a5825da589f84513 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:43 +0900 Subject: [PATCH 18/24] refactor: centralize GROK_TOOL_NAMES in register module --- src/provider/toolScope.ts | 13 +------------ src/tools/register.ts | 12 ++++++++++++ tests/tools/register.test.ts | 14 ++------------ 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/src/provider/toolScope.ts b/src/provider/toolScope.ts index 351058c..5730cc6 100644 --- a/src/provider/toolScope.ts +++ b/src/provider/toolScope.ts @@ -1,16 +1,5 @@ import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; - -export const GROK_TOOL_NAMES = [ - 'Grep', - 'Glob', - 'LS', - 'Read', - 'Write', - 'StrReplace', - 'Edit', - 'Delete', - 'Shell', -]; +import { GROK_TOOL_NAMES } from '../tools/register.js'; export function syncGrokTools( pi: Pick, diff --git a/src/tools/register.ts b/src/tools/register.ts index e798f1f..b981582 100644 --- a/src/tools/register.ts +++ b/src/tools/register.ts @@ -3,6 +3,18 @@ import { registerFileTools } from './files.js'; import { registerSearchTools } from './search.js'; import { registerShellTool } from './shell.js'; +export const GROK_TOOL_NAMES = [ + 'Grep', + 'Glob', + 'LS', + 'Read', + 'Write', + 'StrReplace', + 'Edit', + 'Delete', + 'Shell', +]; + export function registerGrokTools(pi: ExtensionAPI) { registerSearchTools(pi); registerFileTools(pi); diff --git a/tests/tools/register.test.ts b/tests/tools/register.test.ts index e7eb446..04b8690 100644 --- a/tests/tools/register.test.ts +++ b/tests/tools/register.test.ts @@ -1,6 +1,6 @@ import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { describe, expect, it } from 'vitest'; -import { registerGrokTools } from '../../src/tools/register.js'; +import { GROK_TOOL_NAMES, registerGrokTools } from '../../src/tools/register.js'; describe('Grok tool registration', () => { it('registers all Grok/Cursor-native tool shims with renderers', () => { @@ -14,16 +14,6 @@ describe('Grok tool registration', () => { }, } as unknown as ExtensionAPI); - expect(toolNames.sort()).toEqual([ - 'Delete', - 'Edit', - 'Glob', - 'Grep', - 'LS', - 'Read', - 'Shell', - 'StrReplace', - 'Write', - ]); + expect(toolNames.sort()).toEqual([...GROK_TOOL_NAMES].sort()); }); }); From 93bc1f4ecb4b45c983891c9e21444995a6820aae Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:47 +0900 Subject: [PATCH 19/24] fix: extract all system/developer instructions and preserve existing text format --- src/payload/sanitize.ts | 18 ++++++++---------- tests/payload/sanitize.test.ts | 19 +++++++++++++++++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/payload/sanitize.ts b/src/payload/sanitize.ts index 75539e4..1ec371b 100644 --- a/src/payload/sanitize.ts +++ b/src/payload/sanitize.ts @@ -248,15 +248,13 @@ export function sanitizePayload( // Move system/developer messages to top-level instructions. // xAI rejects role: "developer" and role: "system" in the input array. const instructionParts: string[] = []; - while (input.length > 0) { - const first = input[0]; - if (!first || typeof first !== 'object') break; - const role = (first as Record).role; - if (role !== 'developer' && role !== 'system') break; - const text = textFromContent((first as Record).content).trim(); + input = input.filter((item) => { + const role = (item as Record).role; + if (role !== 'developer' && role !== 'system') return true; + const text = textFromContent((item as Record).content).trim(); if (text) instructionParts.push(text); - input.shift(); - } + return false; + }); if (instructionParts.length > 0) { const existing = typeof next.instructions === 'string' && next.instructions ? next.instructions : ''; @@ -276,8 +274,8 @@ export function sanitizePayload( } // ── response_format → text.format ──────────────────────────────────── - if (next.response_format && !next.text) { - next.text = { format: next.response_format }; + if (next.response_format) { + if (!next.text) next.text = { format: next.response_format }; delete next.response_format; } diff --git a/tests/payload/sanitize.test.ts b/tests/payload/sanitize.test.ts index a7063ec..0a19e91 100644 --- a/tests/payload/sanitize.test.ts +++ b/tests/payload/sanitize.test.ts @@ -5,7 +5,7 @@ import { describe, expect, it } from 'vitest'; import { sanitizePayload } from '../../src/payload/sanitize.js'; describe('payload sanitization', () => { - it('removes unsupported items and moves leading instructions', () => { + it('removes unsupported items and moves all instructions', () => { const payload = sanitizePayload( { instructions: 'existing instruction', @@ -21,6 +21,7 @@ describe('payload sanitization', () => { { type: 'reasoning', content: 'cached reasoning' }, { role: 'user', content: '' }, { role: 'user', content: 'hello' }, + { role: 'system', content: 'later system instruction' }, ], include: ['reasoning.encrypted_content', 'message.output_text'], prompt_cache_retention: '24h', @@ -32,7 +33,7 @@ describe('payload sanitization', () => { ); expect(payload.instructions).toBe( - 'existing instruction\n\nsystem instruction\n\ndeveloper instruction\noutput text instruction', + 'existing instruction\n\nsystem instruction\n\ndeveloper instruction\noutput text instruction\n\nlater system instruction', ); expect(payload.input).toEqual([{ role: 'user', content: 'hello' }]); expect(payload.include).toEqual(['message.output_text']); @@ -43,6 +44,20 @@ describe('payload sanitization', () => { expect(payload.prompt_cache_key).toBe('session-123'); }); + it('preserves existing text while removing response_format', () => { + const payload = sanitizePayload( + { + input: 'plain prompt', + text: { format: { type: 'text' } }, + response_format: { type: 'json_object' }, + }, + 'grok-4.3', + ); + + expect(payload.text).toEqual({ format: { type: 'text' } }); + expect(payload.response_format).toBeUndefined(); + }); + it('strips reasoning fields for models that do not accept reasoning effort', () => { const payload = sanitizePayload( { From 982c360cee7f90c006baa1ba4f56c21c7380d518 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:04:51 +0900 Subject: [PATCH 20/24] docs: clarify GROK_CLI_OAUTH_TOKEN bypass behavior --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6d5ed50..2e2dce3 100644 --- a/README.md +++ b/README.md @@ -73,4 +73,4 @@ Select **"Grok CLI"** from the provider list. This opens the xAI OAuth page in y | `PI_GROK_CLI_MODELS` | (all models) | Comma-separated model IDs to expose | | `PI_GROK_CLI_OAUTH_CLIENT_ID` | `b1a00492-...` | Override OAuth client ID | | `PI_GROK_CLI_OAUTH_SCOPE` | `openid profile email offline_access grok-cli:access api:access` | Override OAuth scopes | -| `GROK_CLI_OAUTH_TOKEN` | — | Direct token bypass (no auto-refresh) | +| `GROK_CLI_OAUTH_TOKEN` | — | Direct token bypass that skips OAuth entirely. No automatic refresh or renewal is performed; provide a valid external access token and replace or rotate it when it expires. | From 79820414a8adadec9d2b21b7dcfe1d8c93e1abd2 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:10:21 +0900 Subject: [PATCH 21/24] test: accept platform-specific tool errors --- tests/tools/files.test.ts | 4 +++- tests/tools/search.test.ts | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts index b7fd725..977a468 100644 --- a/tests/tools/files.test.ts +++ b/tests/tools/files.test.ts @@ -102,7 +102,9 @@ describe('file tools', () => { path: join(cwd, 'dir'), deleted: false, failed: true, - error: expect.stringContaining('operation not permitted'), + error: expect.stringMatching( + /EISDIR: illegal operation on a directory|operation not permitted/, + ), }); }); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index a9d71af..dbff382 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -124,7 +124,7 @@ describe('search tools', () => { expect(result.details).toEqual({ matchCount: 0, failed: true, - error: expect.stringContaining('regex parse error'), + error: expect.stringMatching(/regex parse error|Invalid regular expression/), }); }); From eb0bea0d04fd85ba469c60240d4573d2437b620a Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:30:10 +0900 Subject: [PATCH 22/24] fix: make tool fallbacks and image paths portable --- src/payload/sanitize.ts | 42 ++++++++++++++++++++------------ src/provider/register.ts | 2 +- src/provider/status.ts | 8 +++--- src/tools/files.ts | 13 ++-------- src/tools/search.ts | 31 ++++++++++++++++++------ tests/payload/sanitize.test.ts | 43 ++++++++++++++++++++++++++++++++- tests/provider/register.test.ts | 22 +++++++++++++++++ tests/tools/files.test.ts | 24 +++++++++++++++++- tests/tools/search.test.ts | 23 ++++++++++++++++++ 9 files changed, 166 insertions(+), 42 deletions(-) diff --git a/src/payload/sanitize.ts b/src/payload/sanitize.ts index 1ec371b..e24f7ab 100644 --- a/src/payload/sanitize.ts +++ b/src/payload/sanitize.ts @@ -19,8 +19,8 @@ * - Uses prompt_cache_key for session affinity */ -import { existsSync, readFileSync } from 'node:fs'; -import { extname, isAbsolute, resolve } from 'node:path'; +import { existsSync, readFileSync, realpathSync } from 'node:fs'; +import { extname, isAbsolute, resolve, sep } from 'node:path'; import { fileURLToPath } from 'node:url'; import { supportsReasoningEffort } from '../models/catalog.js'; @@ -73,25 +73,34 @@ function imageMimeTypeForPath(path: string): string { } } -function resolveLocalImagePath(value: string): string | undefined { +function ensurePathWithinWorkspace(cwd: string, filePath: string) { + const realCwd = realpathSync(cwd); + const realPath = realpathSync(filePath); + if (realPath !== realCwd && !realPath.startsWith(`${realCwd}${sep}`)) { + throw new Error('Image path is outside the workspace'); + } + return realPath; +} + +function resolveLocalImagePath(value: string, cwd: string): string | undefined { const cleaned = unescapeShellPath(value); if (!cleaned) return undefined; if (cleaned.startsWith('file://')) { try { - return fileURLToPath(cleaned); + const filePath = fileURLToPath(cleaned); + return existsSync(filePath) ? ensurePathWithinWorkspace(cwd, filePath) : undefined; } catch { return undefined; } } - const candidates = [cleaned]; - if (!isAbsolute(cleaned)) candidates.push(resolve(process.cwd(), cleaned)); + const candidate = isAbsolute(cleaned) ? cleaned : resolve(cwd, cleaned); - return candidates.find((candidate) => existsSync(candidate)); + return existsSync(candidate) ? ensurePathWithinWorkspace(cwd, candidate) : undefined; } -function normalizeImageInput(value: unknown): string | undefined { +function normalizeImageInput(value: unknown, cwd: string): string | undefined { if (typeof value !== 'string' || !value.trim()) return undefined; const cleaned = stripShellQuotes(value); @@ -99,7 +108,7 @@ function normalizeImageInput(value: unknown): string | undefined { return cleaned; } - const localPath = resolveLocalImagePath(cleaned); + const localPath = resolveLocalImagePath(cleaned, cwd); if (!localPath) { throw new Error(`Image file does not exist or is not a valid URL: ${cleaned}`); } @@ -131,8 +140,8 @@ function getImageUrlAndDetail(obj: Record): { return { imageUrl: obj.image_url, detail: obj.detail }; } -function normalizeImageParts(value: unknown): unknown { - if (Array.isArray(value)) return value.map(normalizeImageParts); +function normalizeImageParts(value: unknown, cwd: string): unknown { + if (Array.isArray(value)) return value.map((item) => normalizeImageParts(item, cwd)); if (!value || typeof value !== 'object') return value; const obj = { ...(value as Record) }; @@ -154,14 +163,14 @@ function normalizeImageParts(value: unknown): unknown { if (obj.type === 'input_image') { const { imageUrl, detail } = getImageUrlAndDetail(obj); - const normalized = normalizeImageInput(imageUrl); + const normalized = normalizeImageInput(imageUrl, cwd); if (normalized) obj.image_url = normalized; if (typeof detail === 'string' && detail) obj.detail = detail; if (typeof obj.detail !== 'string' || !obj.detail) obj.detail = 'auto'; } - if (Array.isArray(obj.content)) obj.content = normalizeImageParts(obj.content); - if (Array.isArray(obj.output)) obj.output = normalizeImageParts(obj.output); + if (Array.isArray(obj.content)) obj.content = normalizeImageParts(obj.content, cwd); + if (Array.isArray(obj.output)) obj.output = normalizeImageParts(obj.output, cwd); return obj; } @@ -224,7 +233,8 @@ function rewriteFunctionCallOutput(input: Record[]): Record, modelId: string, - sessionId?: string, + sessionId: string | undefined, + cwd: string, ): Record { const next = params; @@ -263,7 +273,7 @@ export function sanitizePayload( } // Normalize image parts (resolve local paths, fix types) - input = normalizeImageParts(input) as Record[]; + input = normalizeImageParts(input, cwd) as Record[]; // Rewrite function_call_output with images input = rewriteFunctionCallOutput(input); diff --git a/src/provider/register.ts b/src/provider/register.ts index 80fff0c..075d1b1 100644 --- a/src/provider/register.ts +++ b/src/provider/register.ts @@ -74,7 +74,7 @@ export default function registerGrokCli(pi: ExtensionAPI) { const modelId = ctx.model?.id ?? ''; const sessionId = ctx.sessionManager?.getSessionId(); - return sanitizePayload(event.payload as Record, modelId, sessionId); + return sanitizePayload(event.payload as Record, modelId, sessionId, ctx.cwd); }); registerStatusCommand(pi); diff --git a/src/provider/status.ts b/src/provider/status.ts index 1727e9a..41ca0a3 100644 --- a/src/provider/status.ts +++ b/src/provider/status.ts @@ -35,10 +35,10 @@ export function registerStatusCommand(pi: Pick) const lines = [ ' Quota:', - '', - ...formatQuota('grok-build', getCachedRateLimit('grok-build')), - '', - ...formatQuota('grok-composer-2.5-fast', getCachedRateLimit('grok-composer-2.5-fast')), + ...grokModels.flatMap((model: Model) => [ + '', + ...formatQuota(model.id, getCachedRateLimit(model.id)), + ]), ]; ctx.ui.notify(lines.join('\n'), 'info'); } catch (err) { diff --git a/src/tools/files.ts b/src/tools/files.ts index 18c3cfb..f5e31e2 100644 --- a/src/tools/files.ts +++ b/src/tools/files.ts @@ -1,4 +1,3 @@ -import { execFile } from 'node:child_process'; import { existsSync, promises as fs, @@ -8,7 +7,6 @@ import { writeFileSync, } from 'node:fs'; import { basename, dirname, join, resolve, sep } from 'node:path'; -import { promisify } from 'node:util'; import { Type } from '@earendil-works/pi-ai'; import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { @@ -16,7 +14,6 @@ import { detailRecord, fileError, fileNotFound, - MAX_OUTPUT_BYTES, MAX_OUTPUT_CHARS, numberDetail, recordFrom, @@ -27,8 +24,6 @@ import { text, } from './rendering.js'; -const execFileAsync = promisify(execFile); - type ReplacementEdit = { oldText: string; newText: string }; type FileDetails = { path: string; [key: string]: unknown }; type WriteArgs = { path: string; content: string }; @@ -192,13 +187,9 @@ export function registerFileTools(pi: ExtensionAPI) { try { const safePath = await canonicalizeWithinWorkspace(ctx.cwd, params.path); - const { stdout } = await execFileAsync('ls', ['-la', safePath], { - cwd: ctx.cwd, - maxBuffer: MAX_OUTPUT_BYTES, - signal, - }); + if (signal?.aborted) throw new Error('The operation was aborted'); - let output = stdout.trim(); + let output = (await fs.readdir(safePath)).sort().join('\n'); if (output.length > MAX_OUTPUT_CHARS) { output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[LS: output truncated at 50KB]`; } diff --git a/src/tools/search.ts b/src/tools/search.ts index d605602..c8e7d21 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -1,6 +1,6 @@ import { execFile } from 'node:child_process'; -import { statSync } from 'node:fs'; -import { basename, relative, resolve } from 'node:path'; +import { promises as fs, statSync } from 'node:fs'; +import { basename, join, relative, resolve } from 'node:path'; import { promisify } from 'node:util'; import { Type } from '@earendil-works/pi-ai'; import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; @@ -58,6 +58,26 @@ function sortByModifiedNewest(files: string[]) { }); } +async function listFilesRecursive(searchPath: string, signal?: AbortSignal): Promise { + if (signal?.aborted) throw new Error('The operation was aborted'); + const stats = await fs.stat(searchPath); + if (stats.isFile()) return [searchPath]; + if (!stats.isDirectory()) return []; + + return ( + await Promise.all( + ( + await fs.readdir(searchPath, { withFileTypes: true }) + ).map((entry) => { + const entryPath = join(searchPath, entry.name); + if (entry.isDirectory()) return listFilesRecursive(entryPath, signal); + if (entry.isFile()) return [entryPath]; + return []; + }), + ) + ).flat(); +} + export function registerSearchTools(pi: ExtensionAPI) { const GrepParams = Type.Object({ pattern: Type.String({ @@ -194,12 +214,7 @@ export function registerSearchTools(pi: ExtensionAPI) { const matchesFile = normalizedPattern.includes('/') ? (file: string) => matcher.test(normalizePath(relative(ctx.cwd, file))) : (file: string) => matcher.test(basename(file)); - const result = await execFileAsync('find', [searchPath, '-type', 'f'], { - cwd: ctx.cwd, - maxBuffer: MAX_OUTPUT_BYTES, - signal, - }); - files = result.stdout.trim().split('\n').filter(Boolean).filter(matchesFile); + files = (await listFilesRecursive(searchPath, signal)).filter(matchesFile); } files = sortByModifiedNewest(files); diff --git a/tests/payload/sanitize.test.ts b/tests/payload/sanitize.test.ts index 0a19e91..f3935a5 100644 --- a/tests/payload/sanitize.test.ts +++ b/tests/payload/sanitize.test.ts @@ -1,4 +1,4 @@ -import { mkdtempSync, rmSync, writeFileSync } from 'node:fs'; +import { mkdirSync, mkdtempSync, rmSync, writeFileSync } from 'node:fs'; import { tmpdir } from 'node:os'; import { join } from 'node:path'; import { describe, expect, it } from 'vitest'; @@ -30,6 +30,7 @@ describe('payload sanitization', () => { }, 'grok-4.3', 'session-123', + process.cwd(), ); expect(payload.instructions).toBe( @@ -52,6 +53,8 @@ describe('payload sanitization', () => { response_format: { type: 'json_object' }, }, 'grok-4.3', + undefined, + process.cwd(), ); expect(payload.text).toEqual({ format: { type: 'text' } }); @@ -69,6 +72,7 @@ describe('payload sanitization', () => { }, 'grok-build', 'new-session', + process.cwd(), ); expect(payload.input).toBe('plain prompt'); @@ -106,6 +110,8 @@ describe('payload sanitization', () => { ], }, 'grok-composer-2.5-fast', + undefined, + process.cwd(), ); expect(payload.input).toEqual([ @@ -163,6 +169,8 @@ describe('payload sanitization', () => { ], }, 'grok-4.3', + undefined, + dir, ); expect(payload.input).toEqual([ @@ -194,7 +202,40 @@ describe('payload sanitization', () => { ], }, 'grok-4.3', + undefined, + process.cwd(), ), ).toThrow('Image file does not exist or is not a valid URL: missing.png'); }); + + it('rejects local image paths outside the workspace', () => { + const dir = mkdtempSync(join(tmpdir(), 'pi-grok-cli-test-')); + const workspace = join(dir, 'workspace'); + const originalCwd = process.cwd(); + writeFileSync(join(dir, 'secret.png'), Buffer.from('png image bytes')); + mkdirSync(workspace); + + try { + process.chdir(workspace); + + expect(() => + sanitizePayload( + { + input: [ + { + role: 'user', + content: [{ type: 'input_image', image_url: join('..', 'secret.png') }], + }, + ], + }, + 'grok-4.3', + undefined, + process.cwd(), + ), + ).toThrow('Image path is outside the workspace'); + } finally { + process.chdir(originalCwd); + rmSync(dir, { recursive: true, force: true }); + } + }); }); diff --git a/tests/provider/register.test.ts b/tests/provider/register.test.ts index 8b0873c..ad5a3f0 100644 --- a/tests/provider/register.test.ts +++ b/tests/provider/register.test.ts @@ -46,6 +46,7 @@ interface Renderable { } interface TestContext { + cwd?: string; modelRegistry: { getAll: () => { provider: string; id: string }[]; getApiKeyForProvider?: (provider: string) => Promise; @@ -216,6 +217,26 @@ describe('Grok CLI status command', () => { expect(notify.mock.calls.at(-1)?.[0]).toContain('Requests: 179/180 remaining'); }); + it('shows cached quotas for registered Grok models instead of hard-coded names', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + setupHome(); + const extension = await setupExtension(); + extension.providers + .get('grok-cli') + ?.streamSimple?.({ provider: 'grok-cli', id: 'custom' }, {}, {}); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], { + modelRegistry: { + getAll: () => [{ provider: 'grok-cli', id: 'custom' }], + }, + ui: { notify }, + }); + + expect(notify.mock.calls.at(-1)?.[0]).toContain('custom:\n Cached:'); + expect(notify.mock.calls.at(-1)?.[0]).not.toContain('grok-build:'); + }); + it('persists cached quotas to the global pi config directory', async () => { delete process.env.GROK_CLI_OAUTH_TOKEN; const home = setupHome(); @@ -400,6 +421,7 @@ describe('Grok CLI provider registration', () => { }, }, { + cwd: process.cwd(), model: { provider: 'grok-cli', id: 'grok-4.3' }, modelRegistry: { getAll: () => [] }, sessionManager: { getSessionId: () => 'session-123' }, diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts index 977a468..3e5ca45 100644 --- a/tests/tools/files.test.ts +++ b/tests/tools/files.test.ts @@ -7,7 +7,7 @@ import { writeFileSync, } from 'node:fs'; import { join } from 'node:path'; -import { describe, expect, it } from 'vitest'; +import { describe, expect, it, vi } from 'vitest'; import { registerFileTools } from '../../src/tools/files.js'; import { collectTools, @@ -61,6 +61,28 @@ describe('file tools', () => { expect(result.details).toEqual({ path: realpathSync(cwd) }); }); + it('lists directory contents when Unix ls is not on PATH', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const oldPath = process.env.PATH; + process.env.PATH = tempDir('pi-grok-cli-empty-bin-'); + vi.resetModules(); + writeFileSync(join(cwd, 'visible.txt'), 'visible', 'utf-8'); + + try { + const result = await executeTool( + collectTools((await import('../../src/tools/files.js')).registerFileTools).get('LS'), + { path: '.' }, + cwd, + ); + + expect(firstText(result)).toContain('visible.txt'); + expect(result.details).toEqual({ path: realpathSync(cwd) }); + } finally { + process.env.PATH = oldPath; + vi.resetModules(); + } + }); + it('reports filesystem errors for invalid file operations', async () => { const cwd = tempDir('pi-grok-cli-files-'); mkdirSync(join(cwd, 'dir')); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index dbff382..d567bdf 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -51,6 +51,20 @@ async function withFindFallbackTools( } } +async function withNoSearchBinaries( + run: (tools: ReturnType) => Promise, +) { + const oldPath = process.env.PATH; + process.env.PATH = tempDir('pi-grok-cli-empty-bin-'); + vi.resetModules(); + try { + await run(collectTools((await import('../../src/tools/search.js')).registerSearchTools)); + } finally { + process.env.PATH = oldPath; + vi.resetModules(); + } +} + describe('search tools', () => { it('greps matching file contents with include filters', async () => { const cwd = setupProject(); @@ -180,6 +194,15 @@ describe('search tools', () => { }); }); + it('globs files without ripgrep or Unix find on PATH', async () => { + const cwd = setupProject(); + await withNoSearchBinaries(async (fallbackTools) => { + const result = await executeTool(fallbackTools.get('Glob'), { pattern: 'src/**/*.ts' }, cwd); + + expectGlobResult(cwd, result); + }); + }); + it('sorts glob results by modification time newest first', async () => { const cwd = setupProject(); const oldTime = new Date('2024-01-01T00:00:00.000Z'); From eb1f664ea3844ab6334ae898f1782d87b46101c3 Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 03:54:41 +0900 Subject: [PATCH 23/24] Fix portable grep fallback --- src/tools/rendering.ts | 90 ++++++++++++++++++++++++++++++++++---- src/tools/search.ts | 62 ++++---------------------- tests/tools/search.test.ts | 13 ++++++ 3 files changed, 104 insertions(+), 61 deletions(-) diff --git a/src/tools/rendering.ts b/src/tools/rendering.ts index 724e90a..084bcda 100644 --- a/src/tools/rendering.ts +++ b/src/tools/rendering.ts @@ -1,4 +1,6 @@ import { execFile } from 'node:child_process'; +import { promises as fs } from 'node:fs'; +import { basename, join, relative } from 'node:path'; import { promisify } from 'node:util'; import { Text } from '@earendil-works/pi-tui'; @@ -35,6 +37,55 @@ export function truncateChars(output: string): string { return output; } +export function globToRegExp(pattern: string) { + let source = '^'; + for (let i = 0; i < pattern.length; i += 1) { + const char = pattern[i]; + const next = pattern[i + 1]; + if (char === '*' && next === '*' && pattern[i + 2] === '/') { + source += '(?:.*/)?'; + i += 2; + } else if (char === '*' && next === '*') { + source += '.*'; + i += 1; + } else if (char === '*') { + source += '[^/]*'; + } else if (char === '?') { + source += '[^/]'; + } else { + source += char.replace(/[|\\{}()[\]^$+?.]/g, '\\$&'); + } + } + return new RegExp(`${source}$`); +} + +export function normalizePath(filePath: string) { + return filePath.replaceAll('\\', '/'); +} + +export async function listFilesRecursive( + searchPath: string, + signal?: AbortSignal, +): Promise { + if (signal?.aborted) throw new Error('The operation was aborted'); + const stats = await fs.stat(searchPath); + if (stats.isFile()) return [searchPath]; + if (!stats.isDirectory()) return []; + + return ( + await Promise.all( + ( + await fs.readdir(searchPath, { withFileTypes: true }) + ).map((entry) => { + const entryPath = join(searchPath, entry.name); + if (entry.isDirectory()) return listFilesRecursive(entryPath, signal); + if (entry.isFile()) return [entryPath]; + return []; + }), + ) + ).flat(); +} + let rgAvailable: boolean | undefined; export async function hasRipgrep(): Promise { if (rgAvailable !== undefined) return rgAvailable; @@ -163,8 +214,13 @@ export function toolError(error: unknown, toolName: string, emptyDetails: T): export async function execWithRgFallback( rgArgs: string[], - grepArgs: string[], - options: { cwd: string; signal?: AbortSignal }, + options: { + cwd: string; + signal?: AbortSignal; + pattern: string; + searchPath: string; + include?: string; + }, ): Promise { if (await hasRipgrep()) { const result = await execFileAsync('rg', rgArgs, { @@ -174,10 +230,28 @@ export async function execWithRgFallback( }); return result.stdout; } - const result = await execFileAsync('grep', grepArgs, { - cwd: options.cwd, - maxBuffer: MAX_OUTPUT_BYTES, - signal: options.signal, - }); - return result.stdout; + + const regex = new RegExp(options.pattern); + const matcher = options.include ? globToRegExp(normalizePath(options.include)) : undefined; + return ( + await Promise.all( + ( + await listFilesRecursive(options.searchPath, options.signal) + ) + .filter((file) => { + if (!matcher) return true; + if (!options.include?.includes('/')) return matcher.test(basename(file)); + return matcher.test(normalizePath(relative(options.cwd, file))); + }) + .map(async (file) => + ( + await fs.readFile(file, 'utf8') + ) + .split(/\r?\n/) + .flatMap((line, index) => (regex.test(line) ? `${file}:${index + 1}:${line}` : [])), + ), + ) + ) + .flat() + .join('\n'); } diff --git a/src/tools/search.ts b/src/tools/search.ts index c8e7d21..b3ad03e 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -1,13 +1,16 @@ import { execFile } from 'node:child_process'; -import { promises as fs, statSync } from 'node:fs'; -import { basename, join, relative, resolve } from 'node:path'; +import { statSync } from 'node:fs'; +import { basename, relative, resolve } from 'node:path'; import { promisify } from 'node:util'; import { Type } from '@earendil-works/pi-ai'; import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; import { execWithRgFallback, + globToRegExp, hasRipgrep, + listFilesRecursive, MAX_OUTPUT_BYTES, + normalizePath, numberDetail, recordFrom, renderResultText, @@ -24,32 +27,6 @@ const execFileAsync = promisify(execFile); type GrepArgs = { pattern: string; path?: string; include?: string }; type GlobArgs = { pattern: string; path?: string }; -function globToRegExp(pattern: string) { - let source = '^'; - for (let i = 0; i < pattern.length; i += 1) { - const char = pattern[i]; - const next = pattern[i + 1]; - if (char === '*' && next === '*' && pattern[i + 2] === '/') { - source += '(?:.*/)?'; - i += 2; - } else if (char === '*' && next === '*') { - source += '.*'; - i += 1; - } else if (char === '*') { - source += '[^/]*'; - } else if (char === '?') { - source += '[^/]'; - } else { - source += char.replace(/[|\\{}()[\]^$+?.]/g, '\\$&'); - } - } - return new RegExp(`${source}$`); -} - -function normalizePath(filePath: string) { - return filePath.replaceAll('\\', '/'); -} - function sortByModifiedNewest(files: string[]) { return files.sort((a, b) => { const delta = statSync(b).mtimeMs - statSync(a).mtimeMs; @@ -58,26 +35,6 @@ function sortByModifiedNewest(files: string[]) { }); } -async function listFilesRecursive(searchPath: string, signal?: AbortSignal): Promise { - if (signal?.aborted) throw new Error('The operation was aborted'); - const stats = await fs.stat(searchPath); - if (stats.isFile()) return [searchPath]; - if (!stats.isDirectory()) return []; - - return ( - await Promise.all( - ( - await fs.readdir(searchPath, { withFileTypes: true }) - ).map((entry) => { - const entryPath = join(searchPath, entry.name); - if (entry.isDirectory()) return listFilesRecursive(entryPath, signal); - if (entry.isFile()) return [entryPath]; - return []; - }), - ) - ).flat(); -} - export function registerSearchTools(pi: ExtensionAPI) { const GrepParams = Type.Object({ pattern: Type.String({ @@ -119,13 +76,12 @@ export function registerSearchTools(pi: ExtensionAPI) { if (params.include) rgArgs.push('--glob', params.include); rgArgs.push('--', params.pattern, searchPath); - const grepArgs = ['-r', '-n', '-H', '--color=never']; - if (params.include) grepArgs.push(`--include=${params.include}`); - grepArgs.push('--', params.pattern, searchPath); - - const stdout = await execWithRgFallback(rgArgs, grepArgs, { + const stdout = await execWithRgFallback(rgArgs, { cwd: ctx.cwd, signal, + pattern: params.pattern, + searchPath, + include: params.include, }); const lines = stdout.trim().split('\n').filter(Boolean); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index d567bdf..b4d5bf1 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -203,6 +203,19 @@ describe('search tools', () => { }); }); + it('greps files without ripgrep or Unix grep on PATH', async () => { + const cwd = setupProject(); + await withNoSearchBinaries(async (fallbackTools) => { + const result = await executeTool( + fallbackTools.get('Grep'), + { pattern: 'needle', path: 'src', include: '*.ts' }, + cwd, + ); + + expectGrepResult(cwd, result); + }); + }); + it('sorts glob results by modification time newest first', async () => { const cwd = setupProject(); const oldTime = new Date('2024-01-01T00:00:00.000Z'); From 381868e5c2ef81c461eeca4e34519c562ace7faa Mon Sep 17 00:00:00 2001 From: kenryu42 Date: Wed, 3 Jun 2026 04:05:10 +0900 Subject: [PATCH 24/24] fix: make glob sorting resilient to deleted files --- src/tools/search.ts | 12 ++++++++++-- tests/tools/search.test.ts | 16 ++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/tools/search.ts b/src/tools/search.ts index b3ad03e..69a784f 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -27,9 +27,17 @@ const execFileAsync = promisify(execFile); type GrepArgs = { pattern: string; path?: string; include?: string }; type GlobArgs = { pattern: string; path?: string }; -function sortByModifiedNewest(files: string[]) { +function modifiedTimeMs(file: string) { + try { + return statSync(file).mtimeMs; + } catch { + return 0; + } +} + +export function sortByModifiedNewest(files: string[]) { return files.sort((a, b) => { - const delta = statSync(b).mtimeMs - statSync(a).mtimeMs; + const delta = modifiedTimeMs(b) - modifiedTimeMs(a); if (delta !== 0) return delta; return a.localeCompare(b); }); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts index b4d5bf1..8237b7a 100644 --- a/tests/tools/search.test.ts +++ b/tests/tools/search.test.ts @@ -1,7 +1,7 @@ -import { mkdirSync, symlinkSync, utimesSync, writeFileSync } from 'node:fs'; +import { mkdirSync, rmSync, symlinkSync, utimesSync, writeFileSync } from 'node:fs'; import { join } from 'node:path'; import { describe, expect, it, vi } from 'vitest'; -import { registerSearchTools } from '../../src/tools/search.js'; +import { registerSearchTools, sortByModifiedNewest } from '../../src/tools/search.js'; import { collectTools, executePreparedTool, @@ -234,6 +234,18 @@ describe('search tools', () => { ]); }); + it('sorts existing glob results when another match is deleted before stat', () => { + const cwd = setupProject(); + const deleted = join(cwd, 'src', 'deleted.ts'); + writeFileSync(deleted, 'deleted\n', 'utf-8'); + rmSync(deleted); + + expect(sortByModifiedNewest([deleted, join(cwd, 'src', 'alpha.ts')])).toEqual([ + join(cwd, 'src', 'alpha.ts'), + deleted, + ]); + }); + it('renders grep calls and result states', () => { const grep = collectTools(registerSearchTools).get('Grep'); const result = {