diff --git a/.gitignore b/.gitignore index 45a1e50..b76e3ab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ node_modules/ coverage dist/ +docs/ *.tgz .DS_Store diff --git a/.ignore b/.ignore new file mode 100644 index 0000000..00a2480 --- /dev/null +++ b/.ignore @@ -0,0 +1 @@ +!docs/ \ No newline at end of file diff --git a/.release-tools/cache.json b/.release-tools/cache.json index 658e7f1..692beef 100644 --- a/.release-tools/cache.json +++ b/.release-tools/cache.json @@ -1,17 +1,17 @@ { - "scripts": { - "typecheck": "tsc --noEmit", - "knip": "knip --production", - "lint": "biome check --write .", - "lint:ci": "biome ci .", - "check-duplicates": "bunx jscpd src tests --exitCode 1 --reporters ai --noTips", - "check": "bun run lint && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", - "check:ci": "bun run lint:ci && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", - "prepare": "husky" - }, - "lintStaged": { - "*": ["biome check --write --no-errors-on-unmatched"] - }, - "tsconfig": "{\n\t\"compilerOptions\": {\n\t\t\"target\": \"ES2022\",\n\t\t\"module\": \"ES2022\",\n\t\t\"moduleResolution\": \"bundler\",\n\t\t\"strict\": true,\n\t\t\"esModuleInterop\": true,\n\t\t\"skipLibCheck\": true,\n\t\t\"forceConsistentCasingInFileNames\": true,\n\t\t\"resolveJsonModule\": true,\n\t\t\"declaration\": true,\n\t\t\"declarationMap\": true,\n\t\t\"sourceMap\": true,\n\t\t\"outDir\": \"./dist\",\n\t\t\"rootDir\": \".\"\n\t},\n\t\"include\": [\"src/**/*.ts\"],\n\t\"exclude\": [\"node_modules\", \"dist\"]\n}\n", - "installedDeps": [] + "scripts": { + "typecheck": "tsc --noEmit", + "knip": "knip --production", + "lint": "biome check --write .", + "lint:ci": "biome ci .", + "check-duplicates": "bunx jscpd src tests --exitCode 1 --reporters ai --noTips", + "check": "bun run lint && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", + "check:ci": "bun run lint:ci && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", + "prepare": "husky" + }, + "lintStaged": { + "*": ["biome check --write --no-errors-on-unmatched"] + }, + "tsconfig": "{\n\t\"compilerOptions\": {\n\t\t\"target\": \"ES2022\",\n\t\t\"module\": \"ES2022\",\n\t\t\"moduleResolution\": \"bundler\",\n\t\t\"strict\": true,\n\t\t\"esModuleInterop\": true,\n\t\t\"skipLibCheck\": true,\n\t\t\"forceConsistentCasingInFileNames\": true,\n\t\t\"resolveJsonModule\": true,\n\t\t\"declaration\": true,\n\t\t\"declarationMap\": true,\n\t\t\"sourceMap\": true,\n\t\t\"outDir\": \"./dist\",\n\t\t\"rootDir\": \".\"\n\t},\n\t\"include\": [\"src/**/*.ts\"],\n\t\"exclude\": [\"node_modules\", \"dist\"]\n}\n", + "installedDeps": [] } diff --git a/.release-tools/config.ts b/.release-tools/config.ts index 8636522..c595f22 100644 --- a/.release-tools/config.ts +++ b/.release-tools/config.ts @@ -1,7 +1,7 @@ -import { defineConfig } from "release-tools/config"; +import { defineConfig } from 'release-tools/config'; export default defineConfig({ - packageName: "pi-grok-cli", - repo: "kenryu42/pi-grok-cli", - excludedAuthors: ["kenryu42"], + packageName: 'pi-grok-cli', + repo: 'kenryu42/pi-grok-cli', + excludedAuthors: ['kenryu42'], }); diff --git a/README.md b/README.md index 3570791..2e2dce3 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,8 @@ [![Version](https://img.shields.io/github/v/tag/kenryu42/pi-grok-cli?label=version&color=blue)](https://github.com/kenryu42/pi-grok-cli) [![License: MIT](https://img.shields.io/badge/License-MIT-red.svg)](https://opensource.org/licenses/MIT) -A pi extension that connects to **Grok CLI's API endpoint** . - -## Why? - -The Grok CLI uhas access to models **not available** on the public `api.x.ai` API yet: +A pi extension that connects to **Grok CLI's API endpoint**. +The Grok CLI has access to models **not available** on the public `api.x.ai` API yet: | Model | Public API (`api.x.ai`) | Grok CLI | |---|---|---| @@ -18,6 +15,16 @@ The Grok CLI uhas access to models **not available** on the public `api.x.ai` AP `grok-composer-2.5-fast` is Cursor's Composer 2.5 model, a purpose-built agentic coding model optimized for long-horizon coding tasks. +## Cursor Tool Compatibility + +Grok CLI models are trained to use Cursor-style coding tools. This extension includes compatibility shims so those models can keep using familiar tool calls inside pi: + +- File tools: `Read`, `Write`, `StrReplace`, `Edit`, `Delete`, and `LS` +- Search tools: `Grep` and `Glob` +- Terminal tool: `Shell` + +The shims also normalize common Cursor/Grok argument shapes, such as `contents` for writes, `glob_pattern` for file search, `glob_filter` for grep filters, and `old_string`/`new_string` or `oldText`/`newText` for exact replacements. This keeps agentic coding workflows moving instead of failing on tool schema mismatches. + ## Requirements You need an active Grok subscription or an X Premium subscription with Grok access to use this extension. @@ -66,4 +73,4 @@ Select **"Grok CLI"** from the provider list. This opens the xAI OAuth page in y | `PI_GROK_CLI_MODELS` | (all models) | Comma-separated model IDs to expose | | `PI_GROK_CLI_OAUTH_CLIENT_ID` | `b1a00492-...` | Override OAuth client ID | | `PI_GROK_CLI_OAUTH_SCOPE` | `openid profile email offline_access grok-cli:access api:access` | Override OAuth scopes | -| `GROK_CLI_OAUTH_TOKEN` | — | Direct token bypass (no auto-refresh) | \ No newline at end of file +| `GROK_CLI_OAUTH_TOKEN` | — | Direct token bypass that skips OAuth entirely. No automatic refresh or renewal is performed; provide a valid external access token and replace or rotate it when it expires. | diff --git a/biome.json b/biome.json index eb66ea5..d8abd7a 100644 --- a/biome.json +++ b/biome.json @@ -1,34 +1,38 @@ { - "$schema": "https://biomejs.dev/schemas/2.4.16/schema.json", - "vcs": { - "enabled": true, - "clientKind": "git", - "useIgnoreFile": true - }, - "files": { - "ignoreUnknown": false - }, - "formatter": { - "enabled": true, - "indentStyle": "tab" - }, - "linter": { - "enabled": true, - "rules": { - "recommended": true - } - }, - "javascript": { - "formatter": { - "quoteStyle": "double" - } - }, - "assist": { - "enabled": true, - "actions": { - "source": { - "organizeImports": "on" - } - } - } + "$schema": "https://biomejs.dev/schemas/2.4.16/schema.json", + "vcs": { + "enabled": true, + "clientKind": "git", + "useIgnoreFile": true + }, + "files": { + "ignoreUnknown": false + }, + "formatter": { + "enabled": true, + "indentStyle": "space", + "indentWidth": 2, + "lineWidth": 100 + }, + "linter": { + "enabled": true, + "rules": { + "recommended": true + } + }, + "javascript": { + "formatter": { + "quoteStyle": "single", + "trailingCommas": "all", + "semicolons": "always" + } + }, + "assist": { + "enabled": true, + "actions": { + "source": { + "organizeImports": "on" + } + } + } } diff --git a/bun.lock b/bun.lock index ec5bcd7..88c55bc 100644 --- a/bun.lock +++ b/bun.lock @@ -8,6 +8,7 @@ "@biomejs/biome": "2.4.16", "@earendil-works/pi-ai": "^0.78.0", "@earendil-works/pi-coding-agent": "^0.78.0", + "@earendil-works/pi-tui": "^0.78.0", "@vitest/coverage-v8": "^4.1.8", "release-tools": "github:kenryu42/release-tools", "husky": "^9.1.7", @@ -20,6 +21,7 @@ "peerDependencies": { "@earendil-works/pi-ai": "*", "@earendil-works/pi-coding-agent": "*", + "@earendil-works/pi-tui": "*", }, }, }, diff --git a/knip.json b/knip.json index 8a668a3..5c61008 100644 --- a/knip.json +++ b/knip.json @@ -1,5 +1,5 @@ { - "entry": ["src/shared/backup_worker.ts"], - "project": ["**/*.ts", "!**/*.d.ts", "!.release-tools/**"], - "ignoreDependencies": ["lint-staged"] + "entry": ["src/shared/backup_worker.ts"], + "project": ["**/*.ts", "!**/*.d.ts", "!.release-tools/**"], + "ignoreDependencies": ["lint-staged"] } diff --git a/package.json b/package.json index ac8380b..ce8f914 100644 --- a/package.json +++ b/package.json @@ -1,74 +1,76 @@ { - "name": "pi-grok-cli", - "version": "0.1.1", - "description": "Use Grok CLI's API endpoint in pi.", - "keywords": [ - "pi-package", - "pi-extension", - "xai", - "grok", - "grok-cli", - "oauth", - "xai-oauth" - ], - "type": "module", - "main": "./src/index.ts", - "files": [ - "README.md", - "src", - "tsconfig.json" - ], - "scripts": { - "test": "vitest run --reporter=agent", - "coverage": "vitest run --reporter=agent --coverage", - "typecheck": "tsc --noEmit", - "prepack": "bun run test && bun run coverage && bun run typecheck", - "knip": "knip --production", - "lint": "biome check --write .", - "lint:ci": "biome ci .", - "check": "bun run lint && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", - "check:ci": "bun run lint:ci && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", - "prepare": "husky", - "check-duplicates": "bunx jscpd src tests --exitCode 1 --reporters ai --noTips" - }, - "author": { - "name": "J Liew", - "email": "jliew@420024lab.com" - }, - "license": "MIT", - "repository": { - "type": "git", - "url": "git+https://github.com/kenryu42/pi-grok-cli.git" - }, - "bugs": { - "url": "https://github.com/kenryu42/pi-grok-cli/issues" - }, - "homepage": "https://github.com/kenryu42/pi-grok-cli#readme", - "pi": { - "extensions": [ - "./src/index.ts" - ] - }, - "peerDependencies": { - "@earendil-works/pi-ai": "*", - "@earendil-works/pi-coding-agent": "*" - }, - "devDependencies": { - "@biomejs/biome": "2.4.16", - "@earendil-works/pi-ai": "^0.78.0", - "@earendil-works/pi-coding-agent": "^0.78.0", - "@vitest/coverage-v8": "^4.1.8", - "husky": "^9.1.7", - "jscpd": "^4.2.4", - "knip": "^6.15.0", - "lint-staged": "^17.0.7", - "release-tools": "github:kenryu42/release-tools", - "typescript": "^6.0.3", - "vitest": "^4.1.8" - }, - "lint-staged": { - "*": [ - "biome check --write --no-errors-on-unmatched" - ] - } + "name": "pi-grok-cli", + "version": "0.1.1", + "description": "Use Grok CLI's API endpoint in pi.", + "keywords": [ + "pi-package", + "pi-extension", + "xai", + "grok", + "grok-cli", + "oauth", + "xai-oauth" + ], + "type": "module", + "main": "./src/index.ts", + "files": [ + "README.md", + "src", + "tsconfig.json" + ], + "scripts": { + "test": "vitest run --reporter=agent", + "coverage": "vitest run --reporter=agent --coverage", + "typecheck": "tsc --noEmit", + "prepack": "bun run test && bun run coverage && bun run typecheck", + "knip": "knip --production", + "lint": "biome check --write .", + "lint:ci": "biome ci .", + "check": "bun run lint && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", + "check:ci": "bun run lint:ci && bun run typecheck && bun run knip && bun run check-duplicates && AGENT=1 bun run coverage", + "prepare": "husky", + "check-duplicates": "bunx jscpd src tests --exitCode 1 --reporters ai --noTips" + }, + "author": { + "name": "J Liew", + "email": "jliew@420024lab.com" + }, + "license": "MIT", + "repository": { + "type": "git", + "url": "git+https://github.com/kenryu42/pi-grok-cli.git" + }, + "bugs": { + "url": "https://github.com/kenryu42/pi-grok-cli/issues" + }, + "homepage": "https://github.com/kenryu42/pi-grok-cli#readme", + "pi": { + "extensions": [ + "./src/index.ts" + ] + }, + "peerDependencies": { + "@earendil-works/pi-ai": "*", + "@earendil-works/pi-coding-agent": "*", + "@earendil-works/pi-tui": "*" + }, + "devDependencies": { + "@biomejs/biome": "2.4.16", + "@earendil-works/pi-ai": "^0.78.0", + "@earendil-works/pi-coding-agent": "^0.78.0", + "@earendil-works/pi-tui": "^0.78.0", + "@vitest/coverage-v8": "^4.1.8", + "husky": "^9.1.7", + "jscpd": "^4.2.4", + "knip": "^6.15.0", + "lint-staged": "^17.0.7", + "release-tools": "github:kenryu42/release-tools", + "typescript": "^6.0.3", + "vitest": "^4.1.8" + }, + "lint-staged": { + "*": [ + "biome check --write --no-errors-on-unmatched" + ] + } } diff --git a/src/auth/oauth.ts b/src/auth/oauth.ts new file mode 100644 index 0000000..aba62a6 --- /dev/null +++ b/src/auth/oauth.ts @@ -0,0 +1,509 @@ +/** + * xAI Grok OAuth 2.0 + PKCE implementation. + * + * Uses Web Crypto API (crypto.subtle) for PKCE so the extension is + * portable across Node versions and potential non-Node runtimes. + * + * The OAuth flow is identical to pi-grok — same client_id, same auth.x.ai + * issuer. The difference is in the API endpoint: this extension targets + * cli-chat-proxy.grok.com instead of api.x.ai. + */ + +import { createServer } from 'node:http'; +import { XaiErrorCode, XaiOAuthError } from '../shared/errors.js'; + +// ─── Constants ──────────────────────────────────────────────────────────────── + +const DEFAULT_BASE_URL = 'https://cli-chat-proxy.grok.com/v1'; +const ISSUER = 'https://auth.x.ai'; +const DISCOVERY_URL = `${ISSUER}/.well-known/openid-configuration`; +const CLIENT_ID = process.env.PI_GROK_CLI_OAUTH_CLIENT_ID || 'b1a00492-073a-47ea-816f-4c329264a828'; +const SCOPE = + process.env.PI_GROK_CLI_OAUTH_SCOPE || + 'openid profile email offline_access grok-cli:access api:access'; +const CALLBACK_HOST = process.env.PI_GROK_CLI_CALLBACK_HOST || '127.0.0.1'; +const CALLBACK_PORT = Number.parseInt(process.env.PI_GROK_CLI_CALLBACK_PORT || '56122', 10); +const CALLBACK_PATH = '/callback'; +/** Refresh 120s before actual expiry. */ +const REFRESH_SKEW_MS = 120_000; +const TOKEN_REQUEST_TIMEOUT_MS = Number.parseInt( + process.env.PI_GROK_CLI_TOKEN_TIMEOUT_MS || '30000', + 10, +); + +// ─── Types ──────────────────────────────────────────────────────────────────── + +interface XaiDiscovery { + authorization_endpoint: string; + token_endpoint: string; +} + +export interface XaiOAuthCredentials { + [key: string]: unknown; + refresh: string; + access: string; + expires: number; + tokenEndpoint?: string; + discovery?: XaiDiscovery; + idToken?: string; + tokenType?: string; + baseUrl?: string; +} + +// ─── Helpers ────────────────────────────────────────────────────────────────── + +export function getBaseUrl(): string { + return ( + process.env.PI_GROK_CLI_BASE_URL || + process.env.GROK_CLI_BASE_URL || + DEFAULT_BASE_URL + ).replace(/\/+$/, ''); +} + +function base64Url(buffer: ArrayBuffer | Uint8Array): string { + const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); + let binary = ''; + for (const b of bytes) binary += String.fromCharCode(b); + return btoa(binary).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, ''); +} + +// ─── PKCE ───────────────────────────────────────────────────────────────────── + +async function generatePKCE(): Promise<{ + verifier: string; + challenge: string; +}> { + const verifier = base64Url(crypto.getRandomValues(new Uint8Array(32))); + const hash = await crypto.subtle.digest('SHA-256', new TextEncoder().encode(verifier)); + return { verifier, challenge: base64Url(hash) }; +} + +// ─── Endpoint validation ────────────────────────────────────────────────────── + +function validateEndpoint(value: string, field: string): string { + let url: URL; + try { + url = new URL(value); + } catch { + throw new XaiOAuthError( + `xAI OAuth discovery returned invalid ${field}: ${value}`, + XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + ); + } + if (url.protocol !== 'https:') { + throw new XaiOAuthError( + `xAI OAuth ${field} must use HTTPS: ${value}`, + XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + ); + } + const host = url.hostname.toLowerCase(); + if ( + host !== 'x.ai' && + host !== 'auth.x.ai' && + host !== 'accounts.x.ai' && + !host.endsWith('.x.ai') + ) { + throw new XaiOAuthError( + `Refusing non-xAI OAuth ${field}: ${value}`, + XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + ); + } + return url.toString(); +} + +// ─── OIDC Discovery ────────────────────────────────────────────────────────── + +async function discover(): Promise { + let response: Response; + try { + response = await fetch(DISCOVERY_URL, { + headers: { Accept: 'application/json' }, + signal: AbortSignal.timeout(15_000), + }); + } catch (cause) { + throw new XaiOAuthError( + `xAI OIDC discovery failed: ${cause instanceof Error ? cause.message : String(cause)}`, + XaiErrorCode.DISCOVERY_FAILED, + ); + } + if (!response.ok) { + throw new XaiOAuthError( + `xAI OIDC discovery returned ${response.status}`, + XaiErrorCode.DISCOVERY_FAILED, + ); + } + let payload: Record; + try { + payload = (await response.json()) as Record; + } catch (cause) { + throw new XaiOAuthError( + `xAI OIDC discovery returned invalid JSON: ${cause instanceof Error ? cause.message : String(cause)}`, + XaiErrorCode.DISCOVERY_FAILED, + ); + } + const authorizationEndpoint = validateEndpoint( + String(payload.authorization_endpoint ?? ''), + 'authorization_endpoint', + ); + const tokenEndpoint = validateEndpoint(String(payload.token_endpoint ?? ''), 'token_endpoint'); + return { + authorization_endpoint: authorizationEndpoint, + token_endpoint: tokenEndpoint, + }; +} + +// ─── Loopback callback server ──────────────────────────────────────────────── + +interface CallbackResult { + code?: string; + state?: string; + error?: string; + errorDescription?: string; +} + +function startCallbackServer(): Promise<{ + server: import('node:http').Server; + redirectUri: string; + waitForCallback: (timeoutMs: number) => Promise; +}> { + let settle: ((value: CallbackResult) => void) | undefined; + const callbackPromise = new Promise((resolve) => { + settle = resolve; + }); + + const server = createServer((req, res) => { + try { + const origin = req.headers.origin; + if (origin === 'https://accounts.x.ai' || origin === 'https://auth.x.ai') { + res.setHeader('Access-Control-Allow-Origin', origin); + res.setHeader('Access-Control-Allow-Methods', 'GET, OPTIONS'); + res.setHeader('Access-Control-Allow-Headers', 'Content-Type'); + res.setHeader('Access-Control-Allow-Private-Network', 'true'); + res.setHeader('Vary', 'Origin'); + } + if (req.method === 'OPTIONS') { + res.statusCode = 204; + res.end(); + return; + } + + const url = new URL(req.url ?? '/', `http://${CALLBACK_HOST}`); + if (url.pathname !== CALLBACK_PATH) { + res.statusCode = 404; + res.end('Not found'); + return; + } + + const result: CallbackResult = { + code: url.searchParams.get('code') ?? undefined, + state: url.searchParams.get('state') ?? undefined, + error: url.searchParams.get('error') ?? undefined, + errorDescription: url.searchParams.get('error_description') ?? undefined, + }; + + res.statusCode = result.error ? 400 : 200; + res.setHeader('Content-Type', 'text/html; charset=utf-8'); + const html = result.error + ? '

xAI authorization failed.

You can close this tab.' + : '

xAI authorization received.

You can close this tab.'; + res.end(html); + settle?.(result); + } catch { + res.statusCode = 500; + res.end('Internal error'); + } + }); + + const listen = (port: number) => + new Promise((resolve, reject) => { + server.once('error', reject); + server.listen(port, CALLBACK_HOST, () => { + server.removeListener('error', reject); + const addr = server.address(); + resolve(typeof addr === 'object' && addr ? addr.port : port); + }); + }); + + return (async () => { + let actualPort: number; + try { + actualPort = await listen(CALLBACK_PORT); + } catch (firstError) { + try { + actualPort = await listen(0); + } catch (secondError) { + const errorDescription = `Could not bind xAI OAuth callback server on ${CALLBACK_HOST}:${CALLBACK_PORT} or an ephemeral port: ${secondError instanceof Error ? secondError.message : String(secondError)} (initial error: ${firstError instanceof Error ? firstError.message : String(firstError)})`; + return { + server, + redirectUri: `http://${CALLBACK_HOST}:${CALLBACK_PORT}${CALLBACK_PATH}`, + waitForCallback: async () => ({ + error: XaiErrorCode.CALLBACK_BIND_FAILED, + errorDescription, + }), + }; + } + } + const redirectUri = `http://${CALLBACK_HOST}:${actualPort}${CALLBACK_PATH}`; + return { + server, + redirectUri, + waitForCallback: (timeoutMs: number) => + Promise.race([ + callbackPromise, + new Promise((resolve) => + setTimeout( + () => + resolve({ + error: XaiErrorCode.CALLBACK_TIMEOUT, + errorDescription: 'Timed out waiting for xAI OAuth callback.', + }), + timeoutMs, + ), + ), + ]), + }; + })(); +} + +// ─── Token exchange ─────────────────────────────────────────────────────────── + +async function fetchTokenResponse( + tokenEndpoint: string, + body: URLSearchParams, + errorCode: string, + label: string, +): Promise { + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), TOKEN_REQUEST_TIMEOUT_MS); + try { + return await fetch(tokenEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }, + body, + signal: controller.signal, + }); + } catch (cause) { + throw new XaiOAuthError( + `xAI ${label} failed: ${cause instanceof Error ? cause.message : String(cause)}`, + errorCode, + ); + } finally { + clearTimeout(timeout); + } +} + +async function tokenResponseText(response: Response) { + try { + return await response.text(); + } catch (cause) { + return `unable to read response body: ${cause instanceof Error ? cause.message : String(cause)}`; + } +} + +async function tokenResponseJson( + response: Response, + errorCode: string, + label: string, +): Promise> { + try { + return (await response.json()) as Record; + } catch (cause) { + throw new XaiOAuthError( + `xAI ${label} returned invalid JSON: ${cause instanceof Error ? cause.message : String(cause)}`, + errorCode, + ); + } +} + +async function exchangeCode( + tokenEndpoint: string, + code: string, + redirectUri: string, + verifier: string, +): Promise { + const response = await fetchTokenResponse( + tokenEndpoint, + new URLSearchParams({ + grant_type: 'authorization_code', + client_id: CLIENT_ID, + code, + redirect_uri: redirectUri, + code_verifier: verifier, + }), + XaiErrorCode.TOKEN_EXCHANGE_FAILED, + 'token exchange', + ); + if (!response.ok) { + throw new XaiOAuthError( + `xAI token exchange failed: ${response.status} ${await tokenResponseText(response)}`, + XaiErrorCode.TOKEN_EXCHANGE_FAILED, + ); + } + const payload = await tokenResponseJson( + response, + XaiErrorCode.TOKEN_EXCHANGE_FAILED, + 'token exchange', + ); + const access = String(payload.access_token ?? ''); + const refresh = String(payload.refresh_token ?? ''); + if (!access) { + throw new XaiOAuthError( + 'xAI token exchange did not return access_token.', + XaiErrorCode.TOKEN_EXCHANGE_INVALID, + ); + } + if (!refresh) { + throw new XaiOAuthError( + 'xAI token exchange did not return refresh_token.', + XaiErrorCode.TOKEN_EXCHANGE_INVALID, + ); + } + const expiresIn = + typeof payload.expires_in === 'number' + ? payload.expires_in + : Number(payload.expires_in ?? 3600); + return { + access, + refresh, + expires: Date.now() + expiresIn * 1000 - REFRESH_SKEW_MS, + tokenEndpoint, + discovery: { authorization_endpoint: '', token_endpoint: tokenEndpoint }, + idToken: String(payload.id_token ?? ''), + tokenType: String(payload.token_type ?? 'Bearer'), + baseUrl: getBaseUrl(), + }; +} + +// ─── Login (called by pi's /login flow) ────────────────────────────────────── + +export async function login( + callbacks: import('@earendil-works/pi-ai').OAuthLoginCallbacks, +): Promise { + const discovery = await discover(); + const { verifier, challenge } = await generatePKCE(); + const state = base64Url(crypto.getRandomValues(new Uint8Array(16))); + const nonce = base64Url(crypto.getRandomValues(new Uint8Array(16))); + const callback = await startCallbackServer(); + + try { + const authUrl = new URL(discovery.authorization_endpoint); + authUrl.searchParams.set('response_type', 'code'); + authUrl.searchParams.set('client_id', CLIENT_ID); + authUrl.searchParams.set('redirect_uri', callback.redirectUri); + authUrl.searchParams.set('scope', SCOPE); + authUrl.searchParams.set('code_challenge', challenge); + authUrl.searchParams.set('code_challenge_method', 'S256'); + authUrl.searchParams.set('state', state); + authUrl.searchParams.set('nonce', nonce); + authUrl.searchParams.set('plan', 'generic'); + authUrl.searchParams.set('referrer', 'pi-grok-cli'); + + callbacks.onAuth({ + url: authUrl.toString(), + instructions: `Authorize xAI, then return to pi. Callback listener: ${callback.redirectUri}`, + }); + + const result = await callback.waitForCallback(180_000); + + if (result.error) { + const code = + result.error === XaiErrorCode.CALLBACK_BIND_FAILED || + result.error === XaiErrorCode.CALLBACK_TIMEOUT + ? result.error + : XaiErrorCode.AUTHORIZATION_FAILED; + throw new XaiOAuthError(result.errorDescription ?? result.error, code); + } + if (result.state !== state) { + throw new XaiOAuthError( + 'xAI OAuth state mismatch — possible CSRF.', + XaiErrorCode.STATE_MISMATCH, + ); + } + if (!result.code) { + throw new XaiOAuthError( + 'xAI OAuth callback did not include an authorization code.', + XaiErrorCode.CODE_MISSING, + ); + } + + const credentials = await exchangeCode( + discovery.token_endpoint, + result.code, + callback.redirectUri, + verifier, + ); + credentials.discovery = discovery; + return credentials; + } finally { + callback.server.close(); + } +} + +// ─── Token refresh ──────────────────────────────────────────────────────────── + +export async function refresh( + credentials: import('@earendil-works/pi-ai').OAuthCredentials, +): Promise { + const xai = credentials as XaiOAuthCredentials; + const tokenEndpoint = + xai.tokenEndpoint || xai.discovery?.token_endpoint || (await discover()).token_endpoint; + validateEndpoint(tokenEndpoint, 'token_endpoint'); + + if (!credentials.refresh) { + throw new XaiOAuthError( + 'Missing refresh_token. Re-login required.', + XaiErrorCode.REFRESH_MISSING, + true, + ); + } + + const response = await fetchTokenResponse( + tokenEndpoint, + new URLSearchParams({ + grant_type: 'refresh_token', + client_id: CLIENT_ID, + refresh_token: credentials.refresh, + }), + XaiErrorCode.REFRESH_FAILED, + 'token refresh', + ); + + if (!response.ok) { + const isFatal = response.status === 400 || response.status === 401 || response.status === 403; + throw new XaiOAuthError( + `xAI token refresh failed: ${response.status} ${await tokenResponseText(response)}`, + XaiErrorCode.REFRESH_FAILED, + isFatal, + ); + } + + const payload = await tokenResponseJson(response, XaiErrorCode.REFRESH_FAILED, 'token refresh'); + const access = String(payload.access_token ?? ''); + if (!access) { + throw new XaiOAuthError( + 'xAI token refresh did not return access_token.', + XaiErrorCode.REFRESH_FAILED, + true, + ); + } + + const refresh_new = String(payload.refresh_token ?? credentials.refresh); + const expiresIn = + typeof payload.expires_in === 'number' + ? payload.expires_in + : Number(payload.expires_in ?? 3600); + + return { + ...xai, + access, + refresh: refresh_new, + expires: Date.now() + expiresIn * 1000 - REFRESH_SKEW_MS, + tokenEndpoint, + idToken: String(payload.id_token ?? xai.idToken ?? ''), + tokenType: String(payload.token_type ?? xai.tokenType ?? 'Bearer'), + baseUrl: getBaseUrl(), + }; +} diff --git a/src/errors.ts b/src/errors.ts deleted file mode 100644 index cde32e6..0000000 --- a/src/errors.ts +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Typed error for xAI OAuth failures. - * - * Codes allow the login flow and stream handlers to distinguish - * retryable failures (network) from fatal ones (revoked refresh token). - */ -export class XaiOAuthError extends Error { - constructor( - message: string, - public readonly code: string, - public readonly reloginRequired = false, - ) { - super(message); - this.name = "XaiOAuthError"; - } -} - -/** Well-known error codes. */ -export const XaiErrorCode = { - /** OIDC discovery failed (network, invalid response). */ - DISCOVERY_FAILED: "discovery_failed", - /** Discovery endpoint returned a non-xAI origin. */ - DISCOVERY_INVALID_ORIGIN: "discovery_invalid_origin", - /** Authorization was denied or errored in the browser. */ - AUTHORIZATION_FAILED: "authorization_failed", - /** CSRF state mismatch between request and callback. */ - STATE_MISMATCH: "state_mismatch", - /** Callback did not include an authorization code. */ - CODE_MISSING: "code_missing", - /** Token exchange failed (network, invalid response). */ - TOKEN_EXCHANGE_FAILED: "token_exchange_failed", - /** Token exchange returned an invalid payload. */ - TOKEN_EXCHANGE_INVALID: "token_exchange_invalid", - /** Refresh token is missing or empty. */ - REFRESH_MISSING: "refresh_missing", - /** Token refresh failed (expired, revoked). */ - REFRESH_FAILED: "refresh_failed", - /** No credentials stored. */ - AUTH_MISSING: "auth_missing", - /** Loopback callback server could not bind. */ - CALLBACK_BIND_FAILED: "callback_bind_failed", - /** Loopback callback timed out. */ - CALLBACK_TIMEOUT: "callback_timeout", -} as const; diff --git a/src/index.ts b/src/index.ts index fb2b05a..ea9eb6b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,348 +1,7 @@ /** * pi-grok-cli — Grok CLI API provider for pi * - * Brings access to the Grok CLI's endpoint - * into pi. This endpoint has access to models not available on the public - * xAI API, including grok-composer-2.5-fast (Cursor's Composer 2.5 model). - * - * Environment variables: - * PI_GROK_CLI_BASE_URL - Override the API base URL - * PI_GROK_CLI_MODELS - Comma-separated model IDs to expose - * PI_GROK_CLI_OAUTH_CLIENT_ID - Override OAuth client ID - * PI_GROK_CLI_OAUTH_SCOPE - Override OAuth scopes - */ - -import { existsSync, mkdirSync, readFileSync, writeFileSync } from "node:fs"; -import { homedir } from "node:os"; -import { dirname, join } from "node:path"; -import { - type Api, - type AssistantMessageEventStream, - type Context, - type Model, - type OAuthCredentials, - type OAuthLoginCallbacks, - type SimpleStreamOptions, - streamSimpleOpenAIResponses, -} from "@earendil-works/pi-ai"; -import type { - ExtensionAPI, - ProviderConfig, -} from "@earendil-works/pi-coding-agent"; -import { XaiOAuthError } from "./errors.js"; -import { type GrokCliModelConfig, resolveModels } from "./models.js"; -import * as oauth from "./oauth.js"; -import { getBaseUrl, type XaiOAuthCredentials } from "./oauth.js"; -import { sanitizePayload } from "./sanitize.js"; - -// ─── Grok CLI version (observed from traffic capture) ───────────────────────── - -const GROK_CLI_VERSION = "0.2.16"; -const QUOTA_CACHE_FILE = "grok-cli-quota.json"; - -// ─── Rate limit cache (piggybacks on onResponse from normal traffic) ────────── - -interface RateLimitInfo { - remainingRequests: number; - limitRequests: number; - remainingTokens: number; - limitTokens: number; - contextWindow: number; - zeroDataRetention: boolean; - capturedAt: number; -} - -const cachedRateLimits = new Map(); - -function quotaCachePath() { - return join(homedir(), ".pi", QUOTA_CACHE_FILE); -} - -function isRateLimitInfo(value: unknown): value is RateLimitInfo { - if (!value || typeof value !== "object") return false; - const info = value as Record; - return ( - typeof info.remainingRequests === "number" && - typeof info.limitRequests === "number" && - typeof info.remainingTokens === "number" && - typeof info.limitTokens === "number" && - typeof info.contextWindow === "number" && - typeof info.zeroDataRetention === "boolean" && - typeof info.capturedAt === "number" - ); -} - -function loadQuotaCache() { - cachedRateLimits.clear(); - if (!existsSync(quotaCachePath())) return; - - try { - const payload = JSON.parse( - readFileSync(quotaCachePath(), "utf8"), - ) as Record; - const models = payload.models; - if (!models || typeof models !== "object") return; - - Object.entries(models).forEach(([model, rateLimit]) => { - if (isRateLimitInfo(rateLimit)) cachedRateLimits.set(model, rateLimit); - }); - } catch { - cachedRateLimits.clear(); - } -} - -function persistQuotaCache() { - try { - mkdirSync(dirname(quotaCachePath()), { recursive: true }); - writeFileSync( - quotaCachePath(), - JSON.stringify( - { version: 1, models: Object.fromEntries(cachedRateLimits) }, - null, - "\t", - ), - ); - } catch { - // Status remains cache-only; persistence failures should not break requests. - } -} - -/** - * Extract rate limit info from response headers. - * Returns undefined if no rate limit headers are present. - */ -function extractRateLimit( - h: Record, -): RateLimitInfo | undefined { - const remainingReqs = Number(h["x-ratelimit-remaining-requests"]); - const limitReqs = Number(h["x-ratelimit-limit-requests"]); - const remainingTokens = Number(h["x-ratelimit-remaining-tokens"]); - const limitTokens = Number(h["x-ratelimit-limit-tokens"]); - const contextWindow = Number(h["x-grok-context-window"]); - - if (Number.isNaN(remainingReqs) && Number.isNaN(remainingTokens)) - return undefined; - - return { - remainingRequests: remainingReqs, - limitRequests: limitReqs, - remainingTokens, - limitTokens, - contextWindow: contextWindow || 512_000, - zeroDataRetention: h["x-zero-data-retention"] === "true", - capturedAt: Date.now(), - }; -} - -function formatQuota(name: string, rateLimit: RateLimitInfo | undefined) { - if (!rateLimit) { - return [ - ` ${name}:`, - " no cached quota data — make a request with this model first", - ]; - } - - const ageSec = Math.round((Date.now() - rateLimit.capturedAt) / 1000); - const ageStr = - ageSec < 60 ? `${ageSec}s ago` : `${Math.round(ageSec / 60)}m ago`; - const lines = [` ${name}:`]; - lines.push(` Cached: ${ageStr}`); - lines.push( - ` Requests: ${rateLimit.remainingRequests}/${rateLimit.limitRequests} remaining`, - ); - lines.push( - ` Tokens: ${rateLimit.remainingTokens.toLocaleString()}/${rateLimit.limitTokens.toLocaleString()} remaining`, - ); - lines.push( - ` Context Limit: ${rateLimit.contextWindow.toLocaleString()} tokens`, - ); - if (rateLimit.zeroDataRetention) { - lines.push(" Data: Zero retention ✓"); - } - return lines; -} - -// ─── Stream function ───────────────────────────────────────────────────────── - -/** - * Stream function that adds Grok CLI-specific headers to requests. - * - * The real Grok CLI sends these headers: - * - x-grok-client-identifier: grok-shell - * - x-grok-client-version: 0.2.16 - * - x-grok-conv-id: - * - x-grok-model-override: - * - x-xai-token-auth: xai-grok-cli + * Brings access to the Grok CLI's endpoint into pi. */ -function streamGrokCli( - model: Model, - context: Context, - options?: SimpleStreamOptions, -): AssistantMessageEventStream { - const sessionId = options?.sessionId; - const headers: Record = { - ...options?.headers, - "x-grok-client-identifier": "pi-grok-cli", - "x-grok-client-version": GROK_CLI_VERSION, - "x-xai-token-auth": "xai-grok-cli", - "x-grok-model-override": model.id, - }; - - if (sessionId) { - headers["x-grok-conv-id"] = sessionId; - } - - return streamSimpleOpenAIResponses( - model as Model<"openai-responses">, - context, - { - ...options, - headers, - onResponse(response) { - const rateLimit = extractRateLimit(response.headers); - if (rateLimit) { - cachedRateLimits.set(model.id, rateLimit); - persistQuotaCache(); - } - options?.onResponse?.(response, model); - }, - }, - ); -} - -// ─── Extension entry point ─────────────────────────────────────────────────── - -export default function (pi: ExtensionAPI) { - loadQuotaCache(); - const baseUrl = getBaseUrl(); - const models = resolveModels(); - - // ── Register provider ───────────────────────────────────────────────── - pi.registerProvider("grok-cli", { - name: "Grok CLI", - baseUrl, - apiKey: "$GROK_CLI_OAUTH_TOKEN", - api: "openai-responses", - models: models.map((m: GrokCliModelConfig) => ({ - id: m.id, - name: m.name, - reasoning: m.reasoning, - thinkingLevelMap: m.thinkingLevelMap, - input: m.input, - cost: m.cost, - contextWindow: m.contextWindow, - maxTokens: m.maxTokens, - })), - oauth: { - name: "Grok CLI", - - async login(callbacks: OAuthLoginCallbacks): Promise { - return oauth.login(callbacks); - }, - - async refreshToken( - credentials: OAuthCredentials, - ): Promise { - return oauth.refresh(credentials); - }, - - getApiKey(credentials: OAuthCredentials): string { - return credentials.access; - }, - - modifyModels(models: Model[], credentials: OAuthCredentials) { - const effectiveBaseUrl = String( - (credentials as XaiOAuthCredentials).baseUrl ?? getBaseUrl(), - ).replace(/\/+$/, ""); - - return models.map((m) => - m.provider === "grok-cli" ? { ...m, baseUrl: effectiveBaseUrl } : m, - ); - }, - } satisfies ProviderConfig["oauth"], - - streamSimple: streamGrokCli, - }); - - // ── Payload sanitization via event ──────────────────────────────────── - pi.on("before_provider_request", (event, ctx) => { - if (ctx.model?.provider !== "grok-cli") return; - - const modelId = ctx.model?.id ?? ""; - const sessionId = ctx.sessionManager?.getSessionId(); - return sanitizePayload( - event.payload as Record, - modelId, - sessionId, - ); - }); - - // ── /grok-cli-status command ───────────────────────────────────────── - pi.registerCommand("grok-cli-status", { - description: "Show Grok CLI provider status, quota, and token health", - handler: async (_args, ctx) => { - const token = process.env.GROK_CLI_OAUTH_TOKEN; - if (token) { - ctx.ui.notify( - "⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available", - "warning", - ); - } - - try { - const registry = ctx.modelRegistry; - const grokModels = registry - .getAll() - .filter((m: Model) => m.provider === "grok-cli"); - if (grokModels.length === 0) { - ctx.ui.notify( - "Grok CLI: no models registered. Run /login grok-cli first.", - "warning", - ); - return; - } - - const modelNames = grokModels - .slice(0, 5) - .map((m: Model) => m.id) - .join(", "); - const suffix = - grokModels.length > 5 ? ` (+${grokModels.length - 5} more)` : ""; - ctx.ui.notify( - `✓ Grok CLI: ${grokModels.length} models available (${modelNames}${suffix})`, - "info", - ); - - const lines = [ - " Quota:", - "", - ...formatQuota("grok-build", cachedRateLimits.get("grok-build")), - "", - ...formatQuota( - "grok-composer-2.5-fast", - cachedRateLimits.get("grok-composer-2.5-fast"), - ), - ]; - ctx.ui.notify(lines.join("\n"), "info"); - } catch (err) { - const msg = - err instanceof XaiOAuthError - ? `${err.message} (code: ${err.code})` - : err instanceof Error - ? err.message - : String(err); - ctx.ui.notify(`Grok CLI: ${msg}`, "warning"); - } - }, - }); - // ── Warn on env bypass ──────────────────────────────────────────────── - if (process.env.GROK_CLI_OAUTH_TOKEN) { - pi.on("session_start", async (_event, ctx) => { - ctx.ui.notify( - "[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery", - "warning", - ); - }); - } -} +export { default } from './provider/register.js'; diff --git a/src/models.ts b/src/models.ts deleted file mode 100644 index 300981e..0000000 --- a/src/models.ts +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Model definitions for Grok CLI's API. - */ - -// ─── Cost constants ($/M tokens) ────────────────────────────────────────────── - -const COST_BUILD = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }; -const COST_COMPOSER = { input: 3, output: 15, cacheRead: 0, cacheWrite: 0 }; -const COST_43 = { input: 1.25, output: 2.5, cacheRead: 0.2, cacheWrite: 0 }; -const COST_420 = { input: 2, output: 6, cacheRead: 0.2, cacheWrite: 0 }; - -// ─── Model type ─────────────────────────────────────────────────────────────── - -export interface GrokCliModelConfig { - id: string; - name: string; - reasoning: boolean; - input: ("text" | "image")[]; - cost: { - input: number; - output: number; - cacheRead: number; - cacheWrite: number; - }; - contextWindow: number; - maxTokens: number; - /** Models that don't support reasoning.effort get a thinkingLevelMap. */ - thinkingLevelMap?: Record; -} - -// ─── Hardcoded fallback catalog ─────────────────────────────────────────────── -// -// These are the models observed via the Grok CLI's /v1/models endpoint and -// the actual traffic captured through cli-chat-proxy.grok.com. - -const FALLBACK_MODELS: GrokCliModelConfig[] = [ - { - id: "grok-composer-2.5-fast", - name: "Composer 2.5 Fast (Grok CLI)", - reasoning: false, - input: ["text", "image"], - cost: COST_COMPOSER, - contextWindow: 512_000, - maxTokens: 30_000, - thinkingLevelMap: { - off: "none", - minimal: null, - low: null, - medium: null, - high: null, - xhigh: null, - }, - }, - { - id: "grok-build", - name: "Grok Build", - reasoning: true, - input: ["text", "image"], - cost: COST_BUILD, - contextWindow: 1_000_000, - maxTokens: 30_000, - }, - { - id: "grok-4.3", - name: "Grok 4.3", - reasoning: true, - input: ["text", "image"], - cost: COST_43, - contextWindow: 1_000_000, - maxTokens: 30_000, - }, - { - id: "grok-4.20-0309-reasoning", - name: "Grok 4.20 Reasoning", - reasoning: true, - input: ["text", "image"], - cost: COST_420, - contextWindow: 2_000_000, - maxTokens: 30_000, - }, - { - id: "grok-4.20-0309-non-reasoning", - name: "Grok 4.20 Non-Reasoning", - reasoning: false, - input: ["text", "image"], - cost: COST_420, - contextWindow: 2_000_000, - maxTokens: 30_000, - thinkingLevelMap: { - off: "none", - minimal: null, - low: null, - medium: null, - high: null, - xhigh: null, - }, - }, - { - id: "grok-4.20-multi-agent-0309", - name: "Grok 4.20 Multi-Agent", - reasoning: true, - input: ["text", "image"], - cost: COST_420, - contextWindow: 2_000_000, - maxTokens: 30_000, - }, -]; - -// ─── Reasoning-effort allowlist ─────────────────────────────────────────────── - -/** - * Only these model prefixes support `reasoning.effort` in the Responses API. - * Everything else gets the param stripped in the sanitizer. - */ -const EFFORT_CAPABLE_PREFIXES = [ - "grok-3-mini", - "grok-4.20-multi-agent", - "grok-4.3", - "grok-composer", -]; - -export function supportsReasoningEffort(modelId: string): boolean { - const parts = modelId.split("/"); - const name = parts.at(-1) ?? modelId; - return EFFORT_CAPABLE_PREFIXES.some((p) => name.toLowerCase().startsWith(p)); -} - -// ─── PI_GROK_CLI_MODELS env override ────────────────────────────────────────── - -/** - * Resolve the active model list. If `PI_GROK_CLI_MODELS` is set, - * it filters/reorders the fallback list; unknown IDs get sensible defaults. - */ -export function resolveModels(): GrokCliModelConfig[] { - const env = (process.env.PI_GROK_CLI_MODELS || "") - .split(",") - .map((s) => s.trim()) - .filter(Boolean); - if (env.length === 0) return FALLBACK_MODELS; - - const byId = new Map(FALLBACK_MODELS.map((m) => [m.id, m])); - return env.map( - (id) => - byId.get(id) ?? { - id, - name: id, - reasoning: true, - input: ["text"] as ("text" | "image")[], - cost: COST_BUILD, - contextWindow: 1_000_000, - maxTokens: 30_000, - }, - ); -} diff --git a/src/models/catalog.ts b/src/models/catalog.ts new file mode 100644 index 0000000..73f4858 --- /dev/null +++ b/src/models/catalog.ts @@ -0,0 +1,149 @@ +/** + * Model definitions for Grok CLI's API. + */ + +// ─── Cost constants ($/M tokens) ────────────────────────────────────────────── + +const COST_BUILD = { input: 1, output: 2, cacheRead: 0.2, cacheWrite: 0.2 }; +const COST_COMPOSER_FAST = { input: 3, output: 15, cacheRead: 0.5, cacheWrite: 0 }; +const COST_43 = { input: 1.25, output: 2.5, cacheRead: 0.2, cacheWrite: 0 }; +const COST_420 = { input: 2, output: 6, cacheRead: 0.2, cacheWrite: 0 }; + +// ─── Model type ─────────────────────────────────────────────────────────────── + +export interface GrokCliModelConfig { + id: string; + name: string; + reasoning: boolean; + input: ('text' | 'image')[]; + cost: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + }; + contextWindow: number; + maxTokens: number; + /** Models that don't support reasoning.effort get a thinkingLevelMap. */ + thinkingLevelMap?: Record; +} + +// ─── Hardcoded fallback catalog ─────────────────────────────────────────────── +// +// These are the models observed via the Grok CLI's /v1/models endpoint and +// the actual traffic captured through cli-chat-proxy.grok.com. + +const FALLBACK_MODELS: GrokCliModelConfig[] = [ + { + id: 'grok-composer-2.5-fast', + name: 'Composer 2.5 Fast (Grok CLI)', + reasoning: false, + input: ['text', 'image'], + cost: COST_COMPOSER_FAST, + contextWindow: 200_000, + maxTokens: 30_000, + thinkingLevelMap: { + off: 'none', + minimal: null, + low: null, + medium: null, + high: null, + xhigh: null, + }, + }, + { + id: 'grok-build', + name: 'Grok Build', + reasoning: true, + input: ['text', 'image'], + cost: COST_BUILD, + contextWindow: 512_000, + maxTokens: 30_000, + }, + { + id: 'grok-4.3', + name: 'Grok 4.3', + reasoning: true, + input: ['text', 'image'], + cost: COST_43, + contextWindow: 1_000_000, + maxTokens: 30_000, + }, + { + id: 'grok-4.20-0309-reasoning', + name: 'Grok 4.20 Reasoning', + reasoning: true, + input: ['text', 'image'], + cost: COST_420, + contextWindow: 2_000_000, + maxTokens: 30_000, + }, + { + id: 'grok-4.20-0309-non-reasoning', + name: 'Grok 4.20 Non-Reasoning', + reasoning: false, + input: ['text', 'image'], + cost: COST_420, + contextWindow: 2_000_000, + maxTokens: 30_000, + thinkingLevelMap: { + off: 'none', + minimal: null, + low: null, + medium: null, + high: null, + xhigh: null, + }, + }, + { + id: 'grok-4.20-multi-agent-0309', + name: 'Grok 4.20 Multi-Agent', + reasoning: true, + input: ['text', 'image'], + cost: COST_420, + contextWindow: 2_000_000, + maxTokens: 30_000, + }, +]; + +const EFFORT_CAPABLE_PREFIXES = ['grok-3-mini', 'grok-4.20-multi-agent', 'grok-4.3']; + +export function supportsReasoningEffort(modelId: string): boolean { + const parts = modelId.split('/'); + const name = parts.at(-1) ?? modelId; + const model = resolveModels().find((entry) => entry.id.toLowerCase() === name.toLowerCase()); + if (!EFFORT_CAPABLE_PREFIXES.some((prefix) => name.toLowerCase().startsWith(prefix))) { + return false; + } + if (!model?.reasoning) return false; + if (!model.thinkingLevelMap) return true; + return Object.values(model.thinkingLevelMap).some((level) => level !== null && level !== 'none'); +} + +// ─── PI_GROK_CLI_MODELS env override ────────────────────────────────────────── + +/** + * Resolve the active model list. If `PI_GROK_CLI_MODELS` is set, + * it filters/reorders the fallback list; unknown IDs get sensible defaults. + */ +export function resolveModels(): GrokCliModelConfig[] { + const env = (process.env.PI_GROK_CLI_MODELS || '') + .split(',') + .map((s) => s.trim()) + .filter(Boolean); + if (env.length === 0) return FALLBACK_MODELS; + + const byId = new Map(FALLBACK_MODELS.map((m) => [m.id, m])); + return env.map( + (id) => + byId.get(id) ?? { + id, + name: id, + reasoning: true, + input: ['text'] as ('text' | 'image')[], + cost: COST_BUILD, + contextWindow: 1_000_000, + maxTokens: 30_000, + }, + ); +} diff --git a/src/oauth.ts b/src/oauth.ts deleted file mode 100644 index 6e515e2..0000000 --- a/src/oauth.ts +++ /dev/null @@ -1,455 +0,0 @@ -/** - * xAI Grok OAuth 2.0 + PKCE implementation. - * - * Uses Web Crypto API (crypto.subtle) for PKCE so the extension is - * portable across Node versions and potential non-Node runtimes. - * - * The OAuth flow is identical to pi-grok — same client_id, same auth.x.ai - * issuer. The difference is in the API endpoint: this extension targets - * cli-chat-proxy.grok.com instead of api.x.ai. - */ - -import { createServer } from "node:http"; -import { XaiErrorCode, XaiOAuthError } from "./errors.js"; - -// ─── Constants ──────────────────────────────────────────────────────────────── - -const DEFAULT_BASE_URL = "https://cli-chat-proxy.grok.com/v1"; -const ISSUER = "https://auth.x.ai"; -const DISCOVERY_URL = `${ISSUER}/.well-known/openid-configuration`; -const CLIENT_ID = - process.env.PI_GROK_CLI_OAUTH_CLIENT_ID || - "b1a00492-073a-47ea-816f-4c329264a828"; -const SCOPE = - process.env.PI_GROK_CLI_OAUTH_SCOPE || - "openid profile email offline_access grok-cli:access api:access"; -const CALLBACK_HOST = process.env.PI_GROK_CLI_CALLBACK_HOST || "127.0.0.1"; -const CALLBACK_PORT = Number.parseInt( - process.env.PI_GROK_CLI_CALLBACK_PORT || "56122", - 10, -); -const CALLBACK_PATH = "/callback"; -/** Refresh 120s before actual expiry. */ -const REFRESH_SKEW_MS = 120_000; - -// ─── Types ──────────────────────────────────────────────────────────────────── - -interface XaiDiscovery { - authorization_endpoint: string; - token_endpoint: string; -} - -export interface XaiOAuthCredentials { - [key: string]: unknown; - refresh: string; - access: string; - expires: number; - tokenEndpoint?: string; - discovery?: XaiDiscovery; - idToken?: string; - tokenType?: string; - baseUrl?: string; -} - -// ─── Helpers ────────────────────────────────────────────────────────────────── - -export function getBaseUrl(): string { - return ( - process.env.PI_GROK_CLI_BASE_URL || - process.env.GROK_CLI_BASE_URL || - DEFAULT_BASE_URL - ).replace(/\/+$/, ""); -} - -function base64Url(buffer: ArrayBuffer | Uint8Array): string { - const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); - let binary = ""; - for (const b of bytes) binary += String.fromCharCode(b); - return btoa(binary) - .replace(/\+/g, "-") - .replace(/\//g, "_") - .replace(/=+$/, ""); -} - -// ─── PKCE ───────────────────────────────────────────────────────────────────── - -async function generatePKCE(): Promise<{ - verifier: string; - challenge: string; -}> { - const verifier = base64Url(crypto.getRandomValues(new Uint8Array(32))); - const hash = await crypto.subtle.digest( - "SHA-256", - new TextEncoder().encode(verifier), - ); - return { verifier, challenge: base64Url(hash) }; -} - -// ─── Endpoint validation ────────────────────────────────────────────────────── - -function validateEndpoint(value: string, field: string): string { - let url: URL; - try { - url = new URL(value); - } catch { - throw new XaiOAuthError( - `xAI OAuth discovery returned invalid ${field}: ${value}`, - XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - ); - } - if (url.protocol !== "https:") { - throw new XaiOAuthError( - `xAI OAuth ${field} must use HTTPS: ${value}`, - XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - ); - } - const host = url.hostname.toLowerCase(); - if ( - host !== "x.ai" && - host !== "auth.x.ai" && - host !== "accounts.x.ai" && - !host.endsWith(".x.ai") - ) { - throw new XaiOAuthError( - `Refusing non-xAI OAuth ${field}: ${value}`, - XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - ); - } - return url.toString(); -} - -// ─── OIDC Discovery ────────────────────────────────────────────────────────── - -async function discover(): Promise { - let response: Response; - try { - response = await fetch(DISCOVERY_URL, { - headers: { Accept: "application/json" }, - signal: AbortSignal.timeout(15_000), - }); - } catch (cause) { - throw new XaiOAuthError( - `xAI OIDC discovery failed: ${cause instanceof Error ? cause.message : String(cause)}`, - XaiErrorCode.DISCOVERY_FAILED, - ); - } - if (!response.ok) { - throw new XaiOAuthError( - `xAI OIDC discovery returned ${response.status}`, - XaiErrorCode.DISCOVERY_FAILED, - ); - } - const payload = (await response.json()) as Record; - const authorizationEndpoint = validateEndpoint( - String(payload.authorization_endpoint ?? ""), - "authorization_endpoint", - ); - const tokenEndpoint = validateEndpoint( - String(payload.token_endpoint ?? ""), - "token_endpoint", - ); - return { - authorization_endpoint: authorizationEndpoint, - token_endpoint: tokenEndpoint, - }; -} - -// ─── Loopback callback server ──────────────────────────────────────────────── - -interface CallbackResult { - code?: string; - state?: string; - error?: string; - errorDescription?: string; -} - -function startCallbackServer(): Promise<{ - server: import("node:http").Server; - redirectUri: string; - waitForCallback: (timeoutMs: number) => Promise; -}> { - let settle: ((value: CallbackResult) => void) | undefined; - const callbackPromise = new Promise((resolve) => { - settle = resolve; - }); - - const server = createServer((req, res) => { - try { - const origin = req.headers.origin; - if ( - origin === "https://accounts.x.ai" || - origin === "https://auth.x.ai" - ) { - res.setHeader("Access-Control-Allow-Origin", origin); - res.setHeader("Access-Control-Allow-Methods", "GET, OPTIONS"); - res.setHeader("Access-Control-Allow-Headers", "Content-Type"); - res.setHeader("Access-Control-Allow-Private-Network", "true"); - res.setHeader("Vary", "Origin"); - } - if (req.method === "OPTIONS") { - res.statusCode = 204; - res.end(); - return; - } - - const url = new URL(req.url ?? "/", `http://${CALLBACK_HOST}`); - if (url.pathname !== CALLBACK_PATH) { - res.statusCode = 404; - res.end("Not found"); - return; - } - - const result: CallbackResult = { - code: url.searchParams.get("code") ?? undefined, - state: url.searchParams.get("state") ?? undefined, - error: url.searchParams.get("error") ?? undefined, - errorDescription: - url.searchParams.get("error_description") ?? undefined, - }; - - res.statusCode = result.error ? 400 : 200; - res.setHeader("Content-Type", "text/html; charset=utf-8"); - const html = result.error - ? "

xAI authorization failed.

You can close this tab." - : "

xAI authorization received.

You can close this tab."; - res.end(html); - settle?.(result); - } catch { - res.statusCode = 500; - res.end("Internal error"); - } - }); - - const listen = (port: number) => - new Promise((resolve, reject) => { - server.once("error", reject); - server.listen(port, CALLBACK_HOST, () => { - server.removeListener("error", reject); - const addr = server.address(); - resolve(typeof addr === "object" && addr ? addr.port : port); - }); - }); - - return (async () => { - let actualPort: number; - try { - actualPort = await listen(CALLBACK_PORT); - } catch { - actualPort = await listen(0); - } - const redirectUri = `http://${CALLBACK_HOST}:${actualPort}${CALLBACK_PATH}`; - return { - server, - redirectUri, - waitForCallback: (timeoutMs: number) => - Promise.race([ - callbackPromise, - new Promise((resolve) => - setTimeout( - () => - resolve({ - error: "timeout", - errorDescription: "Timed out waiting for xAI OAuth callback.", - }), - timeoutMs, - ), - ), - ]), - }; - })(); -} - -// ─── Token exchange ─────────────────────────────────────────────────────────── - -async function exchangeCode( - tokenEndpoint: string, - code: string, - redirectUri: string, - verifier: string, -): Promise { - const response = await fetch(tokenEndpoint, { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - Accept: "application/json", - }, - body: new URLSearchParams({ - grant_type: "authorization_code", - client_id: CLIENT_ID, - code, - redirect_uri: redirectUri, - code_verifier: verifier, - }), - }); - if (!response.ok) { - throw new XaiOAuthError( - `xAI token exchange failed: ${response.status} ${await response.text()}`, - XaiErrorCode.TOKEN_EXCHANGE_FAILED, - ); - } - const payload = (await response.json()) as Record; - const access = String(payload.access_token ?? ""); - const refresh = String(payload.refresh_token ?? ""); - if (!access) { - throw new XaiOAuthError( - "xAI token exchange did not return access_token.", - XaiErrorCode.TOKEN_EXCHANGE_INVALID, - ); - } - if (!refresh) { - throw new XaiOAuthError( - "xAI token exchange did not return refresh_token.", - XaiErrorCode.TOKEN_EXCHANGE_INVALID, - ); - } - const expiresIn = - typeof payload.expires_in === "number" - ? payload.expires_in - : Number(payload.expires_in ?? 3600); - return { - access, - refresh, - expires: Date.now() + expiresIn * 1000 - REFRESH_SKEW_MS, - tokenEndpoint, - discovery: { authorization_endpoint: "", token_endpoint: tokenEndpoint }, - idToken: String(payload.id_token ?? ""), - tokenType: String(payload.token_type ?? "Bearer"), - baseUrl: getBaseUrl(), - }; -} - -// ─── Login (called by pi's /login flow) ────────────────────────────────────── - -export async function login( - callbacks: import("@earendil-works/pi-ai").OAuthLoginCallbacks, -): Promise { - const discovery = await discover(); - const { verifier, challenge } = await generatePKCE(); - const state = base64Url(crypto.getRandomValues(new Uint8Array(16))); - const nonce = base64Url(crypto.getRandomValues(new Uint8Array(16))); - const callback = await startCallbackServer(); - - try { - const authUrl = new URL(discovery.authorization_endpoint); - authUrl.searchParams.set("response_type", "code"); - authUrl.searchParams.set("client_id", CLIENT_ID); - authUrl.searchParams.set("redirect_uri", callback.redirectUri); - authUrl.searchParams.set("scope", SCOPE); - authUrl.searchParams.set("code_challenge", challenge); - authUrl.searchParams.set("code_challenge_method", "S256"); - authUrl.searchParams.set("state", state); - authUrl.searchParams.set("nonce", nonce); - authUrl.searchParams.set("plan", "generic"); - authUrl.searchParams.set("referrer", "pi-grok-cli"); - - callbacks.onAuth({ - url: authUrl.toString(), - instructions: `Authorize xAI, then return to pi. Callback listener: ${callback.redirectUri}`, - }); - - const result = await callback.waitForCallback(180_000); - - if (result.error) { - throw new XaiOAuthError( - result.errorDescription ?? result.error, - XaiErrorCode.AUTHORIZATION_FAILED, - ); - } - if (result.state !== state) { - throw new XaiOAuthError( - "xAI OAuth state mismatch — possible CSRF.", - XaiErrorCode.STATE_MISMATCH, - ); - } - if (!result.code) { - throw new XaiOAuthError( - "xAI OAuth callback did not include an authorization code.", - XaiErrorCode.CODE_MISSING, - ); - } - - const credentials = await exchangeCode( - discovery.token_endpoint, - result.code, - callback.redirectUri, - verifier, - ); - credentials.discovery = discovery; - return credentials; - } finally { - callback.server.close(); - } -} - -// ─── Token refresh ──────────────────────────────────────────────────────────── - -export async function refresh( - credentials: import("@earendil-works/pi-ai").OAuthCredentials, -): Promise { - const xai = credentials as XaiOAuthCredentials; - const tokenEndpoint = - xai.tokenEndpoint || - xai.discovery?.token_endpoint || - (await discover()).token_endpoint; - validateEndpoint(tokenEndpoint, "token_endpoint"); - - if (!credentials.refresh) { - throw new XaiOAuthError( - "Missing refresh_token. Re-login required.", - XaiErrorCode.REFRESH_MISSING, - true, - ); - } - - const response = await fetch(tokenEndpoint, { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - Accept: "application/json", - }, - body: new URLSearchParams({ - grant_type: "refresh_token", - client_id: CLIENT_ID, - refresh_token: credentials.refresh, - }), - }); - - if (!response.ok) { - const isFatal = - response.status === 400 || - response.status === 401 || - response.status === 403; - throw new XaiOAuthError( - `xAI token refresh failed: ${response.status} ${await response.text()}`, - XaiErrorCode.REFRESH_FAILED, - isFatal, - ); - } - - const payload = (await response.json()) as Record; - const access = String(payload.access_token ?? ""); - if (!access) { - throw new XaiOAuthError( - "xAI token refresh did not return access_token.", - XaiErrorCode.REFRESH_FAILED, - true, - ); - } - - const refresh_new = String(payload.refresh_token ?? credentials.refresh); - const expiresIn = - typeof payload.expires_in === "number" - ? payload.expires_in - : Number(payload.expires_in ?? 3600); - - return { - ...xai, - access, - refresh: refresh_new, - expires: Date.now() + expiresIn * 1000 - REFRESH_SKEW_MS, - tokenEndpoint, - idToken: String(payload.id_token ?? xai.idToken ?? ""), - tokenType: String(payload.token_type ?? xai.tokenType ?? "Bearer"), - baseUrl: getBaseUrl(), - }; -} diff --git a/src/payload/sanitize.ts b/src/payload/sanitize.ts new file mode 100644 index 0000000..e24f7ab --- /dev/null +++ b/src/payload/sanitize.ts @@ -0,0 +1,320 @@ +/** + * Payload sanitization for xAI's Responses API via cli-chat-proxy.grok.com. + * + * xAI's endpoint has quirks compared to stock OpenAI: + * - Replayed `reasoning` items in input cause 400 errors. + * - `reasoning.effort` is only supported on a subset of models. + * - Empty-string content items cause validation failures. + * - `function_call_output.output` cannot contain image arrays. + * - `image_url` parts must be normalized to `input_image` with data URIs. + * - Local image paths must be resolved to base64 data URIs. + * - xAI rejects `role: "developer"` and `role: "system"` in the input + * array; these must be moved to top-level `instructions`. + * - xAI uses `text.format` instead of OpenAI's `response_format`. + * - xAI uses `prompt_cache_key` for conversation caching. + * - xAI doesn't support `prompt_cache_retention`. + * + * Additional Grok CLI-specific behavior: + * - Adds x-grok-* headers for client identification + * - Uses prompt_cache_key for session affinity + */ + +import { existsSync, readFileSync, realpathSync } from 'node:fs'; +import { extname, isAbsolute, resolve, sep } from 'node:path'; +import { fileURLToPath } from 'node:url'; +import { supportsReasoningEffort } from '../models/catalog.js'; + +// ─── Content text extraction ───────────────────────────────────────────────── + +function textFromContent(content: unknown): string { + if (typeof content === 'string') return content; + if (!Array.isArray(content)) return ''; + return content + .map((part) => { + if (typeof part === 'string') return part; + if (!part || typeof part !== 'object') return ''; + const item = part as Record; + const type = typeof item.type === 'string' ? item.type : ''; + return ['text', 'input_text', 'output_text'].includes(type) && typeof item.text === 'string' + ? item.text + : ''; + }) + .filter(Boolean) + .join('\n'); +} + +// ─── Image helpers ──────────────────────────────────────────────────────────── + +function stripShellQuotes(value: string): string { + const trimmed = value.trim(); + if ( + trimmed.length >= 2 && + ((trimmed.startsWith('"') && trimmed.endsWith('"')) || + (trimmed.startsWith("'") && trimmed.endsWith("'"))) + ) { + return trimmed.slice(1, -1); + } + return trimmed; +} + +function unescapeShellPath(value: string): string { + return stripShellQuotes(value).replace(/\\([\\\s'"()&;@])/g, '$1'); +} + +function imageMimeTypeForPath(path: string): string { + switch (extname(path).toLowerCase()) { + case '.jpg': + case '.jpeg': + return 'image/jpeg'; + case '.png': + return 'image/png'; + default: + throw new Error('xAI image understanding supports local .jpg, .jpeg, and .png files only'); + } +} + +function ensurePathWithinWorkspace(cwd: string, filePath: string) { + const realCwd = realpathSync(cwd); + const realPath = realpathSync(filePath); + if (realPath !== realCwd && !realPath.startsWith(`${realCwd}${sep}`)) { + throw new Error('Image path is outside the workspace'); + } + return realPath; +} + +function resolveLocalImagePath(value: string, cwd: string): string | undefined { + const cleaned = unescapeShellPath(value); + if (!cleaned) return undefined; + + if (cleaned.startsWith('file://')) { + try { + const filePath = fileURLToPath(cleaned); + return existsSync(filePath) ? ensurePathWithinWorkspace(cwd, filePath) : undefined; + } catch { + return undefined; + } + } + + const candidate = isAbsolute(cleaned) ? cleaned : resolve(cwd, cleaned); + + return existsSync(candidate) ? ensurePathWithinWorkspace(cwd, candidate) : undefined; +} + +function normalizeImageInput(value: unknown, cwd: string): string | undefined { + if (typeof value !== 'string' || !value.trim()) return undefined; + const cleaned = stripShellQuotes(value); + + if (/^https?:\/\//i.test(cleaned) || /^data:image\//i.test(cleaned)) { + return cleaned; + } + + const localPath = resolveLocalImagePath(cleaned, cwd); + if (!localPath) { + throw new Error(`Image file does not exist or is not a valid URL: ${cleaned}`); + } + + const mimeType = imageMimeTypeForPath(localPath); + const data = readFileSync(localPath).toString('base64'); + return `data:${mimeType};base64,${data}`; +} + +// ─── Content part normalization ─────────────────────────────────────────────── + +function isInputImagePart(value: unknown): value is Record { + return ( + !!value && + typeof value === 'object' && + (value as Record).type === 'input_image' + ); +} + +function getImageUrlAndDetail(obj: Record): { + imageUrl: unknown; + detail: unknown; +} { + if (typeof obj.image_url === 'object' && obj.image_url) { + const imageUrl = obj.image_url as Record; + return { imageUrl: imageUrl.url, detail: imageUrl.detail }; + } + + return { imageUrl: obj.image_url, detail: obj.detail }; +} + +function normalizeImageParts(value: unknown, cwd: string): unknown { + if (Array.isArray(value)) return value.map((item) => normalizeImageParts(item, cwd)); + if (!value || typeof value !== 'object') return value; + + const obj = { ...(value as Record) }; + + if (obj.type === 'image' && typeof obj.data === 'string' && typeof obj.mimeType === 'string') { + return { + type: 'input_image', + image_url: `data:${obj.mimeType};base64,${obj.data}`, + detail: typeof obj.detail === 'string' && obj.detail ? obj.detail : 'auto', + }; + } + + if (obj.type === 'image_url') { + const { imageUrl, detail } = getImageUrlAndDetail(obj); + obj.type = 'input_image'; + obj.image_url = imageUrl; + if (typeof detail === 'string' && detail) obj.detail = detail; + } + + if (obj.type === 'input_image') { + const { imageUrl, detail } = getImageUrlAndDetail(obj); + const normalized = normalizeImageInput(imageUrl, cwd); + if (normalized) obj.image_url = normalized; + if (typeof detail === 'string' && detail) obj.detail = detail; + if (typeof obj.detail !== 'string' || !obj.detail) obj.detail = 'auto'; + } + + if (Array.isArray(obj.content)) obj.content = normalizeImageParts(obj.content, cwd); + if (Array.isArray(obj.output)) obj.output = normalizeImageParts(obj.output, cwd); + return obj; +} + +// ─── function_call_output rewrite ───────────────────────────────────────────── + +function rewriteFunctionCallOutput(input: Record[]): Record[] { + const rewritten: Record[] = []; + + for (const item of input) { + if ( + !item || + typeof item !== 'object' || + item.type !== 'function_call_output' || + !Array.isArray(item.output) + ) { + rewritten.push(item); + continue; + } + + const outputParts = item.output as unknown[]; + const imageParts = outputParts.filter(isInputImagePart); + const textParts = outputParts.filter((p) => !isInputImagePart(p)); + + const textChunks: string[] = []; + for (const part of textParts) { + if (typeof part === 'string') { + textChunks.push(part); + } else if (part && typeof part === 'object') { + const p = part as Record; + if (typeof p.text === 'string') textChunks.push(p.text); + } + } + let imageCount = 0; + for (const _ of imageParts) imageCount++; + + const outputText = textChunks.join('\n') || '(tool returned no text output)'; + rewritten.push({ ...item, output: outputText }); + + if (imageCount > 0) { + const callId = item.call_id ? ` (${String(item.call_id)})` : ''; + const label = `The previous tool result${callId} included ${imageCount} image${imageCount === 1 ? '' : 's'}. Use the attached image${imageCount === 1 ? '' : 's'} as the visual output from that tool.`; + rewritten.push({ + role: 'user', + content: [{ type: 'input_text', text: label }, ...imageParts], + }); + } + } + + return rewritten; +} + +// ─── Main sanitization ──────────────────────────────────────────────────────── + +/** + * Sanitize a provider request payload for xAI's Responses API via + * cli-chat-proxy.grok.com. + * + * Returns the modified payload. Mutates the input in place for efficiency. + */ +export function sanitizePayload( + params: Record, + modelId: string, + sessionId: string | undefined, + cwd: string, +): Record { + const next = params; + + // ── Sanitize input array ────────────────────────────────────────────── + if (Array.isArray(next.input)) { + let input = (next.input as unknown[]) + .map((item: unknown) => { + if (!item || typeof item !== 'object') return item; + const obj = item as Record; + + // Strip replayed reasoning items + if (obj.type === 'reasoning') return null; + + // Drop empty string content + if (typeof obj.content === 'string' && obj.content.length === 0) return null; + + return obj; + }) + .filter(Boolean) as Record[]; + + // Move system/developer messages to top-level instructions. + // xAI rejects role: "developer" and role: "system" in the input array. + const instructionParts: string[] = []; + input = input.filter((item) => { + const role = (item as Record).role; + if (role !== 'developer' && role !== 'system') return true; + const text = textFromContent((item as Record).content).trim(); + if (text) instructionParts.push(text); + return false; + }); + if (instructionParts.length > 0) { + const existing = + typeof next.instructions === 'string' && next.instructions ? next.instructions : ''; + const merged = [existing, ...instructionParts].filter((part) => part.length > 0).join('\n\n'); + next.instructions = merged; + } + + // Normalize image parts (resolve local paths, fix types) + input = normalizeImageParts(input, cwd) as Record[]; + + // Rewrite function_call_output with images + input = rewriteFunctionCallOutput(input); + + next.input = input; + } else if (typeof next.input === 'string') { + // String input is valid and should stay string-shaped. + } + + // ── response_format → text.format ──────────────────────────────────── + if (next.response_format) { + if (!next.text) next.text = { format: next.response_format }; + delete next.response_format; + } + + // ── Reasoning effort ────────────────────────────────────────────────── + if (supportsReasoningEffort(modelId)) { + const reasoning = next.reasoning as Record | undefined; + if (reasoning) { + const effort = reasoning.effort === 'minimal' ? 'low' : reasoning.effort; + next.reasoning = reasoning.summary !== undefined ? { effort } : { ...reasoning, effort }; + } + } else { + delete next.reasoning; + delete next.reasoningEffort; + } + + // ── Strip/filter unsupported fields ────────────────────────────────── + if (Array.isArray(next.include)) { + next.include = (next.include as unknown[]).filter( + (item) => item !== 'reasoning.encrypted_content', + ); + if ((next.include as unknown[]).length === 0) delete next.include; + } + + delete next.prompt_cache_retention; + + // Add prompt_cache_key for conversation caching (routes to same server). + if (sessionId && !next.prompt_cache_key) { + next.prompt_cache_key = sessionId; + } + + return next; +} diff --git a/src/provider/quota.ts b/src/provider/quota.ts new file mode 100644 index 0000000..ccbc038 --- /dev/null +++ b/src/provider/quota.ts @@ -0,0 +1,127 @@ +import { existsSync, mkdirSync, readFileSync, writeFileSync } from 'node:fs'; +import { homedir } from 'node:os'; +import { dirname, join } from 'node:path'; + +const QUOTA_CACHE_FILE = 'grok-cli-quota.json'; + +// ─── Rate limit cache (piggybacks on onResponse from normal traffic) ────────── + +interface RateLimitInfo { + remainingRequests: number; + limitRequests: number; + remainingTokens: number; + limitTokens: number; + contextWindow: number; + zeroDataRetention: boolean; + capturedAt: number; +} + +const cachedRateLimits = new Map(); + +function quotaCachePath() { + return join(homedir(), '.pi', QUOTA_CACHE_FILE); +} + +function isRateLimitInfo(value: unknown): value is RateLimitInfo { + if (!value || typeof value !== 'object') return false; + const info = value as Record; + return ( + typeof info.remainingRequests === 'number' && + typeof info.limitRequests === 'number' && + typeof info.remainingTokens === 'number' && + typeof info.limitTokens === 'number' && + typeof info.contextWindow === 'number' && + typeof info.zeroDataRetention === 'boolean' && + typeof info.capturedAt === 'number' + ); +} + +export function loadQuotaCache() { + cachedRateLimits.clear(); + if (!existsSync(quotaCachePath())) return; + + try { + const payload = JSON.parse(readFileSync(quotaCachePath(), 'utf8')) as Record; + const models = payload.models; + if (!models || typeof models !== 'object') return; + + Object.entries(models).forEach(([model, rateLimit]) => { + if (isRateLimitInfo(rateLimit)) cachedRateLimits.set(model, rateLimit); + }); + } catch { + cachedRateLimits.clear(); + } +} + +function persistQuotaCache() { + try { + mkdirSync(dirname(quotaCachePath()), { recursive: true }); + writeFileSync( + quotaCachePath(), + JSON.stringify({ version: 1, models: Object.fromEntries(cachedRateLimits) }, null, '\t'), + ); + } catch { + // Status remains cache-only; persistence failures should not break requests. + } +} + +/** + * Extract rate limit info from response headers. + * Returns undefined if no rate limit headers are present. + */ +function extractRateLimit(h: Record): RateLimitInfo | undefined { + const remainingReqs = Number(h['x-ratelimit-remaining-requests']); + const limitReqs = Number(h['x-ratelimit-limit-requests']); + const remainingTokens = Number(h['x-ratelimit-remaining-tokens']); + const limitTokens = Number(h['x-ratelimit-limit-tokens']); + const contextWindow = Number(h['x-grok-context-window']); + + if ( + [remainingReqs, limitReqs, remainingTokens, limitTokens].some( + (value) => !Number.isFinite(value), + ) + ) { + return undefined; + } + + return { + remainingRequests: remainingReqs, + limitRequests: limitReqs, + remainingTokens, + limitTokens, + contextWindow: contextWindow || 512_000, + zeroDataRetention: h['x-zero-data-retention'] === 'true', + capturedAt: Date.now(), + }; +} + +export function formatQuota(name: string, rateLimit: RateLimitInfo | undefined) { + if (!rateLimit) { + return [` ${name}:`, ' no cached quota data — make a request with this model first']; + } + + const ageSec = Math.round((Date.now() - rateLimit.capturedAt) / 1000); + const ageStr = ageSec < 60 ? `${ageSec}s ago` : `${Math.round(ageSec / 60)}m ago`; + const lines = [` ${name}:`]; + lines.push(` Cached: ${ageStr}`); + lines.push(` Requests: ${rateLimit.remainingRequests}/${rateLimit.limitRequests} remaining`); + lines.push( + ` Tokens: ${rateLimit.remainingTokens.toLocaleString()}/${rateLimit.limitTokens.toLocaleString()} remaining`, + ); + lines.push(` Context Limit: ${rateLimit.contextWindow.toLocaleString()} tokens`); + if (rateLimit.zeroDataRetention) { + lines.push(' Data: Zero retention ✓'); + } + return lines; +} + +export function captureRateLimit(modelId: string, headers: Record) { + const rateLimit = extractRateLimit(headers); + if (!rateLimit) return; + cachedRateLimits.set(modelId, rateLimit); + persistQuotaCache(); +} + +export function getCachedRateLimit(modelId: string): RateLimitInfo | undefined { + return cachedRateLimits.get(modelId); +} diff --git a/src/provider/register.ts b/src/provider/register.ts new file mode 100644 index 0000000..075d1b1 --- /dev/null +++ b/src/provider/register.ts @@ -0,0 +1,90 @@ +import type { Api, Model, OAuthCredentials, OAuthLoginCallbacks } from '@earendil-works/pi-ai'; +import type { ExtensionAPI, ProviderConfig } from '@earendil-works/pi-coding-agent'; +import * as oauth from '../auth/oauth.js'; +import { getBaseUrl, type XaiOAuthCredentials } from '../auth/oauth.js'; +import { type GrokCliModelConfig, resolveModels } from '../models/catalog.js'; +import { sanitizePayload } from '../payload/sanitize.js'; +import { registerGrokTools } from '../tools/register.js'; +import { loadQuotaCache } from './quota.js'; +import { registerStatusCommand } from './status.js'; +import { streamGrokCli } from './stream.js'; +import { syncGrokTools } from './toolScope.js'; + +export default function registerGrokCli(pi: ExtensionAPI) { + loadQuotaCache(); + const baseUrl = getBaseUrl(); + const models = resolveModels(); + + pi.on('model_select', (event) => { + syncGrokTools(pi, event.model.provider); + }); + + pi.on('before_agent_start', (_event, ctx) => { + syncGrokTools(pi, ctx.model?.provider); + }); + + pi.registerProvider('grok-cli', { + name: 'Grok CLI', + baseUrl, + apiKey: '$GROK_CLI_OAUTH_TOKEN', + api: 'openai-responses', + models: models.map((m: GrokCliModelConfig) => ({ + id: m.id, + name: m.name, + reasoning: m.reasoning, + thinkingLevelMap: m.thinkingLevelMap, + input: m.input, + cost: m.cost, + contextWindow: m.contextWindow, + maxTokens: m.maxTokens, + })), + oauth: { + name: 'Grok CLI', + + async login(callbacks: OAuthLoginCallbacks): Promise { + return oauth.login(callbacks); + }, + + async refreshToken(credentials: OAuthCredentials): Promise { + return oauth.refresh(credentials); + }, + + getApiKey(credentials: OAuthCredentials): string { + return credentials.access; + }, + + modifyModels(models: Model[], credentials: OAuthCredentials) { + const effectiveBaseUrl = String( + (credentials as XaiOAuthCredentials).baseUrl ?? getBaseUrl(), + ).replace(/\/+$/, ''); + + return models.map((m) => + m.provider === 'grok-cli' ? { ...m, baseUrl: effectiveBaseUrl } : m, + ); + }, + } satisfies ProviderConfig['oauth'], + + streamSimple: streamGrokCli, + }); + + registerGrokTools(pi); + + pi.on('before_provider_request', (event, ctx) => { + if (ctx.model?.provider !== 'grok-cli') return; + + const modelId = ctx.model?.id ?? ''; + const sessionId = ctx.sessionManager?.getSessionId(); + return sanitizePayload(event.payload as Record, modelId, sessionId, ctx.cwd); + }); + + registerStatusCommand(pi); + + if (process.env.GROK_CLI_OAUTH_TOKEN) { + pi.on('session_start', async (_event, ctx) => { + ctx.ui.notify( + '[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery', + 'warning', + ); + }); + } +} diff --git a/src/provider/status.ts b/src/provider/status.ts new file mode 100644 index 0000000..41ca0a3 --- /dev/null +++ b/src/provider/status.ts @@ -0,0 +1,55 @@ +import type { Api, Model } from '@earendil-works/pi-ai'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { XaiOAuthError } from '../shared/errors.js'; +import { formatQuota, getCachedRateLimit } from './quota.js'; + +export function registerStatusCommand(pi: Pick) { + pi.registerCommand('grok-cli-status', { + description: 'Show Grok CLI provider status, quota, and token health', + handler: async (_args, ctx) => { + const token = process.env.GROK_CLI_OAUTH_TOKEN; + if (token) { + ctx.ui.notify( + '⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available', + 'warning', + ); + } + + try { + const registry = ctx.modelRegistry; + const grokModels = registry.getAll().filter((m: Model) => m.provider === 'grok-cli'); + if (grokModels.length === 0) { + ctx.ui.notify('Grok CLI: no models registered. Run /login grok-cli first.', 'warning'); + return; + } + + const modelNames = grokModels + .slice(0, 5) + .map((m: Model) => m.id) + .join(', '); + const suffix = grokModels.length > 5 ? ` (+${grokModels.length - 5} more)` : ''; + ctx.ui.notify( + `✓ Grok CLI: ${grokModels.length} models available (${modelNames}${suffix})`, + 'info', + ); + + const lines = [ + ' Quota:', + ...grokModels.flatMap((model: Model) => [ + '', + ...formatQuota(model.id, getCachedRateLimit(model.id)), + ]), + ]; + ctx.ui.notify(lines.join('\n'), 'info'); + } catch (err) { + const msg = + err instanceof XaiOAuthError + ? `${err.message} (code: ${err.code})` + : err instanceof Error + ? err.message + : String(err); + ctx.ui.notify(`Grok CLI: ${msg}`, 'warning'); + } + }, + }); +} diff --git a/src/provider/stream.ts b/src/provider/stream.ts new file mode 100644 index 0000000..afcafae --- /dev/null +++ b/src/provider/stream.ts @@ -0,0 +1,49 @@ +import { + type Api, + type AssistantMessageEventStream, + type Context, + type Model, + type SimpleStreamOptions, + streamSimpleOpenAIResponses, +} from '@earendil-works/pi-ai'; +import { captureRateLimit } from './quota.js'; + +const GROK_CLI_VERSION = '0.2.16'; + +/** + * Stream function that adds Grok CLI-specific headers to requests. + * + * The real Grok CLI sends these headers: + * - x-grok-client-identifier: grok-shell + * - x-grok-client-version: 0.2.16 + * - x-grok-conv-id: + * - x-grok-model-override: + * - x-xai-token-auth: xai-grok-cli + */ +export function streamGrokCli( + model: Model, + context: Context, + options?: SimpleStreamOptions, +): AssistantMessageEventStream { + const sessionId = options?.sessionId; + const headers: Record = { + ...options?.headers, + 'x-grok-client-identifier': 'pi-grok-cli', + 'x-grok-client-version': GROK_CLI_VERSION, + 'x-xai-token-auth': 'xai-grok-cli', + 'x-grok-model-override': model.id, + }; + + if (sessionId) { + headers['x-grok-conv-id'] = sessionId; + } + + return streamSimpleOpenAIResponses(model as Model<'openai-responses'>, context, { + ...options, + headers, + onResponse(response) { + captureRateLimit(model.id, response.headers); + options?.onResponse?.(response, model); + }, + }); +} diff --git a/src/provider/toolScope.ts b/src/provider/toolScope.ts new file mode 100644 index 0000000..5730cc6 --- /dev/null +++ b/src/provider/toolScope.ts @@ -0,0 +1,20 @@ +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { GROK_TOOL_NAMES } from '../tools/register.js'; + +export function syncGrokTools( + pi: Pick, + provider: string | undefined, +) { + const currentTools = pi.getActiveTools(); + const baseTools = currentTools.filter((toolName) => !GROK_TOOL_NAMES.includes(toolName)); + const nextTools = provider === 'grok-cli' ? [...baseTools, ...GROK_TOOL_NAMES] : baseTools; + + if ( + currentTools.length === nextTools.length && + currentTools.every((toolName, i) => toolName === nextTools[i]) + ) { + return; + } + + pi.setActiveTools(nextTools); +} diff --git a/src/sanitize.ts b/src/sanitize.ts deleted file mode 100644 index 00804ab..0000000 --- a/src/sanitize.ts +++ /dev/null @@ -1,335 +0,0 @@ -/** - * Payload sanitization for xAI's Responses API via cli-chat-proxy.grok.com. - * - * xAI's endpoint has quirks compared to stock OpenAI: - * - Replayed `reasoning` items in input cause 400 errors. - * - `reasoning.effort` is only supported on a subset of models. - * - Empty-string content items cause validation failures. - * - `function_call_output.output` cannot contain image arrays. - * - `image_url` parts must be normalized to `input_image` with data URIs. - * - Local image paths must be resolved to base64 data URIs. - * - xAI rejects `role: "developer"` and `role: "system"` in the input - * array; these must be moved to top-level `instructions`. - * - xAI uses `text.format` instead of OpenAI's `response_format`. - * - xAI uses `prompt_cache_key` for conversation caching. - * - xAI doesn't support `prompt_cache_retention`. - * - * Additional Grok CLI-specific behavior: - * - Adds x-grok-* headers for client identification - * - Uses prompt_cache_key for session affinity - */ - -import { existsSync, readFileSync } from "node:fs"; -import { extname, isAbsolute, resolve } from "node:path"; -import { fileURLToPath } from "node:url"; -import { supportsReasoningEffort } from "./models.js"; - -// ─── Content text extraction ───────────────────────────────────────────────── - -function textFromContent(content: unknown): string { - if (typeof content === "string") return content; - if (!Array.isArray(content)) return ""; - return content - .map((part) => { - if (typeof part === "string") return part; - if (!part || typeof part !== "object") return ""; - const item = part as Record; - const type = typeof item.type === "string" ? item.type : ""; - return ["text", "input_text", "output_text"].includes(type) && - typeof item.text === "string" - ? item.text - : ""; - }) - .filter(Boolean) - .join("\n"); -} - -// ─── Image helpers ──────────────────────────────────────────────────────────── - -function stripShellQuotes(value: string): string { - const trimmed = value.trim(); - if ( - trimmed.length >= 2 && - ((trimmed.startsWith('"') && trimmed.endsWith('"')) || - (trimmed.startsWith("'") && trimmed.endsWith("'"))) - ) { - return trimmed.slice(1, -1); - } - return trimmed; -} - -function unescapeShellPath(value: string): string { - return stripShellQuotes(value).replace(/\\([\\\s'"()&;@])/g, "$1"); -} - -function imageMimeTypeForPath(path: string): string { - switch (extname(path).toLowerCase()) { - case ".jpg": - case ".jpeg": - return "image/jpeg"; - case ".png": - return "image/png"; - default: - throw new Error( - "xAI image understanding supports local .jpg, .jpeg, and .png files only", - ); - } -} - -function resolveLocalImagePath(value: string): string | undefined { - const cleaned = unescapeShellPath(value); - if (!cleaned) return undefined; - - if (cleaned.startsWith("file://")) { - try { - return fileURLToPath(cleaned); - } catch { - return undefined; - } - } - - const candidates = [cleaned]; - if (!isAbsolute(cleaned)) candidates.push(resolve(process.cwd(), cleaned)); - - return candidates.find((candidate) => existsSync(candidate)); -} - -function normalizeImageInput(value: unknown): string | undefined { - if (typeof value !== "string" || !value.trim()) return undefined; - const cleaned = stripShellQuotes(value); - - if (/^https?:\/\//i.test(cleaned) || /^data:image\//i.test(cleaned)) { - return cleaned; - } - - const localPath = resolveLocalImagePath(cleaned); - if (!localPath) { - throw new Error( - `Image file does not exist or is not a valid URL: ${cleaned}`, - ); - } - - const mimeType = imageMimeTypeForPath(localPath); - const data = readFileSync(localPath).toString("base64"); - return `data:${mimeType};base64,${data}`; -} - -// ─── Content part normalization ─────────────────────────────────────────────── - -function isInputImagePart(value: unknown): value is Record { - return ( - !!value && - typeof value === "object" && - (value as Record).type === "input_image" - ); -} - -function getImageUrlAndDetail(obj: Record): { - imageUrl: unknown; - detail: unknown; -} { - if (typeof obj.image_url === "object" && obj.image_url) { - const imageUrl = obj.image_url as Record; - return { imageUrl: imageUrl.url, detail: imageUrl.detail }; - } - - return { imageUrl: obj.image_url, detail: obj.detail }; -} - -function normalizeImageParts(value: unknown): unknown { - if (Array.isArray(value)) return value.map(normalizeImageParts); - if (!value || typeof value !== "object") return value; - - const obj = { ...(value as Record) }; - - if ( - obj.type === "image" && - typeof obj.data === "string" && - typeof obj.mimeType === "string" - ) { - return { - type: "input_image", - image_url: `data:${obj.mimeType};base64,${obj.data}`, - detail: - typeof obj.detail === "string" && obj.detail ? obj.detail : "auto", - }; - } - - if (obj.type === "image_url") { - const { imageUrl, detail } = getImageUrlAndDetail(obj); - obj.type = "input_image"; - obj.image_url = imageUrl; - if (typeof detail === "string" && detail) obj.detail = detail; - } - - if (obj.type === "input_image") { - const { imageUrl, detail } = getImageUrlAndDetail(obj); - const normalized = normalizeImageInput(imageUrl); - if (normalized) obj.image_url = normalized; - if (typeof detail === "string" && detail) obj.detail = detail; - if (typeof obj.detail !== "string" || !obj.detail) obj.detail = "auto"; - } - - if (Array.isArray(obj.content)) - obj.content = normalizeImageParts(obj.content); - if (Array.isArray(obj.output)) obj.output = normalizeImageParts(obj.output); - return obj; -} - -// ─── function_call_output rewrite ───────────────────────────────────────────── - -function rewriteFunctionCallOutput( - input: Record[], -): Record[] { - const rewritten: Record[] = []; - - for (const item of input) { - if ( - !item || - typeof item !== "object" || - item.type !== "function_call_output" || - !Array.isArray(item.output) - ) { - rewritten.push(item); - continue; - } - - const outputParts = item.output as unknown[]; - const imageParts = outputParts.filter(isInputImagePart); - const textParts = outputParts.filter((p) => !isInputImagePart(p)); - - const textChunks: string[] = []; - for (const part of textParts) { - if (typeof part === "string") { - textChunks.push(part); - } else if (part && typeof part === "object") { - const p = part as Record; - if (typeof p.text === "string") textChunks.push(p.text); - } - } - let imageCount = 0; - for (const _ of imageParts) imageCount++; - - const outputText = - textChunks.join("\n") || "(tool returned no text output)"; - rewritten.push({ ...item, output: outputText }); - - if (imageCount > 0) { - const callId = item.call_id ? ` (${String(item.call_id)})` : ""; - const label = `The previous tool result${callId} included ${imageCount} image${imageCount === 1 ? "" : "s"}. Use the attached image${imageCount === 1 ? "" : "s"} as the visual output from that tool.`; - rewritten.push({ - role: "user", - content: [{ type: "input_text", text: label }, ...imageParts], - }); - } - } - - return rewritten; -} - -// ─── Main sanitization ──────────────────────────────────────────────────────── - -/** - * Sanitize a provider request payload for xAI's Responses API via - * cli-chat-proxy.grok.com. - * - * Returns the modified payload. Mutates the input in place for efficiency. - */ -export function sanitizePayload( - params: Record, - modelId: string, - sessionId?: string, -): Record { - const next = params; - - // ── Sanitize input array ────────────────────────────────────────────── - if (Array.isArray(next.input)) { - let input = (next.input as unknown[]) - .map((item: unknown) => { - if (!item || typeof item !== "object") return item; - const obj = item as Record; - - // Strip replayed reasoning items - if (obj.type === "reasoning") return null; - - // Drop empty string content - if (typeof obj.content === "string" && obj.content.length === 0) - return null; - - return obj; - }) - .filter(Boolean) as Record[]; - - // Move system/developer messages to top-level instructions. - // xAI rejects role: "developer" and role: "system" in the input array. - const instructionParts: string[] = []; - while (input.length > 0) { - const first = input[0]; - if (!first || typeof first !== "object") break; - const role = (first as Record).role; - if (role !== "developer" && role !== "system") break; - const text = textFromContent( - (first as Record).content, - ).trim(); - if (text) instructionParts.push(text); - input.shift(); - } - if (instructionParts.length > 0) { - const existing = - typeof next.instructions === "string" && next.instructions - ? next.instructions - : ""; - const merged = [existing, ...instructionParts] - .filter((part) => part.length > 0) - .join("\n\n"); - next.instructions = merged; - } - - // Normalize image parts (resolve local paths, fix types) - input = normalizeImageParts(input) as Record[]; - - // Rewrite function_call_output with images - input = rewriteFunctionCallOutput(input); - - next.input = input; - } else if (typeof next.input === "string") { - // String input is valid and should stay string-shaped. - } - - // ── response_format → text.format ──────────────────────────────────── - if (next.response_format && !next.text) { - next.text = { format: next.response_format }; - delete next.response_format; - } - - // ── Reasoning effort ────────────────────────────────────────────────── - if (supportsReasoningEffort(modelId)) { - const reasoning = next.reasoning as Record | undefined; - if (reasoning && reasoning.effort === "minimal") { - next.reasoning = { ...reasoning, effort: "low" }; - } - if (reasoning && reasoning.summary !== undefined) { - next.reasoning = { effort: reasoning.effort }; - } - } else { - delete next.reasoning; - delete next.reasoningEffort; - } - - // ── Strip/filter unsupported fields ────────────────────────────────── - if (Array.isArray(next.include)) { - next.include = (next.include as unknown[]).filter( - (item) => item !== "reasoning.encrypted_content", - ); - if ((next.include as unknown[]).length === 0) delete next.include; - } - - delete next.prompt_cache_retention; - - // Add prompt_cache_key for conversation caching (routes to same server). - if (sessionId && !next.prompt_cache_key) { - next.prompt_cache_key = sessionId; - } - - return next; -} diff --git a/src/shared/errors.ts b/src/shared/errors.ts new file mode 100644 index 0000000..3d2bcc5 --- /dev/null +++ b/src/shared/errors.ts @@ -0,0 +1,44 @@ +/** + * Typed error for xAI OAuth failures. + * + * Codes allow the login flow and stream handlers to distinguish + * retryable failures (network) from fatal ones (revoked refresh token). + */ +export class XaiOAuthError extends Error { + constructor( + message: string, + public readonly code: string, + public readonly reloginRequired = false, + ) { + super(message); + this.name = 'XaiOAuthError'; + } +} + +/** Well-known error codes. */ +export const XaiErrorCode = { + /** OIDC discovery failed (network, invalid response). */ + DISCOVERY_FAILED: 'discovery_failed', + /** Discovery endpoint returned a non-xAI origin. */ + DISCOVERY_INVALID_ORIGIN: 'discovery_invalid_origin', + /** Authorization was denied or errored in the browser. */ + AUTHORIZATION_FAILED: 'authorization_failed', + /** CSRF state mismatch between request and callback. */ + STATE_MISMATCH: 'state_mismatch', + /** Callback did not include an authorization code. */ + CODE_MISSING: 'code_missing', + /** Token exchange failed (network, invalid response). */ + TOKEN_EXCHANGE_FAILED: 'token_exchange_failed', + /** Token exchange returned an invalid payload. */ + TOKEN_EXCHANGE_INVALID: 'token_exchange_invalid', + /** Refresh token is missing or empty. */ + REFRESH_MISSING: 'refresh_missing', + /** Token refresh failed (expired, revoked). */ + REFRESH_FAILED: 'refresh_failed', + /** No credentials stored. */ + AUTH_MISSING: 'auth_missing', + /** Loopback callback server could not bind. */ + CALLBACK_BIND_FAILED: 'callback_bind_failed', + /** Loopback callback timed out. */ + CALLBACK_TIMEOUT: 'callback_timeout', +} as const; diff --git a/src/tools/files.ts b/src/tools/files.ts new file mode 100644 index 0000000..f5e31e2 --- /dev/null +++ b/src/tools/files.ts @@ -0,0 +1,632 @@ +import { + existsSync, + promises as fs, + mkdirSync, + readFileSync, + unlinkSync, + writeFileSync, +} from 'node:fs'; +import { basename, dirname, join, resolve, sep } from 'node:path'; +import { Type } from '@earendil-works/pi-ai'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { + booleanDetail, + detailRecord, + fileError, + fileNotFound, + MAX_OUTPUT_CHARS, + numberDetail, + recordFrom, + renderResultSummary, + stringDetail, + stringFrom, + type ToolError, + text, +} from './rendering.js'; + +type ReplacementEdit = { oldText: string; newText: string }; +type FileDetails = { path: string; [key: string]: unknown }; +type WriteArgs = { path: string; content: string }; +type StrReplaceArgs = { path: string; old_str: string; new_str: string }; +type EditArgs = { + path: string; + edits?: ReplacementEdit[]; + applyPatch?: { patchContent: string }; + strReplace?: ReplacementEdit; + multiStrReplace?: { edits: ReplacementEdit[] }; +}; + +type ToolTheme = { + bold: (text: string) => string; + fg: (name: 'accent' | 'toolTitle', text: string) => string; +}; + +function parseEditList(value: unknown): ReplacementEdit[] | undefined { + const editList = typeof value === 'string' ? parseJson(value) : value; + if (!Array.isArray(editList)) return undefined; + if ( + !editList.every( + (edit) => + typeof recordFrom(edit)?.oldText === 'string' && + typeof recordFrom(edit)?.newText === 'string', + ) + ) { + return undefined; + } + return editList.map((edit) => ({ + oldText: stringFrom(recordFrom(edit)?.oldText) ?? '', + newText: stringFrom(recordFrom(edit)?.newText) ?? '', + })); +} + +function parseJson(value: string): unknown { + try { + return JSON.parse(value); + } catch { + return undefined; + } +} + +function editFromText(oldText: unknown, newText: unknown) { + if (typeof oldText !== 'string' || typeof newText !== 'string') return undefined; + return [{ oldText, newText }]; +} + +function editsFromArgs(input: Record) { + return ( + parseEditList(input.edits) ?? + parseEditList(recordFrom(input.multiStrReplace)?.edits) ?? + editFromText(input.oldText, input.newText) ?? + editFromText(recordFrom(input.strReplace)?.oldText, recordFrom(input.strReplace)?.newText) + ); +} + +function applyEdits(content: string, edits: ReplacementEdit[]) { + return edits.reduce( + (result, edit) => { + const count = result.content.split(edit.oldText).length - 1; + return { + content: + count === 0 + ? result.content + : result.content.replaceAll(edit.oldText, () => edit.newText), + replacements: result.replacements + count, + }; + }, + { content, replacements: 0 }, + ); +} + +function replacementResult(text: string, filePath: string) { + return { + content: [{ type: 'text' as const, text }], + details: { path: filePath, replacements: 0 }, + }; +} + +function renderReplacementResult( + result: { content: { type: string; text?: string }[]; details: unknown }, + expanded: boolean, + isPartial: boolean, + theme: { fg: (name: 'dim' | 'muted', text: string) => string }, +) { + const replacements = numberDetail(result, 'replacements'); + return renderResultSummary( + result, + expanded, + isPartial, + replacements === 0 + ? theme.fg('dim', 'No replacements') + : theme.fg('muted', `${replacements} replacement(s)`), + ); +} + +function renderPathToolCall(toolName: string, filePath: string, theme: ToolTheme) { + return text(theme.fg('toolTitle', theme.bold(`${toolName} `)) + theme.fg('accent', filePath)); +} + +async function canonicalizeWithinWorkspace(cwd: string, requestedPath: string) { + const targetPath = resolve(cwd, requestedPath); + const realCwd = await fs.realpath(cwd); + const missingParts: string[] = []; + let currentPath = targetPath; + let realTarget: string | undefined; + while (!realTarget) { + try { + realTarget = join(await fs.realpath(currentPath), ...[...missingParts].reverse()); + } catch (error) { + const parentPath = dirname(currentPath); + if (parentPath === currentPath) throw error; + missingParts.push(basename(currentPath)); + currentPath = parentPath; + } + } + if (realTarget !== realCwd && !realTarget.startsWith(`${realCwd}${sep}`)) { + throw new Error('Path is outside the workspace'); + } + return realTarget; +} + +async function existingPathWithinWorkspace(cwd: string, requestedPath: string) { + const safePath = await canonicalizeWithinWorkspace(cwd, requestedPath); + return existsSync(safePath) ? safePath : undefined; +} + +async function existingPathOrNotFound( + cwd: string, + requestedPath: string, + extraDetails: Omit, +) { + return ( + (await existingPathWithinWorkspace(cwd, requestedPath)) ?? + fileNotFound(resolve(cwd, requestedPath), extraDetails) + ); +} + +function replacementPathOrNotFound(cwd: string, requestedPath: string) { + return existingPathOrNotFound(cwd, requestedPath, { replacements: 0 }); +} + +export function registerFileTools(pi: ExtensionAPI) { + // ── LS tool ────────────────────────────────────────────────────────── + + const LsParams = Type.Object({ + path: Type.String({ + description: 'Directory path to list', + }), + }); + + pi.registerTool({ + name: 'LS', + label: 'LS', + description: 'List the contents of a directory, including hidden files.', + parameters: LsParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const targetPath = resolve(ctx.cwd, params.path); + + try { + const safePath = await canonicalizeWithinWorkspace(ctx.cwd, params.path); + if (signal?.aborted) throw new Error('The operation was aborted'); + + let output = (await fs.readdir(safePath)).sort().join('\n'); + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[LS: output truncated at 50KB]`; + } + + return { + content: [{ type: 'text', text: output }], + details: { path: safePath }, + }; + } catch (error: unknown) { + const err = error as ToolError; + const message = err.message ?? 'Unknown error'; + return { + content: [ + { + type: 'text', + text: `LS error: ${message}`, + }, + ], + details: { path: targetPath, failed: true, error: message }, + }; + } + }, + renderCall(args, theme) { + return renderPathToolCall('LS', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + theme.fg('muted', stringDetail(result, 'path')), + ); + }, + }); + + // ── Read tool ──────────────────────────────────────────────────────── + + const ReadParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to read', + }), + offset: Type.Optional( + Type.Number({ + description: 'Line number to start reading from (0-indexed)', + }), + ), + limit: Type.Optional( + Type.Number({ + description: 'Maximum number of lines to read', + }), + ), + }); + + pi.registerTool({ + name: 'Read', + label: 'Read', + description: 'Read the contents of a file. Returns the file content with line numbers.', + parameters: ReadParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + const safePath = await existingPathOrNotFound(ctx.cwd, params.path, { + exists: false, + totalLines: 0, + }); + if (typeof safePath !== 'string') return safePath; + + const content = readFileSync(safePath, 'utf-8'); + const lines = content.endsWith('\n') + ? content.slice(0, -1).split('\n') + : content.split('\n'); + + const startLine = params.offset ?? 0; + const endLine = + params.limit !== undefined + ? Math.min(startLine + params.limit, lines.length) + : Math.min(startLine + 2000, lines.length); + + const selectedLines = lines.slice(startLine, endLine); + const numberedLines = selectedLines.map((line, i) => `${startLine + i + 1}\t${line}`); + + let output = numberedLines.join('\n'); + if (endLine < lines.length) { + output += `\n\n[Showing lines ${startLine + 1}-${endLine} of ${lines.length} total lines. Use offset to see more.]`; + } + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [{ type: 'text', text: output }], + details: { path: safePath, totalLines: lines.length }, + }; + } catch (error: unknown) { + const err = error as { code?: string }; + return fileError(error, 'Read', filePath, { + exists: err.code !== 'ENOENT', + totalLines: 0, + }); + } + }, + renderCall(args, theme) { + const range = + args.offset !== undefined || args.limit !== undefined + ? theme.fg( + 'muted', + ` (from ${args.offset ?? 0}${args.limit !== undefined ? `, ${args.limit} lines` : ''})`, + ) + : ''; + return text( + theme.fg('toolTitle', theme.bold('Read ')) + theme.fg('accent', args.path) + range, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + detailRecord(result).exists === false + ? theme.fg('error', 'File not found') + : theme.fg('muted', `${numberDetail(result, 'totalLines')} line(s)`), + ); + }, + }); + + // ── Write tool ─────────────────────────────────────────────────────── + + const WriteParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to write', + }), + content: Type.String({ + description: 'Content to write to the file', + }), + }); + + pi.registerTool({ + name: 'Write', + label: 'Write', + description: + 'Create or overwrite a file with the given content. Creates parent directories if needed.', + parameters: WriteParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as WriteArgs; + return { + ...input, + content: stringFrom(input.content) ?? stringFrom(input.contents), + } as WriteArgs; + }, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + const safePath = await canonicalizeWithinWorkspace(ctx.cwd, params.path); + mkdirSync(dirname(safePath), { recursive: true }); + writeFileSync(safePath, params.content, 'utf-8'); + const bytesWritten = Buffer.byteLength(params.content, 'utf8'); + + return { + content: [ + { + type: 'text', + text: `Successfully wrote ${bytesWritten} bytes to ${params.path}`, + }, + ], + details: { path: safePath, bytesWritten }, + }; + } catch (error: unknown) { + const err = error as ToolError; + const message = err.message ?? 'Unknown error'; + return { + content: [ + { + type: 'text', + text: `Write error: ${message}`, + }, + ], + details: { path: filePath, bytesWritten: 0, failed: true, error: message }, + }; + } + }, + renderCall(args, theme) { + return renderPathToolCall('Write', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + theme.fg('muted', `${numberDetail(result, 'bytesWritten')} bytes written`), + ); + }, + }); + + // ── StrReplace tool ────────────────────────────────────────────────── + + const StrReplaceParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to modify', + }), + old_str: Type.String({ + description: 'String to search for (exact match)', + }), + new_str: Type.String({ + description: 'String to replace with', + }), + }); + + pi.registerTool({ + name: 'StrReplace', + label: 'StrReplace', + description: + 'Replace all occurrences of a string in a file. The old_str must be an exact match.', + parameters: StrReplaceParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as StrReplaceArgs; + return { + ...input, + old_str: + stringFrom(input.old_str) ?? + stringFrom(input.old_string) ?? + stringFrom(input.oldText) ?? + stringFrom(recordFrom(input.strReplace)?.oldText), + new_str: + stringFrom(input.new_str) ?? + stringFrom(input.new_string) ?? + stringFrom(input.newText) ?? + stringFrom(recordFrom(input.strReplace)?.newText), + } as StrReplaceArgs; + }, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const requestedPath = params.path; + const filePath = resolve(ctx.cwd, requestedPath); + + try { + const safePath = await replacementPathOrNotFound(ctx.cwd, requestedPath); + if (typeof safePath !== 'string') return safePath; + + const content = readFileSync(safePath, 'utf-8'); + if (params.old_str === '') { + return replacementResult('StrReplace error: old_str must not be empty', safePath); + } + + const count = content.split(params.old_str).length - 1; + + if (count === 0) { + return replacementResult( + `String not found in ${params.path}: "${params.old_str}"`, + safePath, + ); + } + + const newContent = content.replaceAll(params.old_str, () => params.new_str); + writeFileSync(safePath, newContent, 'utf-8'); + + return { + content: [ + { + type: 'text', + text: `Replaced ${count} occurrence(s) in ${params.path}`, + }, + ], + details: { path: safePath, replacements: count }, + }; + } catch (error: unknown) { + return fileError(error, 'StrReplace', filePath, { replacements: 0 }); + } + }, + renderCall(args, theme) { + return renderPathToolCall('StrReplace', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderReplacementResult(result, expanded, isPartial, theme); + }, + }); + + // ── Edit tool ──────────────────────────────────────────────────────── + + const EditItemParams = Type.Object({ + oldText: Type.String({ + description: 'String to search for (exact match)', + }), + newText: Type.String({ + description: 'String to replace with', + }), + replaceAll: Type.Optional( + Type.Boolean({ + description: + 'Accepted for Cursor compatibility. Replacements are always applied to all matches.', + }), + ), + }); + + const EditParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to modify', + }), + edits: Type.Optional( + Type.Array(EditItemParams, { + description: 'Exact text replacements to apply sequentially', + }), + ), + applyPatch: Type.Optional( + Type.Object({ + patchContent: Type.String({ + description: 'Unsupported unified patch content', + }), + }), + ), + strReplace: Type.Optional(EditItemParams), + multiStrReplace: Type.Optional( + Type.Object({ + edits: Type.Array(EditItemParams), + }), + ), + }); + + pi.registerTool({ + name: 'Edit', + label: 'Edit', + description: + 'Modify a file with exact text replacement. applyPatch is not supported by this Grok tool shim.', + parameters: EditParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as EditArgs; + return { + ...input, + edits: editsFromArgs(input), + } as EditArgs; + }, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + const safePath = await replacementPathOrNotFound(ctx.cwd, params.path); + if (typeof safePath !== 'string') return safePath; + if (!params.edits?.length) { + return { + content: [ + { + type: 'text', + text: params.applyPatch + ? 'Edit error: applyPatch is not supported by this Grok tool shim' + : 'Edit error: provide at least one exact text replacement', + }, + ], + details: { path: safePath, replacements: 0 }, + }; + } + if (params.edits.some((edit) => edit.oldText === '')) { + return replacementResult('Edit error: oldText must not be empty', safePath); + } + + const result = applyEdits(readFileSync(safePath, 'utf-8'), params.edits); + + if (result.replacements === 0) { + return replacementResult(`No replacement strings found in ${params.path}`, safePath); + } + + writeFileSync(safePath, result.content, 'utf-8'); + + return { + content: [ + { + type: 'text', + text: `Applied ${result.replacements} replacement(s) in ${params.path}`, + }, + ], + details: { path: safePath, replacements: result.replacements }, + }; + } catch (error: unknown) { + return fileError(error, 'Edit', filePath, { replacements: 0 }); + } + }, + renderCall(args, theme) { + return renderPathToolCall('Edit', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderReplacementResult(result, expanded, isPartial, theme); + }, + }); + + // ── Delete tool ────────────────────────────────────────────────────── + + const DeleteParams = Type.Object({ + path: Type.String({ + description: 'Path to the file to delete', + }), + }); + + pi.registerTool({ + name: 'Delete', + label: 'Delete', + description: 'Delete a file from the filesystem.', + parameters: DeleteParams, + + async execute(_toolCallId, params, _signal, _onUpdate, ctx) { + const filePath = resolve(ctx.cwd, params.path); + + try { + const safePath = await existingPathOrNotFound(ctx.cwd, params.path, { deleted: false }); + if (typeof safePath !== 'string') return safePath; + + unlinkSync(safePath); + + return { + content: [{ type: 'text', text: `Successfully deleted ${params.path}` }], + details: { path: safePath, deleted: true }, + }; + } catch (error: unknown) { + return fileError(error, 'Delete', filePath, { deleted: false }); + } + }, + renderCall(args, theme) { + return renderPathToolCall('Delete', args.path, theme); + }, + renderResult(result, { expanded, isPartial }, theme) { + return renderResultSummary( + result, + expanded, + isPartial, + booleanDetail(result, 'deleted') + ? theme.fg('muted', 'Deleted') + : theme.fg('error', 'Not deleted'), + ); + }, + }); +} diff --git a/src/tools/register.ts b/src/tools/register.ts new file mode 100644 index 0000000..b981582 --- /dev/null +++ b/src/tools/register.ts @@ -0,0 +1,22 @@ +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { registerFileTools } from './files.js'; +import { registerSearchTools } from './search.js'; +import { registerShellTool } from './shell.js'; + +export const GROK_TOOL_NAMES = [ + 'Grep', + 'Glob', + 'LS', + 'Read', + 'Write', + 'StrReplace', + 'Edit', + 'Delete', + 'Shell', +]; + +export function registerGrokTools(pi: ExtensionAPI) { + registerSearchTools(pi); + registerFileTools(pi); + registerShellTool(pi); +} diff --git a/src/tools/rendering.ts b/src/tools/rendering.ts new file mode 100644 index 0000000..084bcda --- /dev/null +++ b/src/tools/rendering.ts @@ -0,0 +1,257 @@ +import { execFile } from 'node:child_process'; +import { promises as fs } from 'node:fs'; +import { basename, join, relative } from 'node:path'; +import { promisify } from 'node:util'; +import { Text } from '@earendil-works/pi-tui'; + +const execFileAsync = promisify(execFile); + +export const MAX_OUTPUT_CHARS = 50_000; +export const MAX_OUTPUT_BYTES = MAX_OUTPUT_CHARS * 4 + 1024; +export const MAX_LINES = 500; + +export function recordFrom(value: unknown): Record | undefined { + if (!value || typeof value !== 'object') return undefined; + return value as Record; +} + +export function stringFrom(value: unknown): string | undefined { + if (typeof value !== 'string') return undefined; + return value; +} + +export function truncateLines(lines: string[]): string { + if (lines.length > MAX_LINES) { + return ( + lines.slice(0, MAX_LINES).join('\n') + + `\n\n[Showing first ${MAX_LINES} of ${lines.length} results. Refine your pattern to narrow results.]` + ); + } + return lines.join('\n'); +} + +export function truncateChars(output: string): string { + if (output.length > MAX_OUTPUT_CHARS) { + return `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + return output; +} + +export function globToRegExp(pattern: string) { + let source = '^'; + for (let i = 0; i < pattern.length; i += 1) { + const char = pattern[i]; + const next = pattern[i + 1]; + if (char === '*' && next === '*' && pattern[i + 2] === '/') { + source += '(?:.*/)?'; + i += 2; + } else if (char === '*' && next === '*') { + source += '.*'; + i += 1; + } else if (char === '*') { + source += '[^/]*'; + } else if (char === '?') { + source += '[^/]'; + } else { + source += char.replace(/[|\\{}()[\]^$+?.]/g, '\\$&'); + } + } + return new RegExp(`${source}$`); +} + +export function normalizePath(filePath: string) { + return filePath.replaceAll('\\', '/'); +} + +export async function listFilesRecursive( + searchPath: string, + signal?: AbortSignal, +): Promise { + if (signal?.aborted) throw new Error('The operation was aborted'); + const stats = await fs.stat(searchPath); + if (stats.isFile()) return [searchPath]; + if (!stats.isDirectory()) return []; + + return ( + await Promise.all( + ( + await fs.readdir(searchPath, { withFileTypes: true }) + ).map((entry) => { + const entryPath = join(searchPath, entry.name); + if (entry.isDirectory()) return listFilesRecursive(entryPath, signal); + if (entry.isFile()) return [entryPath]; + return []; + }), + ) + ).flat(); +} + +let rgAvailable: boolean | undefined; +export async function hasRipgrep(): Promise { + if (rgAvailable !== undefined) return rgAvailable; + try { + await execFileAsync('rg', ['--version']); + rgAvailable = true; + } catch { + rgAvailable = false; + } + return rgAvailable; +} + +export type ToolError = { code?: number; message?: string }; +export type ToolResult = { + content: [{ type: 'text'; text: string }]; + details: T; +}; + +export function text(text: string): Text { + return new Text(text, 0, 0); +} + +function firstText(result: { content: { type: string; text?: string }[] }) { + const first = result.content[0]; + if (first?.type !== 'text') return undefined; + return first.text; +} + +export function renderResultText( + result: { content: { type: string; text?: string }[] }, + expanded: boolean, + summary: string, +): Text { + if (expanded) return text(firstText(result) ?? summary); + return text(summary); +} + +export function renderRunning(isPartial: boolean): Text | undefined { + if (!isPartial) return undefined; + return text('Running...'); +} + +export function renderResultSummary( + result: { content: { type: string; text?: string }[] }, + expanded: boolean, + isPartial: boolean, + summary: string, +): Text { + const running = renderRunning(isPartial); + if (running) return running; + return renderResultText(result, expanded, summary); +} + +export function detailRecord(result: { details: unknown }): Record { + if (!result.details || typeof result.details !== 'object') return {}; + return result.details as Record; +} + +export function numberDetail(result: { details: unknown }, key: string): number { + const value = detailRecord(result)[key]; + if (typeof value !== 'number') return 0; + return value; +} + +export function stringDetail(result: { details: unknown }, key: string): string { + const value = detailRecord(result)[key]; + if (typeof value !== 'string') return ''; + return value; +} + +export function booleanDetail(result: { details: unknown }, key: string): boolean { + const value = detailRecord(result)[key]; + return value === true; +} + +type FileDetails = { path: string; [key: string]: unknown }; + +export function fileNotFound( + filePath: string, + extraDetails: Omit, +): ToolResult { + return { + content: [{ type: 'text', text: `File not found: ${filePath}` }], + details: { path: filePath, ...extraDetails } as T, + }; +} + +export function fileError( + error: unknown, + toolName: string, + filePath: string, + extraDetails: Omit, +): ToolResult { + const err = error as ToolError; + const message = err.message ?? 'Unknown error'; + return { + content: [ + { + type: 'text', + text: `${toolName} error: ${message}`, + }, + ], + details: { path: filePath, ...extraDetails, failed: true, error: message } as unknown as T, + }; +} + +export function toolError(error: unknown, toolName: string, emptyDetails: T): ToolResult { + const err = error as ToolError; + if (toolName === 'Grep' && err.code === 1) { + return { + content: [{ type: 'text', text: 'No matches found' }], + details: emptyDetails, + }; + } + const message = err.message ?? 'Unknown error'; + return { + content: [ + { + type: 'text', + text: `${toolName} error: ${message}`, + }, + ], + details: { ...emptyDetails, failed: true, error: message } as T, + }; +} + +export async function execWithRgFallback( + rgArgs: string[], + options: { + cwd: string; + signal?: AbortSignal; + pattern: string; + searchPath: string; + include?: string; + }, +): Promise { + if (await hasRipgrep()) { + const result = await execFileAsync('rg', rgArgs, { + cwd: options.cwd, + maxBuffer: MAX_OUTPUT_BYTES, + signal: options.signal, + }); + return result.stdout; + } + + const regex = new RegExp(options.pattern); + const matcher = options.include ? globToRegExp(normalizePath(options.include)) : undefined; + return ( + await Promise.all( + ( + await listFilesRecursive(options.searchPath, options.signal) + ) + .filter((file) => { + if (!matcher) return true; + if (!options.include?.includes('/')) return matcher.test(basename(file)); + return matcher.test(normalizePath(relative(options.cwd, file))); + }) + .map(async (file) => + ( + await fs.readFile(file, 'utf8') + ) + .split(/\r?\n/) + .flatMap((line, index) => (regex.test(line) ? `${file}:${index + 1}:${line}` : [])), + ), + ) + ) + .flat() + .join('\n'); +} diff --git a/src/tools/search.ts b/src/tools/search.ts new file mode 100644 index 0000000..69a784f --- /dev/null +++ b/src/tools/search.ts @@ -0,0 +1,224 @@ +import { execFile } from 'node:child_process'; +import { statSync } from 'node:fs'; +import { basename, relative, resolve } from 'node:path'; +import { promisify } from 'node:util'; +import { Type } from '@earendil-works/pi-ai'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { + execWithRgFallback, + globToRegExp, + hasRipgrep, + listFilesRecursive, + MAX_OUTPUT_BYTES, + normalizePath, + numberDetail, + recordFrom, + renderResultText, + renderRunning, + stringFrom, + text, + toolError, + truncateChars, + truncateLines, +} from './rendering.js'; + +const execFileAsync = promisify(execFile); + +type GrepArgs = { pattern: string; path?: string; include?: string }; +type GlobArgs = { pattern: string; path?: string }; + +function modifiedTimeMs(file: string) { + try { + return statSync(file).mtimeMs; + } catch { + return 0; + } +} + +export function sortByModifiedNewest(files: string[]) { + return files.sort((a, b) => { + const delta = modifiedTimeMs(b) - modifiedTimeMs(a); + if (delta !== 0) return delta; + return a.localeCompare(b); + }); +} + +export function registerSearchTools(pi: ExtensionAPI) { + const GrepParams = Type.Object({ + pattern: Type.String({ + description: 'Regex pattern to search for in file contents', + }), + path: Type.Optional( + Type.String({ + description: 'Directory or file to search. Defaults to current working directory.', + }), + ), + include: Type.Optional( + Type.String({ + description: 'Glob pattern to filter which files are searched (e.g. *.ts, **/*.md)', + }), + ), + }); + + pi.registerTool({ + name: 'Grep', + label: 'Grep', + description: + 'Search for a regex pattern in file contents. Returns matching lines with file path and line number. Use the include parameter to filter by file type.', + parameters: GrepParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as GrepArgs; + return { + ...input, + include: stringFrom(input.include) ?? stringFrom(input.glob_filter), + } as GrepArgs; + }, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const searchPath = resolve(ctx.cwd, params.path ?? '.'); + + try { + const rgArgs = ['-n', '-H', '--no-heading', '--color=never']; + if (params.include) rgArgs.push('--glob', params.include); + rgArgs.push('--', params.pattern, searchPath); + + const stdout = await execWithRgFallback(rgArgs, { + cwd: ctx.cwd, + signal, + pattern: params.pattern, + searchPath, + include: params.include, + }); + + const lines = stdout.trim().split('\n').filter(Boolean); + if (lines.length === 0) { + return { + content: [{ type: 'text', text: 'No matches found' }], + details: { matchCount: 0 }, + }; + } + + return { + content: [{ type: 'text', text: truncateChars(truncateLines(lines)) }], + details: { matchCount: lines.length }, + }; + } catch (error: unknown) { + return toolError(error, 'Grep', { matchCount: 0 }); + } + }, + renderCall(args, theme) { + const path = args.path ? theme.fg('muted', ` in ${args.path}`) : ''; + const include = args.include ? theme.fg('dim', ` [${args.include}]`) : ''; + return text( + theme.fg('toolTitle', theme.bold('Grep ')) + + theme.fg('accent', `"${args.pattern}"`) + + path + + include, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const matchCount = numberDetail(result, 'matchCount'); + return renderResultText( + result, + expanded, + matchCount === 0 + ? theme.fg('dim', 'No matches') + : theme.fg('muted', `${matchCount} match(es)`), + ); + }, + }); + + const GlobParams = Type.Object({ + pattern: Type.String({ + description: 'Glob pattern to match files (e.g. **/*.ts, src/**/*.json)', + }), + path: Type.Optional( + Type.String({ + description: 'Directory to search within. Defaults to current working directory.', + }), + ), + }); + + pi.registerTool({ + name: 'Glob', + label: 'Glob', + description: + 'Find files matching a glob pattern. Returns a list of matching file paths sorted by modification time (newest first).', + parameters: GlobParams, + + prepareArguments(args) { + const input = recordFrom(args); + if (!input) return args as GlobArgs; + return { + ...input, + pattern: stringFrom(input.pattern) ?? stringFrom(input.glob_pattern), + } as GlobArgs; + }, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const searchPath = resolve(ctx.cwd, params.path ?? '.'); + + try { + let files: string[]; + + if (await hasRipgrep()) { + const result = await execFileAsync( + 'rg', + ['--files', '--color=never', '--glob', params.pattern, searchPath], + { cwd: ctx.cwd, maxBuffer: MAX_OUTPUT_BYTES, signal }, + ); + files = result.stdout.trim().split('\n').filter(Boolean); + } else { + const normalizedPattern = normalizePath(params.pattern); + const matcher = globToRegExp(normalizedPattern); + const matchesFile = normalizedPattern.includes('/') + ? (file: string) => matcher.test(normalizePath(relative(ctx.cwd, file))) + : (file: string) => matcher.test(basename(file)); + files = (await listFilesRecursive(searchPath, signal)).filter(matchesFile); + } + files = sortByModifiedNewest(files); + + if (files.length === 0) { + return { + content: [{ type: 'text', text: 'No files found' }], + details: { fileCount: 0 }, + }; + } + + return { + content: [{ type: 'text', text: truncateChars(truncateLines(files)) }], + details: { fileCount: files.length }, + }; + } catch (error: unknown) { + const err = error as { code?: unknown; stderr?: string }; + if (err.code === 1 && !err.stderr) { + return { + content: [{ type: 'text', text: 'No files found' }], + details: { fileCount: 0 }, + }; + } + return toolError(error, 'Glob', { fileCount: 0 }); + } + }, + renderCall(args, theme) { + const path = args.path ? theme.fg('muted', ` in ${args.path}`) : ''; + return text( + theme.fg('toolTitle', theme.bold('Glob ')) + theme.fg('accent', args.pattern) + path, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const fileCount = numberDetail(result, 'fileCount'); + return renderResultText( + result, + expanded, + fileCount === 0 ? theme.fg('dim', 'No files') : theme.fg('muted', `${fileCount} file(s)`), + ); + }, + }); +} diff --git a/src/tools/shell.ts b/src/tools/shell.ts new file mode 100644 index 0000000..5078826 --- /dev/null +++ b/src/tools/shell.ts @@ -0,0 +1,157 @@ +import { execFile } from 'node:child_process'; +import { existsSync } from 'node:fs'; +import { resolve } from 'node:path'; +import { promisify } from 'node:util'; +import { Type } from '@earendil-works/pi-ai'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { + detailRecord, + MAX_OUTPUT_BYTES, + MAX_OUTPUT_CHARS, + renderResultText, + renderRunning, + text, +} from './rendering.js'; + +const execFileAsync = promisify(execFile); + +function shellCommand(command: string): { file: string; args: string[] } | undefined { + if (process.platform === 'win32') { + if (existsSync('C:\\Windows\\System32\\cmd.exe')) { + return { file: 'cmd.exe', args: ['/d', '/s', '/c', command] }; + } + if (existsSync('C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe')) { + return { + file: 'powershell.exe', + args: ['-NoLogo', '-NoProfile', '-Command', command], + }; + } + return undefined; + } + + if ( + process.platform !== 'darwin' && + process.platform !== 'linux' && + process.platform !== 'freebsd' + ) { + return undefined; + } + + if (existsSync('/bin/bash')) return { file: '/bin/bash', args: ['-c', command] }; + if (existsSync('/usr/bin/bash')) return { file: '/usr/bin/bash', args: ['-c', command] }; + if (existsSync('/bin/sh')) return { file: '/bin/sh', args: ['-c', command] }; + if (existsSync('/usr/bin/sh')) return { file: '/usr/bin/sh', args: ['-c', command] }; + return undefined; +} + +export function registerShellTool(pi: ExtensionAPI) { + // ── Shell tool ─────────────────────────────────────────────────────── + + const ShellParams = Type.Object({ + command: Type.String({ + description: 'Shell command to execute', + }), + working_directory: Type.Optional( + Type.String({ + description: 'Working directory for the command', + }), + ), + timeout: Type.Optional( + Type.Number({ + description: 'Timeout in milliseconds (default: 120000)', + }), + ), + }); + + pi.registerTool({ + name: 'Shell', + label: 'Shell', + description: 'Execute a shell command and return stdout, stderr, and exit code.', + parameters: ShellParams, + + async execute(_toolCallId, params, signal, _onUpdate, ctx) { + const cwd = params.working_directory ? resolve(ctx.cwd, params.working_directory) : ctx.cwd; + const timeout = params.timeout ?? 120_000; + + try { + const shell = shellCommand(params.command); + if (!shell) { + return { + content: [ + { + type: 'text', + text: 'Shell error: unsupported platform or shell not found', + }, + ], + details: { exitCode: 1, command: params.command }, + }; + } + const { stdout, stderr } = await execFileAsync(shell.file, shell.args, { + cwd, + maxBuffer: MAX_OUTPUT_BYTES, + timeout, + signal, + }); + + let output = ''; + if (stdout) output += stdout; + if (stderr) output += `\n[stderr]\n${stderr}`; + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [{ type: 'text', text: output || '(no output)' }], + details: { exitCode: 0, command: params.command }, + }; + } catch (error: unknown) { + const err = error as { + code?: unknown; + message?: string; + stdout?: string; + stderr?: string; + }; + const exitCode = typeof err.code === 'number' ? err.code : 1; + + let output = ''; + if (err.stdout) output += err.stdout; + if (err.stderr) output += `\n[stderr]\n${err.stderr}`; + + if (output.length > MAX_OUTPUT_CHARS) { + output = `${output.slice(0, MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`; + } + + return { + content: [ + { + type: 'text', + text: `Shell error (exit code ${err.code ?? 'unknown'}): ${err.message ?? 'Unknown error'}${output ? `\n${output}` : ''}`, + }, + ], + details: { + exitCode, + command: params.command, + }, + }; + } + }, + renderCall(args, theme) { + const cwd = args.working_directory ? theme.fg('muted', ` in ${args.working_directory}`) : ''; + return text( + theme.fg('toolTitle', theme.bold('Shell ')) + theme.fg('accent', args.command) + cwd, + ); + }, + renderResult(result, { expanded, isPartial }, theme) { + const running = renderRunning(isPartial); + if (running) return running; + const exitCode = + typeof detailRecord(result).exitCode === 'number' ? detailRecord(result).exitCode : 1; + return renderResultText( + result, + expanded, + exitCode === 0 ? theme.fg('muted', 'Exit 0') : theme.fg('warning', `Exit ${exitCode}`), + ); + }, + }); +} diff --git a/tests/auth/oauth.test.ts b/tests/auth/oauth.test.ts new file mode 100644 index 0000000..57ea9d2 --- /dev/null +++ b/tests/auth/oauth.test.ts @@ -0,0 +1,366 @@ +import { afterEach, describe, expect, it, vi } from 'vitest'; +import { getBaseUrl, login, refresh } from '../../src/auth/oauth.js'; +import { XaiErrorCode } from '../../src/shared/errors.js'; + +const originalEnv = { ...process.env }; +const originalFetch = globalThis.fetch; +const storedRefreshCredentials = { + access: 'access-token', + refresh: 'refresh-token', + expires: 0, + tokenEndpoint: 'https://auth.x.ai/oauth/token', +}; +const credentialsWithoutEndpoint = { + access: 'old-access', + refresh: 'old-refresh', + expires: 0, +}; +const discoveryDocument = { + authorization_endpoint: 'https://auth.x.ai/oauth/authorize', + token_endpoint: 'https://auth.x.ai/oauth/token', +}; + +function authorizeCallback(auth: { url: string }) { + const url = new URL(auth.url); + void originalFetch( + `${url.searchParams.get('redirect_uri')}?code=callback-code&state=${url.searchParams.get('state')}`, + ); +} + +afterEach(() => { + process.env = { ...originalEnv }; + globalThis.fetch = originalFetch; + vi.restoreAllMocks(); + vi.useRealTimers(); +}); + +describe('OAuth helpers without network access', () => { + it('resolves and trims the configured base URL', () => { + delete process.env.GROK_CLI_BASE_URL; + delete process.env.PI_GROK_CLI_BASE_URL; + expect(getBaseUrl()).toBe('https://cli-chat-proxy.grok.com/v1'); + + process.env.GROK_CLI_BASE_URL = 'https://example.invalid/v1///'; + expect(getBaseUrl()).toBe('https://example.invalid/v1'); + + process.env.PI_GROK_CLI_BASE_URL = 'https://override.invalid/api//'; + expect(getBaseUrl()).toBe('https://override.invalid/api'); + }); + + it('rejects refresh credentials with no refresh token before fetching', async () => { + const fetchMock = vi.fn(); + globalThis.fetch = fetchMock; + + await expect( + refresh({ + access: 'access-token', + refresh: '', + expires: 0, + tokenEndpoint: 'https://auth.x.ai/oauth/token', + }), + ).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_MISSING, + reloginRequired: true, + }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it('refreshes credentials with the configured token endpoint', async () => { + vi.useFakeTimers(); + vi.setSystemTime(1_700_000_000_000); + process.env.PI_GROK_CLI_BASE_URL = 'https://proxy.example/v1//'; + const fetchMock = vi.fn(async () => + Response.json({ + access_token: 'new-access', + refresh_token: 'new-refresh', + expires_in: 600, + id_token: 'new-id', + token_type: 'DPoP', + }), + ); + globalThis.fetch = fetchMock; + + await expect( + refresh({ + access: 'old-access', + refresh: 'old-refresh', + expires: 0, + tokenEndpoint: 'https://auth.x.ai/oauth/token', + idToken: 'old-id', + tokenType: 'Bearer', + }), + ).resolves.toMatchObject({ + access: 'new-access', + refresh: 'new-refresh', + expires: 1_700_000_480_000, + tokenEndpoint: 'https://auth.x.ai/oauth/token', + idToken: 'new-id', + tokenType: 'DPoP', + baseUrl: 'https://proxy.example/v1', + }); + + expect(fetchMock).toHaveBeenCalledOnce(); + expect(fetchMock.mock.calls[0]?.[0]).toBe('https://auth.x.ai/oauth/token'); + expect(fetchMock.mock.calls[0]?.[1]).toMatchObject({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }, + }); + expect((fetchMock.mock.calls[0]?.[1]?.body as URLSearchParams).toString()).toBe( + 'grant_type=refresh_token&client_id=b1a00492-073a-47ea-816f-4c329264a828&refresh_token=old-refresh', + ); + }); + + it('keeps the existing refresh token and metadata when refresh omits optional fields', async () => { + const fetchMock = vi.fn(async () => + Response.json({ access_token: 'new-access', expires_in: '900' }), + ); + globalThis.fetch = fetchMock; + + await expect( + refresh({ + access: 'old-access', + refresh: 'old-refresh', + expires: 0, + discovery: { + authorization_endpoint: 'https://auth.x.ai/oauth/authorize', + token_endpoint: 'https://accounts.x.ai/oauth/token', + }, + idToken: 'old-id', + tokenType: 'Bearer', + }), + ).resolves.toMatchObject({ + access: 'new-access', + refresh: 'old-refresh', + tokenEndpoint: 'https://accounts.x.ai/oauth/token', + idToken: 'old-id', + tokenType: 'Bearer', + }); + }); + + it('marks unauthorized refresh failures as requiring login', async () => { + const fetchMock = vi.fn(async () => new Response('revoked', { status: 401 })); + globalThis.fetch = fetchMock; + + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + reloginRequired: true, + message: 'xAI token refresh failed: 401 revoked', + }); + }); + + it('keeps server refresh failures retryable', async () => { + const fetchMock = vi.fn( + async () => new Response('temporarily unavailable', { status: 500 }), + ); + globalThis.fetch = fetchMock; + + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + reloginRequired: false, + message: 'xAI token refresh failed: 500 temporarily unavailable', + }); + }); + + it('rejects refresh responses without an access token', async () => { + const fetchMock = vi.fn(async () => Response.json({})); + globalThis.fetch = fetchMock; + + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + reloginRequired: true, + message: 'xAI token refresh did not return access_token.', + }); + }); + + it('wraps refresh transport and JSON failures', async () => { + globalThis.fetch = vi.fn(async () => { + throw new Error('socket closed'); + }); + + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + message: 'xAI token refresh failed: socket closed', + }); + + globalThis.fetch = vi.fn( + async () => new Response('proxy error', { status: 200 }), + ); + + await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ + code: XaiErrorCode.REFRESH_FAILED, + message: expect.stringContaining('xAI token refresh returned invalid JSON:'), + }); + }); + + it('rejects unsafe token endpoints before fetching', async () => { + const fetchMock = vi.fn(); + globalThis.fetch = fetchMock; + + await expect( + refresh({ + ...storedRefreshCredentials, + tokenEndpoint: 'https://evil.example/oauth/token', + }), + ).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + message: 'Refusing non-xAI OAuth token_endpoint: https://evil.example/oauth/token', + }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it('discovers the token endpoint when credentials do not include it', async () => { + const fetchMock = vi.fn(async (input) => { + if (input === 'https://auth.x.ai/.well-known/openid-configuration') { + return Response.json(discoveryDocument); + } + return Response.json({ access_token: 'new-access' }); + }); + globalThis.fetch = fetchMock; + + await expect(refresh(credentialsWithoutEndpoint)).resolves.toMatchObject({ + access: 'new-access', + refresh: 'old-refresh', + tokenEndpoint: 'https://auth.x.ai/oauth/token', + }); + expect(fetchMock.mock.calls.map((call) => call[0])).toEqual([ + 'https://auth.x.ai/.well-known/openid-configuration', + 'https://auth.x.ai/oauth/token', + ]); + }); + + it('wraps discovery network failures', async () => { + globalThis.fetch = vi.fn(async () => { + throw new Error('network down'); + }); + + await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_FAILED, + message: 'xAI OIDC discovery failed: network down', + }); + }); + + it('wraps malformed discovery JSON as discovery failure', async () => { + globalThis.fetch = vi.fn( + async () => new Response('proxy error', { status: 200 }), + ); + + await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_FAILED, + message: expect.stringContaining('xAI OIDC discovery returned invalid JSON:'), + }); + }); + + it('rejects failed and invalid discovery responses', async () => { + globalThis.fetch = vi.fn( + async () => new Response('unavailable', { status: 503 }), + ); + await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_FAILED, + message: 'xAI OIDC discovery returned 503', + }); + + globalThis.fetch = vi.fn(async () => + Response.json({ + authorization_endpoint: 'http://auth.x.ai/oauth/authorize', + token_endpoint: 'https://auth.x.ai/oauth/token', + }), + ); + await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ + code: XaiErrorCode.DISCOVERY_INVALID_ORIGIN, + message: 'xAI OAuth authorization_endpoint must use HTTPS: http://auth.x.ai/oauth/authorize', + }); + }); + + it('logs in with a loopback callback and exchanges the authorization code', async () => { + vi.useFakeTimers(); + vi.setSystemTime(1_700_000_000_000); + const fetchMock = vi.fn(async (input) => { + if (input === 'https://auth.x.ai/.well-known/openid-configuration') { + return Response.json(discoveryDocument); + } + return Response.json({ + access_token: 'login-access', + refresh_token: 'login-refresh', + expires_in: 900, + id_token: 'login-id', + token_type: 'Bearer', + }); + }); + globalThis.fetch = fetchMock; + + await expect( + login({ + onAuth: authorizeCallback, + }), + ).resolves.toMatchObject({ + access: 'login-access', + refresh: 'login-refresh', + expires: 1_700_000_780_000, + tokenEndpoint: 'https://auth.x.ai/oauth/token', + discovery: discoveryDocument, + idToken: 'login-id', + tokenType: 'Bearer', + }); + + expect(fetchMock.mock.calls[1]?.[0]).toBe('https://auth.x.ai/oauth/token'); + expect((fetchMock.mock.calls[1]?.[1]?.body as URLSearchParams).get('code')).toBe( + 'callback-code', + ); + }); + + it('reports callback timeouts with a dedicated error code', async () => { + vi.useFakeTimers(); + globalThis.fetch = vi.fn(async () => Response.json(discoveryDocument)); + const onAuth = vi.fn(); + const resultPromise = login({ onAuth }).then( + () => undefined, + (error: unknown) => error, + ); + + await vi.waitFor(() => expect(onAuth).toHaveBeenCalledOnce()); + await vi.advanceTimersByTimeAsync(180_000); + + await expect(resultPromise).resolves.toMatchObject({ + code: XaiErrorCode.CALLBACK_TIMEOUT, + message: 'Timed out waiting for xAI OAuth callback.', + }); + }); + + it('wraps token exchange transport and JSON failures', async () => { + const fetchMock = vi.fn(async (input) => { + if (input === 'https://auth.x.ai/.well-known/openid-configuration') { + return Response.json(discoveryDocument); + } + throw new Error('exchange socket closed'); + }); + globalThis.fetch = fetchMock; + + await expect( + login({ + onAuth: authorizeCallback, + }), + ).rejects.toMatchObject({ + code: XaiErrorCode.TOKEN_EXCHANGE_FAILED, + message: 'xAI token exchange failed: exchange socket closed', + }); + + globalThis.fetch = vi.fn(async (input) => { + if (input === 'https://auth.x.ai/.well-known/openid-configuration') { + return Response.json(discoveryDocument); + } + return new Response('proxy error', { status: 200 }); + }); + + await expect( + login({ + onAuth: authorizeCallback, + }), + ).rejects.toMatchObject({ + code: XaiErrorCode.TOKEN_EXCHANGE_FAILED, + message: expect.stringContaining('xAI token exchange returned invalid JSON:'), + }); + }); +}); diff --git a/tests/errors.test.ts b/tests/errors.test.ts deleted file mode 100644 index 9e35aac..0000000 --- a/tests/errors.test.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { describe, expect, it } from "vitest"; -import { XaiErrorCode, XaiOAuthError } from "../src/errors.js"; - -describe("OAuth errors", () => { - it("keeps machine-readable code and relogin state", () => { - const error = new XaiOAuthError( - "Refresh token was revoked", - XaiErrorCode.REFRESH_FAILED, - true, - ); - - expect(error).toBeInstanceOf(Error); - expect(error.name).toBe("XaiOAuthError"); - expect(error.message).toBe("Refresh token was revoked"); - expect(error.code).toBe("refresh_failed"); - expect(error.reloginRequired).toBe(true); - }); -}); diff --git a/tests/index.test.ts b/tests/index.test.ts deleted file mode 100644 index a2f9b95..0000000 --- a/tests/index.test.ts +++ /dev/null @@ -1,215 +0,0 @@ -import { - mkdirSync, - mkdtempSync, - readFileSync, - rmSync, - writeFileSync, -} from "node:fs"; -import { tmpdir } from "node:os"; -import { join } from "node:path"; -import type { - ExtensionAPI, - ProviderConfig, -} from "@earendil-works/pi-coding-agent"; -import { afterEach, describe, expect, it, vi } from "vitest"; - -const streamSimpleOpenAIResponses = vi.fn( - ( - _model: unknown, - _context: unknown, - options?: { - onResponse?: (response: { headers: Record }) => void; - }, - ) => { - options?.onResponse?.({ - headers: { - "x-ratelimit-remaining-requests": "179", - "x-ratelimit-limit-requests": "180", - "x-ratelimit-remaining-tokens": "7500000", - "x-ratelimit-limit-tokens": "7500000", - "x-grok-context-window": "512000", - "x-zero-data-retention": "true", - }, - }); - return {}; - }, -); - -vi.mock("@earendil-works/pi-ai", async (importOriginal) => ({ - ...(await importOriginal()), - streamSimpleOpenAIResponses, -})); - -interface CommandConfig { - handler: (args: string[], ctx: TestContext) => Promise; -} - -interface TestContext { - modelRegistry: { - getAll: () => { provider: string; id: string }[]; - getApiKeyForProvider?: (provider: string) => Promise; - }; - ui: { - notify: (message: string, level: string) => void; - }; -} - -const originalFetch = globalThis.fetch; -const originalHome = process.env.HOME; -const originalToken = process.env.GROK_CLI_OAUTH_TOKEN; -const tempDirs: string[] = []; - -afterEach(() => { - vi.resetModules(); - streamSimpleOpenAIResponses.mockClear(); - globalThis.fetch = originalFetch; - if (originalHome === undefined) { - delete process.env.HOME; - } else { - process.env.HOME = originalHome; - } - if (originalToken === undefined) { - delete process.env.GROK_CLI_OAUTH_TOKEN; - } else { - process.env.GROK_CLI_OAUTH_TOKEN = originalToken; - } - for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); -}); - -async function setupExtension() { - const commands = new Map(); - const providers = new Map(); - const registerGrokCli = (await import("../src/index.js")).default; - registerGrokCli({ - registerProvider(name: string, config: ProviderConfig) { - providers.set(name, config); - }, - on() {}, - registerCommand(name: string, config: unknown) { - commands.set(name, config as CommandConfig); - }, - } as unknown as ExtensionAPI); - return { commands, providers }; -} - -function statusContext(notify: TestContext["ui"]["notify"]): TestContext { - return { - modelRegistry: { - getAll: () => [ - { provider: "grok-cli", id: "grok-build" }, - { provider: "grok-cli", id: "grok-composer-2.5-fast" }, - ], - }, - ui: { notify }, - }; -} - -function setupHome() { - const dir = mkdtempSync(join(tmpdir(), "pi-grok-cli-home-")); - mkdirSync(join(dir, ".pi")); - tempDirs.push(dir); - process.env.HOME = dir; - return dir; -} - -async function runStatus( - extension: Awaited>, -) { - const notify = vi.fn(); - await extension.commands - .get("grok-cli-status") - ?.handler([], statusContext(notify)); - return notify; -} - -describe("Grok CLI status command", () => { - it("uses only cached quota data and tells users to make requests first", async () => { - delete process.env.GROK_CLI_OAUTH_TOKEN; - setupHome(); - const fetchMock = vi.fn(); - globalThis.fetch = fetchMock; - const extension = await setupExtension(); - const notify = await runStatus(extension); - - expect(fetchMock).not.toHaveBeenCalled(); - expect(notify.mock.calls.at(-1)?.[0]).toBe( - [ - " Quota:", - "", - " grok-build:", - " no cached quota data — make a request with this model first", - "", - " grok-composer-2.5-fast:", - " no cached quota data — make a request with this model first", - ].join("\n"), - ); - }); - - it("shows separate cached quotas for build and composer", async () => { - delete process.env.GROK_CLI_OAUTH_TOKEN; - setupHome(); - const extension = await setupExtension(); - const provider = extension.providers.get("grok-cli"); - provider?.streamSimple?.( - { provider: "grok-cli", id: "grok-build" }, - {}, - {}, - ); - provider?.streamSimple?.( - { provider: "grok-cli", id: "grok-composer-2.5-fast" }, - {}, - {}, - ); - const notify = await runStatus(extension); - - expect(notify.mock.calls.at(-1)?.[0]).toContain("grok-build:\n Cached:"); - expect(notify.mock.calls.at(-1)?.[0]).toContain( - "grok-composer-2.5-fast:\n Cached:", - ); - expect(notify.mock.calls.at(-1)?.[0]).toContain( - "Requests: 179/180 remaining", - ); - }); - - it("persists cached quotas to the global pi config directory", async () => { - delete process.env.GROK_CLI_OAUTH_TOKEN; - const home = setupHome(); - const extension = await setupExtension(); - extension.providers - .get("grok-cli") - ?.streamSimple?.({ provider: "grok-cli", id: "grok-build" }, {}, {}); - - expect( - JSON.parse(readFileSync(join(home, ".pi", "grok-cli-quota.json"), "utf8")) - .models["grok-build"].remainingRequests, - ).toBe(179); - }); - - it("loads cached quotas from the global pi config directory", async () => { - delete process.env.GROK_CLI_OAUTH_TOKEN; - const home = setupHome(); - writeFileSync( - join(home, ".pi", "grok-cli-quota.json"), - JSON.stringify({ - version: 1, - models: { - "grok-build": { - remainingRequests: 42, - limitRequests: 180, - remainingTokens: 1_000, - limitTokens: 2_000, - contextWindow: 512_000, - zeroDataRetention: true, - capturedAt: Date.now(), - }, - }, - }), - ); - const extension = await setupExtension(); - const notify = await runStatus(extension); - - expect(notify.mock.calls.at(-1)?.[0]).toContain( - "Requests: 42/180 remaining", - ); - }); -}); diff --git a/tests/models.test.ts b/tests/models.test.ts deleted file mode 100644 index 932a206..0000000 --- a/tests/models.test.ts +++ /dev/null @@ -1,51 +0,0 @@ -import { afterEach, describe, expect, it } from "vitest"; -import { resolveModels, supportsReasoningEffort } from "../src/models.js"; - -const originalEnv = { ...process.env }; - -afterEach(() => { - process.env = { ...originalEnv }; -}); - -describe("model catalog", () => { - it("reports reasoning-effort support by normalized model name", () => { - expect(supportsReasoningEffort("grok-4.3")).toBe(true); - expect(supportsReasoningEffort("grok-cli/GROK-COMPOSER-2.5-fast")).toBe( - true, - ); - expect(supportsReasoningEffort("grok-4.20-0309-non-reasoning")).toBe(false); - }); - - it("uses fallback models when no override is configured", () => { - delete process.env.PI_GROK_CLI_MODELS; - - expect(resolveModels().map((model) => model.id)).toEqual([ - "grok-composer-2.5-fast", - "grok-build", - "grok-4.3", - "grok-4.20-0309-reasoning", - "grok-4.20-0309-non-reasoning", - "grok-4.20-multi-agent-0309", - ]); - }); - - it("filters, reorders, and fills unknown model overrides", () => { - process.env.PI_GROK_CLI_MODELS = " custom-model , grok-build ,, grok-4.3 "; - - const models = resolveModels(); - - expect(models.map((model) => model.id)).toEqual([ - "custom-model", - "grok-build", - "grok-4.3", - ]); - expect(models[0]).toMatchObject({ - name: "custom-model", - reasoning: true, - input: ["text"], - contextWindow: 1_000_000, - maxTokens: 30_000, - }); - expect(models[1].name).toBe("Grok Build"); - }); -}); diff --git a/tests/models/catalog.test.ts b/tests/models/catalog.test.ts new file mode 100644 index 0000000..0f9ea87 --- /dev/null +++ b/tests/models/catalog.test.ts @@ -0,0 +1,53 @@ +import { afterEach, describe, expect, it } from 'vitest'; +import { resolveModels, supportsReasoningEffort } from '../../src/models/catalog.js'; + +const originalEnv = { ...process.env }; + +afterEach(() => { + process.env = { ...originalEnv }; +}); + +describe('model catalog', () => { + it('reports reasoning-effort support by normalized model name', () => { + expect(supportsReasoningEffort('grok-4.3')).toBe(true); + expect(supportsReasoningEffort('grok-cli/GROK-COMPOSER-2.5-fast')).toBe(false); + expect(supportsReasoningEffort('grok-4.20-0309-non-reasoning')).toBe(false); + }); + + it('uses fallback models when no override is configured', () => { + delete process.env.PI_GROK_CLI_MODELS; + + const models = resolveModels(); + + expect(models.map((model) => model.id)).toEqual([ + 'grok-composer-2.5-fast', + 'grok-build', + 'grok-4.3', + 'grok-4.20-0309-reasoning', + 'grok-4.20-0309-non-reasoning', + 'grok-4.20-multi-agent-0309', + ]); + expect(models.find((model) => model.id === 'grok-composer-2.5-fast')).toMatchObject({ + contextWindow: 200_000, + }); + expect(models.find((model) => model.id === 'grok-build')).toMatchObject({ + contextWindow: 512_000, + }); + }); + + it('filters, reorders, and fills unknown model overrides', () => { + process.env.PI_GROK_CLI_MODELS = ' custom-model , grok-build ,, grok-4.3 '; + + const models = resolveModels(); + + expect(models.map((model) => model.id)).toEqual(['custom-model', 'grok-build', 'grok-4.3']); + expect(models[0]).toMatchObject({ + name: 'custom-model', + reasoning: true, + input: ['text'], + contextWindow: 1_000_000, + maxTokens: 30_000, + }); + expect(models[1].name).toBe('Grok Build'); + }); +}); diff --git a/tests/oauth.test.ts b/tests/oauth.test.ts deleted file mode 100644 index e11023e..0000000 --- a/tests/oauth.test.ts +++ /dev/null @@ -1,286 +0,0 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; -import { XaiErrorCode } from "../src/errors.js"; -import { getBaseUrl, login, refresh } from "../src/oauth.js"; - -const originalEnv = { ...process.env }; -const originalFetch = globalThis.fetch; -const storedRefreshCredentials = { - access: "access-token", - refresh: "refresh-token", - expires: 0, - tokenEndpoint: "https://auth.x.ai/oauth/token", -}; -const credentialsWithoutEndpoint = { - access: "old-access", - refresh: "old-refresh", - expires: 0, -}; -const discoveryDocument = { - authorization_endpoint: "https://auth.x.ai/oauth/authorize", - token_endpoint: "https://auth.x.ai/oauth/token", -}; - -afterEach(() => { - process.env = { ...originalEnv }; - globalThis.fetch = originalFetch; - vi.restoreAllMocks(); - vi.useRealTimers(); -}); - -describe("OAuth helpers without network access", () => { - it("resolves and trims the configured base URL", () => { - delete process.env.GROK_CLI_BASE_URL; - delete process.env.PI_GROK_CLI_BASE_URL; - expect(getBaseUrl()).toBe("https://cli-chat-proxy.grok.com/v1"); - - process.env.GROK_CLI_BASE_URL = "https://example.invalid/v1///"; - expect(getBaseUrl()).toBe("https://example.invalid/v1"); - - process.env.PI_GROK_CLI_BASE_URL = "https://override.invalid/api//"; - expect(getBaseUrl()).toBe("https://override.invalid/api"); - }); - - it("rejects refresh credentials with no refresh token before fetching", async () => { - const fetchMock = vi.fn(); - globalThis.fetch = fetchMock; - - await expect( - refresh({ - access: "access-token", - refresh: "", - expires: 0, - tokenEndpoint: "https://auth.x.ai/oauth/token", - }), - ).rejects.toMatchObject({ - code: XaiErrorCode.REFRESH_MISSING, - reloginRequired: true, - }); - expect(fetchMock).not.toHaveBeenCalled(); - }); - - it("refreshes credentials with the configured token endpoint", async () => { - vi.useFakeTimers(); - vi.setSystemTime(1_700_000_000_000); - process.env.PI_GROK_CLI_BASE_URL = "https://proxy.example/v1//"; - const fetchMock = vi.fn(async () => - Response.json({ - access_token: "new-access", - refresh_token: "new-refresh", - expires_in: 600, - id_token: "new-id", - token_type: "DPoP", - }), - ); - globalThis.fetch = fetchMock; - - await expect( - refresh({ - access: "old-access", - refresh: "old-refresh", - expires: 0, - tokenEndpoint: "https://auth.x.ai/oauth/token", - idToken: "old-id", - tokenType: "Bearer", - }), - ).resolves.toMatchObject({ - access: "new-access", - refresh: "new-refresh", - expires: 1_700_000_480_000, - tokenEndpoint: "https://auth.x.ai/oauth/token", - idToken: "new-id", - tokenType: "DPoP", - baseUrl: "https://proxy.example/v1", - }); - - expect(fetchMock).toHaveBeenCalledOnce(); - expect(fetchMock.mock.calls[0]?.[0]).toBe("https://auth.x.ai/oauth/token"); - expect(fetchMock.mock.calls[0]?.[1]).toMatchObject({ - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - Accept: "application/json", - }, - }); - expect( - (fetchMock.mock.calls[0]?.[1]?.body as URLSearchParams).toString(), - ).toBe( - "grant_type=refresh_token&client_id=b1a00492-073a-47ea-816f-4c329264a828&refresh_token=old-refresh", - ); - }); - - it("keeps the existing refresh token and metadata when refresh omits optional fields", async () => { - const fetchMock = vi.fn(async () => - Response.json({ access_token: "new-access", expires_in: "900" }), - ); - globalThis.fetch = fetchMock; - - await expect( - refresh({ - access: "old-access", - refresh: "old-refresh", - expires: 0, - discovery: { - authorization_endpoint: "https://auth.x.ai/oauth/authorize", - token_endpoint: "https://accounts.x.ai/oauth/token", - }, - idToken: "old-id", - tokenType: "Bearer", - }), - ).resolves.toMatchObject({ - access: "new-access", - refresh: "old-refresh", - tokenEndpoint: "https://accounts.x.ai/oauth/token", - idToken: "old-id", - tokenType: "Bearer", - }); - }); - - it("marks unauthorized refresh failures as requiring login", async () => { - const fetchMock = vi.fn( - async () => new Response("revoked", { status: 401 }), - ); - globalThis.fetch = fetchMock; - - await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ - code: XaiErrorCode.REFRESH_FAILED, - reloginRequired: true, - message: "xAI token refresh failed: 401 revoked", - }); - }); - - it("keeps server refresh failures retryable", async () => { - const fetchMock = vi.fn( - async () => new Response("temporarily unavailable", { status: 500 }), - ); - globalThis.fetch = fetchMock; - - await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ - code: XaiErrorCode.REFRESH_FAILED, - reloginRequired: false, - message: "xAI token refresh failed: 500 temporarily unavailable", - }); - }); - - it("rejects refresh responses without an access token", async () => { - const fetchMock = vi.fn(async () => Response.json({})); - globalThis.fetch = fetchMock; - - await expect(refresh(storedRefreshCredentials)).rejects.toMatchObject({ - code: XaiErrorCode.REFRESH_FAILED, - reloginRequired: true, - message: "xAI token refresh did not return access_token.", - }); - }); - - it("rejects unsafe token endpoints before fetching", async () => { - const fetchMock = vi.fn(); - globalThis.fetch = fetchMock; - - await expect( - refresh({ - ...storedRefreshCredentials, - tokenEndpoint: "https://evil.example/oauth/token", - }), - ).rejects.toMatchObject({ - code: XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - message: - "Refusing non-xAI OAuth token_endpoint: https://evil.example/oauth/token", - }); - expect(fetchMock).not.toHaveBeenCalled(); - }); - - it("discovers the token endpoint when credentials do not include it", async () => { - const fetchMock = vi.fn(async (input) => { - if (input === "https://auth.x.ai/.well-known/openid-configuration") { - return Response.json(discoveryDocument); - } - return Response.json({ access_token: "new-access" }); - }); - globalThis.fetch = fetchMock; - - await expect(refresh(credentialsWithoutEndpoint)).resolves.toMatchObject({ - access: "new-access", - refresh: "old-refresh", - tokenEndpoint: "https://auth.x.ai/oauth/token", - }); - expect(fetchMock.mock.calls.map((call) => call[0])).toEqual([ - "https://auth.x.ai/.well-known/openid-configuration", - "https://auth.x.ai/oauth/token", - ]); - }); - - it("wraps discovery network failures", async () => { - globalThis.fetch = vi.fn(async () => { - throw new Error("network down"); - }); - - await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ - code: XaiErrorCode.DISCOVERY_FAILED, - message: "xAI OIDC discovery failed: network down", - }); - }); - - it("rejects failed and invalid discovery responses", async () => { - globalThis.fetch = vi.fn( - async () => new Response("unavailable", { status: 503 }), - ); - await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ - code: XaiErrorCode.DISCOVERY_FAILED, - message: "xAI OIDC discovery returned 503", - }); - - globalThis.fetch = vi.fn(async () => - Response.json({ - authorization_endpoint: "http://auth.x.ai/oauth/authorize", - token_endpoint: "https://auth.x.ai/oauth/token", - }), - ); - await expect(refresh(credentialsWithoutEndpoint)).rejects.toMatchObject({ - code: XaiErrorCode.DISCOVERY_INVALID_ORIGIN, - message: - "xAI OAuth authorization_endpoint must use HTTPS: http://auth.x.ai/oauth/authorize", - }); - }); - - it("logs in with a loopback callback and exchanges the authorization code", async () => { - vi.useFakeTimers(); - vi.setSystemTime(1_700_000_000_000); - const fetchMock = vi.fn(async (input) => { - if (input === "https://auth.x.ai/.well-known/openid-configuration") { - return Response.json(discoveryDocument); - } - return Response.json({ - access_token: "login-access", - refresh_token: "login-refresh", - expires_in: 900, - id_token: "login-id", - token_type: "Bearer", - }); - }); - globalThis.fetch = fetchMock; - - await expect( - login({ - onAuth: (auth) => { - const url = new URL(auth.url); - void originalFetch( - `${url.searchParams.get("redirect_uri")}?code=callback-code&state=${url.searchParams.get("state")}`, - ); - }, - }), - ).resolves.toMatchObject({ - access: "login-access", - refresh: "login-refresh", - expires: 1_700_000_780_000, - tokenEndpoint: "https://auth.x.ai/oauth/token", - discovery: discoveryDocument, - idToken: "login-id", - tokenType: "Bearer", - }); - - expect(fetchMock.mock.calls[1]?.[0]).toBe("https://auth.x.ai/oauth/token"); - expect( - (fetchMock.mock.calls[1]?.[1]?.body as URLSearchParams).get("code"), - ).toBe("callback-code"); - }); -}); diff --git a/tests/package.test.ts b/tests/package.test.ts deleted file mode 100644 index 215dd7e..0000000 --- a/tests/package.test.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { existsSync, globSync, readFileSync } from "node:fs"; -import { describe, expect, it } from "vitest"; - -const packageJson = JSON.parse( - readFileSync(new URL("../package.json", import.meta.url), "utf8"), -); - -describe("npm package manifest", () => { - it("declares a pi package entry point", () => { - expect(packageJson.name).toBe("pi-grok-cli"); - expect(packageJson.keywords).toContain("pi-package"); - expect(packageJson.pi?.extensions).toEqual(["./src/index.ts"]); - expect(packageJson.main).toBe("./src/index.ts"); - expect(packageJson.files).toEqual(["README.md", "src", "tsconfig.json"]); - }); - - it("runs publish checks before packing", () => { - expect(packageJson.scripts?.test).toBe("vitest run --reporter=agent"); - expect(packageJson.scripts?.coverage).toBe( - "vitest run --reporter=agent --coverage", - ); - expect(packageJson.scripts?.typecheck).toBe("tsc --noEmit"); - expect(packageJson.scripts?.prepack).toBe( - "bun run test && bun run coverage && bun run typecheck", - ); - expect(packageJson.devDependencies?.vitest).toBeDefined(); - expect(packageJson.devDependencies?.["@vitest/coverage-v8"]).toBeDefined(); - expect(existsSync(new URL("../vitest.config.ts", import.meta.url))).toBe( - true, - ); - }); -}); - -describe("repository layout", () => { - it("keeps extension source files under src", () => { - for (const file of [ - "index.ts", - "models.ts", - "oauth.ts", - "sanitize.ts", - "errors.ts", - ]) { - expect(existsSync(new URL(`../src/${file}`, import.meta.url))).toBe(true); - expect(existsSync(new URL(`../${file}`, import.meta.url))).toBe(false); - } - }); - - it("contains the expected source files", () => { - const sourceFiles = globSync("src/*.ts").sort(); - expect(sourceFiles).toEqual([ - "src/errors.ts", - "src/index.ts", - "src/models.ts", - "src/oauth.ts", - "src/sanitize.ts", - ]); - }); -}); diff --git a/tests/payload/sanitize.test.ts b/tests/payload/sanitize.test.ts new file mode 100644 index 0000000..f3935a5 --- /dev/null +++ b/tests/payload/sanitize.test.ts @@ -0,0 +1,241 @@ +import { mkdirSync, mkdtempSync, rmSync, writeFileSync } from 'node:fs'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { sanitizePayload } from '../../src/payload/sanitize.js'; + +describe('payload sanitization', () => { + it('removes unsupported items and moves all instructions', () => { + const payload = sanitizePayload( + { + instructions: 'existing instruction', + input: [ + { role: 'system', content: 'system instruction' }, + { + role: 'developer', + content: [ + { type: 'input_text', text: 'developer instruction' }, + { type: 'output_text', text: 'output text instruction' }, + ], + }, + { type: 'reasoning', content: 'cached reasoning' }, + { role: 'user', content: '' }, + { role: 'user', content: 'hello' }, + { role: 'system', content: 'later system instruction' }, + ], + include: ['reasoning.encrypted_content', 'message.output_text'], + prompt_cache_retention: '24h', + reasoning: { effort: 'minimal', summary: 'auto' }, + response_format: { type: 'json_object' }, + }, + 'grok-4.3', + 'session-123', + process.cwd(), + ); + + expect(payload.instructions).toBe( + 'existing instruction\n\nsystem instruction\n\ndeveloper instruction\noutput text instruction\n\nlater system instruction', + ); + expect(payload.input).toEqual([{ role: 'user', content: 'hello' }]); + expect(payload.include).toEqual(['message.output_text']); + expect(payload.prompt_cache_retention).toBeUndefined(); + expect(payload.reasoning).toEqual({ effort: 'low' }); + expect(payload.text).toEqual({ format: { type: 'json_object' } }); + expect(payload.response_format).toBeUndefined(); + expect(payload.prompt_cache_key).toBe('session-123'); + }); + + it('preserves existing text while removing response_format', () => { + const payload = sanitizePayload( + { + input: 'plain prompt', + text: { format: { type: 'text' } }, + response_format: { type: 'json_object' }, + }, + 'grok-4.3', + undefined, + process.cwd(), + ); + + expect(payload.text).toEqual({ format: { type: 'text' } }); + expect(payload.response_format).toBeUndefined(); + }); + + it('strips reasoning fields for models that do not accept reasoning effort', () => { + const payload = sanitizePayload( + { + input: 'plain prompt', + include: ['reasoning.encrypted_content'], + reasoning: { effort: 'high' }, + reasoningEffort: 'high', + prompt_cache_key: 'existing-session', + }, + 'grok-build', + 'new-session', + process.cwd(), + ); + + expect(payload.input).toBe('plain prompt'); + expect(payload.reasoning).toBeUndefined(); + expect(payload.reasoningEffort).toBeUndefined(); + expect(payload.include).toBeUndefined(); + expect(payload.prompt_cache_key).toBe('existing-session'); + }); + + it('normalizes image parts and rewrites image tool output', () => { + const payload = sanitizePayload( + { + input: [ + { + role: 'user', + content: [ + { type: 'image', data: 'ZmFrZQ==', mimeType: 'image/png' }, + { + type: 'image_url', + image_url: { + url: 'https://example.invalid/image.png', + detail: 'high', + }, + }, + ], + }, + { + type: 'function_call_output', + call_id: 'call_1', + output: [ + { type: 'input_text', text: 'tool text' }, + { type: 'input_image', image_url: 'data:image/png;base64,aW1n' }, + ], + }, + ], + }, + 'grok-composer-2.5-fast', + undefined, + process.cwd(), + ); + + expect(payload.input).toEqual([ + { + role: 'user', + content: [ + { + type: 'input_image', + image_url: 'data:image/png;base64,ZmFrZQ==', + detail: 'auto', + }, + { + type: 'input_image', + image_url: 'https://example.invalid/image.png', + detail: 'high', + }, + ], + }, + { type: 'function_call_output', call_id: 'call_1', output: 'tool text' }, + { + role: 'user', + content: [ + { + type: 'input_text', + text: 'The previous tool result (call_1) included 1 image. Use the attached image as the visual output from that tool.', + }, + { + type: 'input_image', + image_url: 'data:image/png;base64,aW1n', + detail: 'auto', + }, + ], + }, + ]); + }); + + it('resolves local image paths to data URLs', () => { + const dir = mkdtempSync(join(tmpdir(), 'pi-grok-cli-test-')); + const imagePath = join(dir, 'sample image.png'); + writeFileSync(imagePath, Buffer.from('png image bytes')); + + try { + const payload = sanitizePayload( + { + input: [ + { + role: 'user', + content: [ + { + type: 'input_image', + image_url: `'${imagePath}'`, + }, + ], + }, + ], + }, + 'grok-4.3', + undefined, + dir, + ); + + expect(payload.input).toEqual([ + { + role: 'user', + content: [ + { + type: 'input_image', + image_url: `data:image/png;base64,${Buffer.from('png image bytes').toString('base64')}`, + detail: 'auto', + }, + ], + }, + ]); + } finally { + rmSync(dir, { recursive: true, force: true }); + } + }); + + it('rejects missing or unsupported local images', () => { + expect(() => + sanitizePayload( + { + input: [ + { + role: 'user', + content: [{ type: 'input_image', image_url: 'missing.png' }], + }, + ], + }, + 'grok-4.3', + undefined, + process.cwd(), + ), + ).toThrow('Image file does not exist or is not a valid URL: missing.png'); + }); + + it('rejects local image paths outside the workspace', () => { + const dir = mkdtempSync(join(tmpdir(), 'pi-grok-cli-test-')); + const workspace = join(dir, 'workspace'); + const originalCwd = process.cwd(); + writeFileSync(join(dir, 'secret.png'), Buffer.from('png image bytes')); + mkdirSync(workspace); + + try { + process.chdir(workspace); + + expect(() => + sanitizePayload( + { + input: [ + { + role: 'user', + content: [{ type: 'input_image', image_url: join('..', 'secret.png') }], + }, + ], + }, + 'grok-4.3', + undefined, + process.cwd(), + ), + ).toThrow('Image path is outside the workspace'); + } finally { + process.chdir(originalCwd); + rmSync(dir, { recursive: true, force: true }); + } + }); +}); diff --git a/tests/provider/package.test.ts b/tests/provider/package.test.ts new file mode 100644 index 0000000..3ff0c1a --- /dev/null +++ b/tests/provider/package.test.ts @@ -0,0 +1,60 @@ +import { existsSync, globSync, readFileSync } from 'node:fs'; +import { describe, expect, it } from 'vitest'; + +const packageJson = JSON.parse( + readFileSync(new URL('../../package.json', import.meta.url), 'utf8'), +); + +describe('npm package manifest', () => { + it('declares a pi package entry point', () => { + expect(packageJson.name).toBe('pi-grok-cli'); + expect(packageJson.keywords).toContain('pi-package'); + expect(packageJson.pi?.extensions).toEqual(['./src/index.ts']); + expect(packageJson.main).toBe('./src/index.ts'); + expect(packageJson.files).toEqual(['README.md', 'src', 'tsconfig.json']); + }); + + it('runs publish checks before packing', () => { + expect(packageJson.scripts?.test).toBe('vitest run --reporter=agent'); + expect(packageJson.scripts?.coverage).toBe('vitest run --reporter=agent --coverage'); + expect(packageJson.scripts?.typecheck).toBe('tsc --noEmit'); + expect(packageJson.scripts?.prepack).toBe( + 'bun run test && bun run coverage && bun run typecheck', + ); + expect(packageJson.devDependencies?.vitest).toBeDefined(); + expect(packageJson.devDependencies?.['@vitest/coverage-v8']).toBeDefined(); + expect(existsSync(new URL('../../vitest.config.ts', import.meta.url))).toBe(true); + }); +}); + +describe('repository layout', () => { + it('keeps the extension entrypoint at src/index.ts', () => { + expect(existsSync(new URL('../../src/index.ts', import.meta.url))).toBe(true); + }); + + it('contains the expected domain source files', () => { + expect(globSync('src/**/*.ts').sort()).toEqual([ + 'src/auth/oauth.ts', + 'src/index.ts', + 'src/models/catalog.ts', + 'src/payload/sanitize.ts', + 'src/provider/quota.ts', + 'src/provider/register.ts', + 'src/provider/status.ts', + 'src/provider/stream.ts', + 'src/provider/toolScope.ts', + 'src/shared/errors.ts', + 'src/tools/files.ts', + 'src/tools/register.ts', + 'src/tools/rendering.ts', + 'src/tools/search.ts', + 'src/tools/shell.ts', + ]); + }); + + it('does not keep top-level helper compatibility wrappers', () => { + for (const file of ['errors.ts', 'models.ts', 'oauth.ts', 'sanitize.ts']) { + expect(existsSync(new URL(`../../src/${file}`, import.meta.url))).toBe(false); + } + }); +}); diff --git a/tests/provider/register.test.ts b/tests/provider/register.test.ts new file mode 100644 index 0000000..ad5a3f0 --- /dev/null +++ b/tests/provider/register.test.ts @@ -0,0 +1,612 @@ +import { existsSync, mkdirSync, mkdtempSync, readFileSync, rmSync, writeFileSync } from 'node:fs'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import type { ExtensionAPI, ProviderConfig } from '@earendil-works/pi-coding-agent'; +import { afterEach, describe, expect, it, vi } from 'vitest'; + +const streamSimpleOpenAIResponses = vi.fn( + ( + _model: unknown, + _context: unknown, + options?: { + onResponse?: (response: { headers: Record }) => void; + }, + ) => { + options?.onResponse?.({ + headers: { + 'x-ratelimit-remaining-requests': '179', + 'x-ratelimit-limit-requests': '180', + 'x-ratelimit-remaining-tokens': '7500000', + 'x-ratelimit-limit-tokens': '7500000', + 'x-grok-context-window': '512000', + 'x-zero-data-retention': 'true', + }, + }); + return {}; + }, +); + +vi.mock('@earendil-works/pi-ai', async (importOriginal) => ({ + ...(await importOriginal()), + streamSimpleOpenAIResponses, +})); + +interface CommandConfig { + handler: (args: string[], ctx: TestContext) => Promise; +} + +interface RegisteredTool { + name: string; + renderCall?: (...args: unknown[]) => Renderable; + renderResult?: (...args: unknown[]) => Renderable; +} + +interface Renderable { + render: (width: number) => string[]; +} + +interface TestContext { + cwd?: string; + modelRegistry: { + getAll: () => { provider: string; id: string }[]; + getApiKeyForProvider?: (provider: string) => Promise; + }; + model?: { provider: string; id: string }; + sessionManager?: { + getSessionId: () => string; + }; + ui: { + notify: (message: string, level: string) => void; + }; +} + +type ExtensionHandler = (event: unknown, ctx: TestContext) => unknown; + +const grokToolNames = [ + 'Grep', + 'Glob', + 'LS', + 'Read', + 'Write', + 'StrReplace', + 'Edit', + 'Delete', + 'Shell', +]; + +const originalFetch = globalThis.fetch; +const originalHome = process.env.HOME; +const originalToken = process.env.GROK_CLI_OAUTH_TOKEN; +const tempDirs: string[] = []; + +afterEach(() => { + vi.resetModules(); + streamSimpleOpenAIResponses.mockClear(); + globalThis.fetch = originalFetch; + if (originalHome === undefined) { + delete process.env.HOME; + } else { + process.env.HOME = originalHome; + } + if (originalToken === undefined) { + delete process.env.GROK_CLI_OAUTH_TOKEN; + } else { + process.env.GROK_CLI_OAUTH_TOKEN = originalToken; + } + for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); +}); + +async function setupExtension(initialActiveTools = ['read', 'bash']) { + const commands = new Map(); + const providers = new Map(); + const tools = new Map(); + const handlers = new Map(); + let activeTools = initialActiveTools; + const setActiveTools = vi.fn((toolNames: string[]) => { + activeTools = toolNames; + }); + const registerGrokCli = (await import('../../src/index.js')).default; + registerGrokCli({ + registerProvider(name: string, config: ProviderConfig) { + providers.set(name, config); + }, + on(event: string, handler: ExtensionHandler) { + handlers.set(event, handler); + }, + registerCommand(name: string, config: unknown) { + commands.set(name, config as CommandConfig); + }, + registerTool(tool: RegisteredTool) { + tools.set(tool.name, tool); + }, + getActiveTools() { + return activeTools; + }, + setActiveTools, + } as unknown as ExtensionAPI); + return { commands, providers, tools, handlers, setActiveTools }; +} + +function statusContext(notify: TestContext['ui']['notify']): TestContext { + return { + modelRegistry: { + getAll: () => [ + { provider: 'grok-cli', id: 'grok-build' }, + { provider: 'grok-cli', id: 'grok-composer-2.5-fast' }, + ], + }, + ui: { notify }, + }; +} + +function emptyStatusContext(notify: TestContext['ui']['notify']): TestContext { + return { + modelRegistry: { getAll: () => [] }, + ui: { notify }, + }; +} + +function contextForModel(provider: string): TestContext { + return { + model: { provider, id: `${provider}-model` }, + modelRegistry: { getAll: () => [] }, + ui: { notify: vi.fn() }, + }; +} + +function renderText(component: Renderable): string { + return component + .render(120) + .map((line) => line.trimEnd()) + .join('\n'); +} + +const theme = { + bold: (text: string) => text, + fg: (_name: string, text: string) => text, +}; + +function setupHome() { + const dir = mkdtempSync(join(tmpdir(), 'pi-grok-cli-home-')); + mkdirSync(join(dir, '.pi')); + tempDirs.push(dir); + process.env.HOME = dir; + return dir; +} + +async function runStatus(extension: Awaited>) { + const notify = vi.fn(); + await extension.commands.get('grok-cli-status')?.handler([], statusContext(notify)); + return notify; +} + +describe('Grok CLI status command', () => { + it('uses only cached quota data and tells users to make requests first', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + setupHome(); + const fetchMock = vi.fn(); + globalThis.fetch = fetchMock; + const extension = await setupExtension(); + const notify = await runStatus(extension); + + expect(fetchMock).not.toHaveBeenCalled(); + expect(notify.mock.calls.at(-1)?.[0]).toBe( + [ + ' Quota:', + '', + ' grok-build:', + ' no cached quota data — make a request with this model first', + '', + ' grok-composer-2.5-fast:', + ' no cached quota data — make a request with this model first', + ].join('\n'), + ); + }); + + it('shows separate cached quotas for build and composer', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + setupHome(); + const extension = await setupExtension(); + const provider = extension.providers.get('grok-cli'); + provider?.streamSimple?.({ provider: 'grok-cli', id: 'grok-build' }, {}, {}); + provider?.streamSimple?.({ provider: 'grok-cli', id: 'grok-composer-2.5-fast' }, {}, {}); + const notify = await runStatus(extension); + + expect(notify.mock.calls.at(-1)?.[0]).toContain('grok-build:\n Cached:'); + expect(notify.mock.calls.at(-1)?.[0]).toContain('grok-composer-2.5-fast:\n Cached:'); + expect(notify.mock.calls.at(-1)?.[0]).toContain('Requests: 179/180 remaining'); + }); + + it('shows cached quotas for registered Grok models instead of hard-coded names', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + setupHome(); + const extension = await setupExtension(); + extension.providers + .get('grok-cli') + ?.streamSimple?.({ provider: 'grok-cli', id: 'custom' }, {}, {}); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], { + modelRegistry: { + getAll: () => [{ provider: 'grok-cli', id: 'custom' }], + }, + ui: { notify }, + }); + + expect(notify.mock.calls.at(-1)?.[0]).toContain('custom:\n Cached:'); + expect(notify.mock.calls.at(-1)?.[0]).not.toContain('grok-build:'); + }); + + it('persists cached quotas to the global pi config directory', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + const home = setupHome(); + const extension = await setupExtension(); + extension.providers + .get('grok-cli') + ?.streamSimple?.({ provider: 'grok-cli', id: 'grok-build' }, {}, {}); + + expect( + JSON.parse(readFileSync(join(home, '.pi', 'grok-cli-quota.json'), 'utf8')).models[ + 'grok-build' + ].remainingRequests, + ).toBe(179); + }); + + it('ignores incomplete quota headers instead of caching NaN values', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + const home = setupHome(); + streamSimpleOpenAIResponses.mockImplementationOnce((_model, _context, options) => { + options?.onResponse?.({ + headers: { + 'x-ratelimit-remaining-tokens': '7500000', + 'x-ratelimit-limit-tokens': '7500000', + }, + }); + return {}; + }); + const extension = await setupExtension(); + extension.providers + .get('grok-cli') + ?.streamSimple?.({ provider: 'grok-cli', id: 'grok-build' }, {}, {}); + const notify = await runStatus(extension); + + expect(existsSync(join(home, '.pi', 'grok-cli-quota.json'))).toBe(false); + expect(notify.mock.calls.at(-1)?.[0]).not.toContain('NaN'); + expect(notify.mock.calls.at(-1)?.[0]).toContain( + 'no cached quota data — make a request with this model first', + ); + }); + + it('loads cached quotas from the global pi config directory', async () => { + delete process.env.GROK_CLI_OAUTH_TOKEN; + const home = setupHome(); + writeFileSync( + join(home, '.pi', 'grok-cli-quota.json'), + JSON.stringify({ + version: 1, + models: { + 'grok-build': { + remainingRequests: 42, + limitRequests: 180, + remainingTokens: 1_000, + limitTokens: 2_000, + contextWindow: 512_000, + zeroDataRetention: true, + capturedAt: Date.now(), + }, + }, + }), + ); + const extension = await setupExtension(); + const notify = await runStatus(extension); + + expect(notify.mock.calls.at(-1)?.[0]).toContain('Requests: 42/180 remaining'); + }); + + it('warns when no Grok models are registered', async () => { + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], emptyStatusContext(notify)); + + expect(notify).toHaveBeenCalledOnce(); + expect(notify).toHaveBeenCalledWith( + 'Grok CLI: no models registered. Run /login grok-cli first.', + 'warning', + ); + }); + + it('shows env-token bypass and truncates long model lists', async () => { + process.env.GROK_CLI_OAUTH_TOKEN = 'token'; + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], { + modelRegistry: { + getAll: () => + Array.from({ length: 7 }, (_value, index) => ({ + provider: 'grok-cli', + id: `grok-model-${index + 1}`, + })), + }, + ui: { notify }, + }); + + expect(notify.mock.calls[0]).toEqual([ + '⚠️ Grok CLI: using GROK_CLI_OAUTH_TOKEN env bypass — no auto-refresh available', + 'warning', + ]); + expect(notify.mock.calls[1]).toEqual([ + '✓ Grok CLI: 7 models available (grok-model-1, grok-model-2, grok-model-3, grok-model-4, grok-model-5 (+2 more))', + 'info', + ]); + }); + + it('reports registry errors as status warnings', async () => { + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], { + modelRegistry: { + getAll: () => { + throw new Error('registry unavailable'); + }, + }, + ui: { notify }, + }); + + expect(notify).toHaveBeenCalledWith('Grok CLI: registry unavailable', 'warning'); + }); + + it('includes OAuth error codes in status warnings', async () => { + const { XaiOAuthError } = await import('../../src/shared/errors.js'); + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.commands.get('grok-cli-status')?.handler([], { + modelRegistry: { + getAll: () => { + throw new XaiOAuthError('refresh failed', 'refresh_failed', true); + }, + }, + ui: { notify }, + }); + + expect(notify).toHaveBeenCalledWith( + 'Grok CLI: refresh failed (code: refresh_failed)', + 'warning', + ); + }); +}); + +describe('Grok CLI provider registration', () => { + it('registers provider metadata and OAuth helpers', async () => { + const extension = await setupExtension(); + const provider = extension.providers.get('grok-cli'); + + expect(provider?.name).toBe('Grok CLI'); + expect(provider?.api).toBe('openai-responses'); + expect(provider?.apiKey).toBe('$GROK_CLI_OAUTH_TOKEN'); + expect(provider?.models.map((model) => model.id)).toContain('grok-build'); + expect(provider?.oauth?.getApiKey({ access: 'access-token' })).toBe('access-token'); + expect( + provider?.oauth?.modifyModels( + [ + { provider: 'grok-cli', id: 'grok-build', baseUrl: 'old' }, + { provider: 'openai', id: 'gpt-4', baseUrl: 'keep' }, + ], + { + access: 'access-token', + refresh: 'refresh-token', + expires: 123, + baseUrl: 'https://example.invalid/custom///', + }, + ), + ).toEqual([ + { + provider: 'grok-cli', + id: 'grok-build', + baseUrl: 'https://example.invalid/custom', + }, + { provider: 'openai', id: 'gpt-4', baseUrl: 'keep' }, + ]); + }); + + it('sanitizes Grok provider requests with the current session id', async () => { + const extension = await setupExtension(); + const result = extension.handlers.get('before_provider_request')?.( + { + payload: { + input: [{ role: 'system', content: 'system instruction' }], + }, + }, + { + cwd: process.cwd(), + model: { provider: 'grok-cli', id: 'grok-4.3' }, + modelRegistry: { getAll: () => [] }, + sessionManager: { getSessionId: () => 'session-123' }, + ui: { notify: vi.fn() }, + }, + ); + + expect(result).toEqual({ + input: [], + instructions: 'system instruction', + prompt_cache_key: 'session-123', + }); + }); + + it('leaves non-Grok provider requests untouched', async () => { + const extension = await setupExtension(); + const payload = { input: [{ role: 'system', content: 'keep' }] }; + const result = extension.handlers.get('before_provider_request')?.( + { payload }, + { + model: { provider: 'openai', id: 'gpt-4' }, + modelRegistry: { getAll: () => [] }, + sessionManager: { getSessionId: () => 'session-123' }, + ui: { notify: vi.fn() }, + }, + ); + + expect(result).toBeUndefined(); + expect(payload).toEqual({ input: [{ role: 'system', content: 'keep' }] }); + }); + + it('warns at session start when env-token bypass is active', async () => { + process.env.GROK_CLI_OAUTH_TOKEN = 'token'; + const extension = await setupExtension(); + const notify = vi.fn(); + + await extension.handlers.get('session_start')?.( + {}, + { + modelRegistry: { getAll: () => [] }, + ui: { notify }, + }, + ); + + expect(notify).toHaveBeenCalledWith( + '[pi-grok-cli] Using GROK_CLI_OAUTH_TOKEN bypass — no auto-refresh, no model discovery', + 'warning', + ); + }); +}); + +describe('Grok CLI tool scoping', () => { + it('registers the Grok/Cursor-native tool shims', async () => { + const extension = await setupExtension(); + + expect([...extension.tools.keys()].sort()).toEqual([...grokToolNames].sort()); + }); + + it('enables Grok tools for Grok models while preserving other active tools', async () => { + const extension = await setupExtension(['read', 'custom_tool']); + + await extension.handlers.get('model_select')?.( + { model: { provider: 'grok-cli', id: 'grok-build' } }, + contextForModel('grok-cli'), + ); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith([ + 'read', + 'custom_tool', + ...grokToolNames, + ]); + }); + + it('removes Grok tools for non-Grok models while preserving other active tools', async () => { + const extension = await setupExtension(['read', 'Grep', 'custom_tool', 'Shell']); + + await extension.handlers.get('model_select')?.( + { model: { provider: 'openai', id: 'gpt-4' } }, + contextForModel('openai'), + ); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith(['read', 'custom_tool']); + }); + + it('syncs tool scope before each agent turn from the current context model', async () => { + const extension = await setupExtension(['read']); + + await extension.handlers.get('before_agent_start')?.({}, contextForModel('grok-cli')); + + expect(extension.setActiveTools).toHaveBeenLastCalledWith(['read', ...grokToolNames]); + }); + + it('does not update active tools when the selection is already correct', async () => { + const extension = await setupExtension(['read', ...grokToolNames]); + + await extension.handlers.get('before_agent_start')?.({}, contextForModel('grok-cli')); + + expect(extension.setActiveTools).not.toHaveBeenCalled(); + }); +}); + +describe('Grok CLI tool rendering', () => { + it('adds renderers to every Grok tool shim', async () => { + const extension = await setupExtension(); + + for (const name of grokToolNames) { + expect(extension.tools.get(name)?.renderCall).toBeTypeOf('function'); + expect(extension.tools.get(name)?.renderResult).toBeTypeOf('function'); + } + }); + + it('keeps collapsed search output compact and expands to full output', async () => { + const extension = await setupExtension(); + const grep = extension.tools.get('Grep'); + const result = { + content: [{ type: 'text', text: 'src/a.ts:1:match\nsrc/b.ts:2:match' }], + details: { matchCount: 2 }, + }; + + const collapsed = renderText( + grep?.renderResult?.(result, { expanded: false, isPartial: false }, theme, {}) as Renderable, + ); + const expanded = renderText( + grep?.renderResult?.(result, { expanded: true, isPartial: false }, theme, {}) as Renderable, + ); + + expect(collapsed).toBe('2 match(es)'); + expect(collapsed).not.toContain('src/a.ts'); + expect(expanded).toContain('src/a.ts:1:match'); + }); + + it('renders compact summaries for file mutations, delete, and shell tools', async () => { + const extension = await setupExtension(); + + expect( + renderText( + extension.tools.get('Write')?.renderResult?.( + { + content: [{ type: 'text', text: 'long write output' }], + details: { bytesWritten: 42 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe('42 bytes written'); + expect( + renderText( + extension.tools.get('StrReplace')?.renderResult?.( + { + content: [{ type: 'text', text: 'long replace output' }], + details: { replacements: 3 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe('3 replacement(s)'); + expect( + renderText( + extension.tools.get('Delete')?.renderResult?.( + { + content: [{ type: 'text', text: 'long delete output' }], + details: { deleted: true }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe('Deleted'); + expect( + renderText( + extension.tools.get('Shell')?.renderResult?.( + { + content: [{ type: 'text', text: 'long shell output' }], + details: { exitCode: 2 }, + }, + { expanded: false, isPartial: false }, + theme, + {}, + ) as Renderable, + ), + ).toBe('Exit 2'); + }); +}); diff --git a/tests/sanitize.test.ts b/tests/sanitize.test.ts deleted file mode 100644 index 4358dc8..0000000 --- a/tests/sanitize.test.ts +++ /dev/null @@ -1,185 +0,0 @@ -import { mkdtempSync, rmSync, writeFileSync } from "node:fs"; -import { tmpdir } from "node:os"; -import { join } from "node:path"; -import { describe, expect, it } from "vitest"; -import { sanitizePayload } from "../src/sanitize.js"; - -describe("payload sanitization", () => { - it("removes unsupported items and moves leading instructions", () => { - const payload = sanitizePayload( - { - instructions: "existing instruction", - input: [ - { role: "system", content: "system instruction" }, - { - role: "developer", - content: [ - { type: "input_text", text: "developer instruction" }, - { type: "output_text", text: "output text instruction" }, - ], - }, - { type: "reasoning", content: "cached reasoning" }, - { role: "user", content: "" }, - { role: "user", content: "hello" }, - ], - include: ["reasoning.encrypted_content", "message.output_text"], - prompt_cache_retention: "24h", - reasoning: { effort: "minimal", summary: "auto" }, - response_format: { type: "json_object" }, - }, - "grok-4.3", - "session-123", - ); - - expect(payload.instructions).toBe( - "existing instruction\n\nsystem instruction\n\ndeveloper instruction\noutput text instruction", - ); - expect(payload.input).toEqual([{ role: "user", content: "hello" }]); - expect(payload.include).toEqual(["message.output_text"]); - expect(payload.prompt_cache_retention).toBeUndefined(); - expect(payload.reasoning).toEqual({ effort: "minimal" }); - expect(payload.text).toEqual({ format: { type: "json_object" } }); - expect(payload.response_format).toBeUndefined(); - expect(payload.prompt_cache_key).toBe("session-123"); - }); - - it("strips reasoning fields for models that do not accept reasoning effort", () => { - const payload = sanitizePayload( - { - input: "plain prompt", - include: ["reasoning.encrypted_content"], - reasoning: { effort: "high" }, - reasoningEffort: "high", - prompt_cache_key: "existing-session", - }, - "grok-build", - "new-session", - ); - - expect(payload.input).toBe("plain prompt"); - expect(payload.reasoning).toBeUndefined(); - expect(payload.reasoningEffort).toBeUndefined(); - expect(payload.include).toBeUndefined(); - expect(payload.prompt_cache_key).toBe("existing-session"); - }); - - it("normalizes image parts and rewrites image tool output", () => { - const payload = sanitizePayload( - { - input: [ - { - role: "user", - content: [ - { type: "image", data: "ZmFrZQ==", mimeType: "image/png" }, - { - type: "image_url", - image_url: { - url: "https://example.invalid/image.png", - detail: "high", - }, - }, - ], - }, - { - type: "function_call_output", - call_id: "call_1", - output: [ - { type: "input_text", text: "tool text" }, - { type: "input_image", image_url: "data:image/png;base64,aW1n" }, - ], - }, - ], - }, - "grok-composer-2.5-fast", - ); - - expect(payload.input).toEqual([ - { - role: "user", - content: [ - { - type: "input_image", - image_url: "data:image/png;base64,ZmFrZQ==", - detail: "auto", - }, - { - type: "input_image", - image_url: "https://example.invalid/image.png", - detail: "high", - }, - ], - }, - { type: "function_call_output", call_id: "call_1", output: "tool text" }, - { - role: "user", - content: [ - { - type: "input_text", - text: "The previous tool result (call_1) included 1 image. Use the attached image as the visual output from that tool.", - }, - { - type: "input_image", - image_url: "data:image/png;base64,aW1n", - detail: "auto", - }, - ], - }, - ]); - }); - - it("resolves local image paths to data URLs", () => { - const dir = mkdtempSync(join(tmpdir(), "pi-grok-cli-test-")); - const imagePath = join(dir, "sample image.png"); - writeFileSync(imagePath, Buffer.from("png image bytes")); - - try { - const payload = sanitizePayload( - { - input: [ - { - role: "user", - content: [ - { - type: "input_image", - image_url: `'${imagePath}'`, - }, - ], - }, - ], - }, - "grok-4.3", - ); - - expect(payload.input).toEqual([ - { - role: "user", - content: [ - { - type: "input_image", - image_url: `data:image/png;base64,${Buffer.from("png image bytes").toString("base64")}`, - detail: "auto", - }, - ], - }, - ]); - } finally { - rmSync(dir, { recursive: true, force: true }); - } - }); - - it("rejects missing or unsupported local images", () => { - expect(() => - sanitizePayload( - { - input: [ - { - role: "user", - content: [{ type: "input_image", image_url: "missing.png" }], - }, - ], - }, - "grok-4.3", - ), - ).toThrow("Image file does not exist or is not a valid URL: missing.png"); - }); -}); diff --git a/tests/shared/errors.test.ts b/tests/shared/errors.test.ts new file mode 100644 index 0000000..5b5863d --- /dev/null +++ b/tests/shared/errors.test.ts @@ -0,0 +1,14 @@ +import { describe, expect, it } from 'vitest'; +import { XaiErrorCode, XaiOAuthError } from '../../src/shared/errors.js'; + +describe('OAuth errors', () => { + it('keeps machine-readable code and relogin state', () => { + const error = new XaiOAuthError('Refresh token was revoked', XaiErrorCode.REFRESH_FAILED, true); + + expect(error).toBeInstanceOf(Error); + expect(error.name).toBe('XaiOAuthError'); + expect(error.message).toBe('Refresh token was revoked'); + expect(error.code).toBe('refresh_failed'); + expect(error.reloginRequired).toBe(true); + }); +}); diff --git a/tests/tools/files.test.ts b/tests/tools/files.test.ts new file mode 100644 index 0000000..3e5ca45 --- /dev/null +++ b/tests/tools/files.test.ts @@ -0,0 +1,540 @@ +import { + existsSync, + mkdirSync, + readFileSync, + realpathSync, + symlinkSync, + writeFileSync, +} from 'node:fs'; +import { join } from 'node:path'; +import { describe, expect, it, vi } from 'vitest'; +import { registerFileTools } from '../../src/tools/files.js'; +import { + collectTools, + executePreparedTool, + executeTool, + firstText, + renderToolCall, + renderToolResult, + type ToolResult, + tempDir, +} from './toolTestHelpers.js'; + +function expectStoryState(result: ToolResult, cwd: string, replacements: number, content: string) { + expect(result.details).toEqual({ + path: expectedPath(cwd, 'story.txt'), + replacements, + }); + expect(readFileSync(join(cwd, 'story.txt'), 'utf-8')).toBe(content); +} + +function expectedPath(cwd: string, ...parts: string[]) { + return join(realpathSync(cwd), ...parts); +} + +function strReplace(cwd: string, old_str: string, new_str: string) { + return executeTool( + collectTools(registerFileTools).get('StrReplace'), + { path: 'story.txt', old_str, new_str }, + cwd, + ); +} + +function strReplaceWithPreparedArgs(cwd: string, params: Record) { + return executePreparedTool( + collectTools(registerFileTools).get('StrReplace'), + { path: 'story.txt', ...params }, + cwd, + ); +} + +describe('file tools', () => { + it('lists directory contents including hidden files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, '.hidden'), 'secret', 'utf-8'); + writeFileSync(join(cwd, 'visible.txt'), 'visible', 'utf-8'); + + const result = await executeTool(collectTools(registerFileTools).get('LS'), { path: '.' }, cwd); + + expect(firstText(result)).toContain('.hidden'); + expect(firstText(result)).toContain('visible.txt'); + expect(result.details).toEqual({ path: realpathSync(cwd) }); + }); + + it('lists directory contents when Unix ls is not on PATH', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const oldPath = process.env.PATH; + process.env.PATH = tempDir('pi-grok-cli-empty-bin-'); + vi.resetModules(); + writeFileSync(join(cwd, 'visible.txt'), 'visible', 'utf-8'); + + try { + const result = await executeTool( + collectTools((await import('../../src/tools/files.js')).registerFileTools).get('LS'), + { path: '.' }, + cwd, + ); + + expect(firstText(result)).toContain('visible.txt'); + expect(result.details).toEqual({ path: realpathSync(cwd) }); + } finally { + process.env.PATH = oldPath; + vi.resetModules(); + } + }); + + it('reports filesystem errors for invalid file operations', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + mkdirSync(join(cwd, 'dir')); + writeFileSync(join(cwd, 'blocked'), 'not a directory', 'utf-8'); + const tools = collectTools(registerFileTools); + + const lsResult = await executeTool(tools.get('LS'), { path: 'missing-dir' }, cwd); + const readResult = await executeTool(tools.get('Read'), { path: 'dir' }, cwd); + const writeResult = await executeTool( + tools.get('Write'), + { path: 'blocked/file.txt', content: 'content' }, + cwd, + ); + const replaceResult = await executeTool( + tools.get('StrReplace'), + { path: 'dir', old_str: 'old', new_str: 'new' }, + cwd, + ); + const deleteResult = await executeTool(tools.get('Delete'), { path: 'dir' }, cwd); + + expect(firstText(lsResult).startsWith('LS error:')).toBe(true); + expect(firstText(readResult).startsWith('Read error:')).toBe(true); + expect(firstText(writeResult).startsWith('Write error:')).toBe(true); + expect(firstText(replaceResult).startsWith('StrReplace error:')).toBe(true); + expect(firstText(deleteResult).startsWith('Delete error:')).toBe(true); + expect(writeResult.details).toEqual({ + path: join(cwd, 'blocked', 'file.txt'), + bytesWritten: 0, + failed: true, + error: expect.stringContaining('EEXIST: file already exists, mkdir'), + }); + expect(replaceResult.details).toEqual({ + path: join(cwd, 'dir'), + replacements: 0, + failed: true, + error: expect.stringContaining('EISDIR: illegal operation on a directory, read'), + }); + expect(deleteResult.details).toEqual({ + path: join(cwd, 'dir'), + deleted: false, + failed: true, + error: expect.stringMatching( + /EISDIR: illegal operation on a directory|operation not permitted/, + ), + }); + }); + + it('writes a nested file and reads a requested line window', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const tools = collectTools(registerFileTools); + + const writeResult = await executeTool( + tools.get('Write'), + { path: 'nested/notes.txt', content: 'alpha\nbeta\ngamma\ndelta' }, + cwd, + ); + + expect(firstText(writeResult)).toBe('Successfully wrote 22 bytes to nested/notes.txt'); + expect(writeResult.details).toEqual({ + path: expectedPath(cwd, 'nested/notes.txt'), + bytesWritten: 22, + }); + + const readResult = await executeTool( + tools.get('Read'), + { path: 'nested/notes.txt', offset: 1, limit: 2 }, + cwd, + ); + + expect(firstText(readResult)).toBe( + '2\tbeta\n3\tgamma\n\n[Showing lines 2-3 of 4 total lines. Use offset to see more.]', + ); + expect(readResult.details).toEqual({ + path: expectedPath(cwd, 'nested/notes.txt'), + totalLines: 4, + }); + }); + + it('writes Cursor-style contents arguments', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + + const result = await executePreparedTool( + collectTools(registerFileTools).get('Write'), + { path: 'nested/notes.txt', contents: 'alpha\nbeta' }, + cwd, + ); + + expect(firstText(result)).toBe('Successfully wrote 10 bytes to nested/notes.txt'); + expect(readFileSync(join(cwd, 'nested/notes.txt'), 'utf-8')).toBe('alpha\nbeta'); + expect(result.details).toEqual({ + path: expectedPath(cwd, 'nested/notes.txt'), + bytesWritten: 10, + }); + }); + + it('reports UTF-8 bytes written for multibyte content', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const result = await executeTool( + collectTools(registerFileTools).get('Write'), + { path: 'emoji.txt', content: 'a🙂漢' }, + cwd, + ); + + expect(firstText(result)).toBe('Successfully wrote 8 bytes to emoji.txt'); + expect(result.details).toEqual({ + path: expectedPath(cwd, 'emoji.txt'), + bytesWritten: 8, + }); + }); + + it('honors a zero read limit', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'notes.txt'), 'alpha\nbeta', 'utf-8'); + const result = await executeTool( + collectTools(registerFileTools).get('Read'), + { path: 'notes.txt', limit: 0 }, + cwd, + ); + + expect(firstText(result)).toBe( + '\n\n[Showing lines 1-0 of 2 total lines. Use offset to see more.]', + ); + expect(result.details).toEqual({ + path: expectedPath(cwd, 'notes.txt'), + totalLines: 2, + }); + }); + + it('does not add a blank numbered line for files ending with a newline', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'notes.txt'), 'alpha\nbeta\n', 'utf-8'); + const result = await executeTool( + collectTools(registerFileTools).get('Read'), + { path: 'notes.txt' }, + cwd, + ); + + expect(firstText(result)).toBe('1\talpha\n2\tbeta'); + expect(result.details).toEqual({ + path: expectedPath(cwd, 'notes.txt'), + totalLines: 2, + }); + }); + + it('reports missing files without throwing', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const result = await executeTool( + collectTools(registerFileTools).get('Read'), + { path: 'missing.txt' }, + cwd, + ); + + expect(firstText(result)).toBe(`File not found: ${join(cwd, 'missing.txt')}`); + expect(result.details).toEqual({ + path: join(cwd, 'missing.txt'), + exists: false, + totalLines: 0, + }); + }); + + it('rejects paths that escape the workspace', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + const outside = tempDir('pi-grok-cli-files-outside-'); + writeFileSync(join(outside, 'secret.txt'), 'secret', 'utf-8'); + symlinkSync(outside, join(cwd, 'outside')); + + const readResult = await executeTool( + collectTools(registerFileTools).get('Read'), + { path: 'outside/secret.txt' }, + cwd, + ); + const writeResult = await executeTool( + collectTools(registerFileTools).get('Write'), + { path: '../escape.txt', content: 'escape' }, + cwd, + ); + + expect(firstText(readResult)).toBe('Read error: Path is outside the workspace'); + expect(readResult.details).toEqual({ + path: join(cwd, 'outside', 'secret.txt'), + exists: true, + totalLines: 0, + failed: true, + error: 'Path is outside the workspace', + }); + expect(firstText(writeResult)).toBe('Write error: Path is outside the workspace'); + expect(writeResult.details).toEqual({ + path: join(cwd, '..', 'escape.txt'), + bytesWritten: 0, + failed: true, + error: 'Path is outside the workspace', + }); + expect(existsSync(join(cwd, '..', 'escape.txt'))).toBe(false); + }); + + it('renders read errors for existing paths without claiming the file is missing', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + mkdirSync(join(cwd, 'dir')); + const tools = collectTools(registerFileTools); + const result = await executeTool(tools.get('Read'), { path: 'dir' }, cwd); + + expect(firstText(result).startsWith('Read error:')).toBe(true); + expect(result.details).toEqual({ + path: join(cwd, 'dir'), + exists: true, + totalLines: 0, + failed: true, + error: expect.stringContaining('EISDIR: illegal operation on a directory, read'), + }); + expect(renderToolResult(tools.get('Read'), result)).toBe('0 line(s)'); + }); + + it('replaces every exact string occurrence', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await strReplace(cwd, 'red', 'green'); + + expect(firstText(result)).toBe('Replaced 2 occurrence(s) in story.txt'); + expectStoryState(result, cwd, 2, 'green blue green'); + }); + + it('rejects empty replacement search strings without changing files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await strReplace(cwd, '', 'green'); + + expect(firstText(result)).toBe('StrReplace error: old_str must not be empty'); + expectStoryState(result, cwd, 0, 'red blue red'); + }); + + it('treats replacement text as a literal string', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'abc', 'utf-8'); + + const result = await strReplace(cwd, 'a', '$&'); + + expect(firstText(result)).toBe('Replaced 1 occurrence(s) in story.txt'); + expectStoryState(result, cwd, 1, '$&bc'); + }); + + it('replaces string occurrences with Grok and Cursor argument variants', async () => { + const oldStringCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(oldStringCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const oldStringResult = await strReplaceWithPreparedArgs(oldStringCwd, { + old_string: 'red', + new_string: 'green', + }); + + expect(firstText(oldStringResult)).toBe('Replaced 2 occurrence(s) in story.txt'); + expectStoryState(oldStringResult, oldStringCwd, 2, 'green blue green'); + + const oldTextCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(oldTextCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const oldTextResult = await strReplaceWithPreparedArgs(oldTextCwd, { + oldText: 'red', + newText: 'green', + }); + + expect(firstText(oldTextResult)).toBe('Replaced 2 occurrence(s) in story.txt'); + expectStoryState(oldTextResult, oldTextCwd, 2, 'green blue green'); + + const nestedCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(nestedCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const nestedResult = await strReplaceWithPreparedArgs(nestedCwd, { + strReplace: { oldText: 'red', newText: 'green' }, + }); + + expect(firstText(nestedResult)).toBe('Replaced 2 occurrence(s) in story.txt'); + expectStoryState(nestedResult, nestedCwd, 2, 'green blue green'); + }); + + it('edits files with single, multiple, and stringified replacement inputs', async () => { + const singleCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(singleCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const singleResult = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { path: 'story.txt', oldText: 'red', newText: 'green' }, + singleCwd, + ); + + expect(firstText(singleResult)).toBe('Applied 2 replacement(s) in story.txt'); + expectStoryState(singleResult, singleCwd, 2, 'green blue green'); + + const multipleCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(multipleCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const multipleResult = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { + path: 'story.txt', + edits: [ + { oldText: 'red', newText: 'green' }, + { oldText: 'blue', newText: 'yellow' }, + ], + }, + multipleCwd, + ); + + expect(firstText(multipleResult)).toBe('Applied 3 replacement(s) in story.txt'); + expectStoryState(multipleResult, multipleCwd, 3, 'green yellow green'); + + const stringifiedCwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(stringifiedCwd, 'story.txt'), 'red blue red', 'utf-8'); + + const stringifiedResult = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { + path: 'story.txt', + edits: JSON.stringify([{ oldText: 'red', newText: 'green' }]), + }, + stringifiedCwd, + ); + + expect(firstText(stringifiedResult)).toBe('Applied 2 replacement(s) in story.txt'); + expectStoryState(stringifiedResult, stringifiedCwd, 2, 'green blue green'); + }); + + it('edits files with literal replacement text', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'abc', 'utf-8'); + + const result = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { path: 'story.txt', oldText: 'a', newText: '$&' }, + cwd, + ); + + expect(firstText(result)).toBe('Applied 1 replacement(s) in story.txt'); + expectStoryState(result, cwd, 1, '$&bc'); + }); + + it('rejects empty edit search strings without changing files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { path: 'story.txt', oldText: '', newText: 'green' }, + cwd, + ); + + expect(firstText(result)).toBe('Edit error: oldText must not be empty'); + expectStoryState(result, cwd, 0, 'red blue red'); + }); + + it('reports unsupported edit strategies without changing files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await executePreparedTool( + collectTools(registerFileTools).get('Edit'), + { path: 'story.txt', applyPatch: { patchContent: 'patch' } }, + cwd, + ); + + expect(firstText(result)).toBe( + 'Edit error: applyPatch is not supported by this Grok tool shim', + ); + expectStoryState(result, cwd, 0, 'red blue red'); + }); + + it('leaves files unchanged when the replacement string is absent', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'story.txt'), 'red blue red', 'utf-8'); + + const result = await strReplace(cwd, 'purple', 'green'); + + expect(firstText(result)).toBe('String not found in story.txt: "purple"'); + expectStoryState(result, cwd, 0, 'red blue red'); + }); + + it('deletes existing files and reports missing files', async () => { + const cwd = tempDir('pi-grok-cli-files-'); + writeFileSync(join(cwd, 'remove.txt'), 'delete me', 'utf-8'); + const tools = collectTools(registerFileTools); + + const deletedResult = await executeTool(tools.get('Delete'), { path: 'remove.txt' }, cwd); + + expect(firstText(deletedResult)).toBe('Successfully deleted remove.txt'); + expect(deletedResult.details).toEqual({ + path: expectedPath(cwd, 'remove.txt'), + deleted: true, + }); + expect(existsSync(join(cwd, 'remove.txt'))).toBe(false); + + const missingResult = await executeTool(tools.get('Delete'), { path: 'remove.txt' }, cwd); + + expect(firstText(missingResult)).toBe(`File not found: ${join(cwd, 'remove.txt')}`); + expect(missingResult.details).toEqual({ + path: join(cwd, 'remove.txt'), + deleted: false, + }); + }); + + it('renders file tool calls and result states', () => { + const tools = collectTools(registerFileTools); + + expect(renderToolCall(tools.get('LS'), { path: '.' })).toBe('LS .'); + expect( + renderToolCall(tools.get('Read'), { + path: 'notes.txt', + offset: 5, + limit: 10, + }), + ).toBe('Read notes.txt (from 5, 10 lines)'); + expect(renderToolCall(tools.get('StrReplace'), { path: 'notes.txt' })).toBe( + 'StrReplace notes.txt', + ); + expect(renderToolCall(tools.get('Delete'), { path: 'notes.txt' })).toBe('Delete notes.txt'); + expect( + renderToolResult(tools.get('Read'), { + content: [{ type: 'text', text: 'missing' }], + details: { exists: false, totalLines: 0 }, + }), + ).toBe('File not found'); + expect( + renderToolResult(tools.get('StrReplace'), { + content: [{ type: 'text', text: 'no replacement' }], + details: { replacements: 0 }, + }), + ).toBe('No replacements'); + expect( + renderToolResult(tools.get('Delete'), { + content: [{ type: 'text', text: 'not deleted' }], + details: { deleted: false }, + }), + ).toBe('Not deleted'); + expect( + renderToolResult( + tools.get('LS'), + { + content: [{ type: 'text', text: 'full listing' }], + details: { path: '/tmp/project' }, + }, + { expanded: true, isPartial: false }, + ), + ).toBe('full listing'); + expect( + renderToolResult( + tools.get('Write'), + { + content: [{ type: 'text', text: 'writing' }], + details: { bytesWritten: 10 }, + }, + { expanded: false, isPartial: true }, + ), + ).toBe('Running...'); + }); +}); diff --git a/tests/tools/register.test.ts b/tests/tools/register.test.ts new file mode 100644 index 0000000..04b8690 --- /dev/null +++ b/tests/tools/register.test.ts @@ -0,0 +1,19 @@ +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { describe, expect, it } from 'vitest'; +import { GROK_TOOL_NAMES, registerGrokTools } from '../../src/tools/register.js'; + +describe('Grok tool registration', () => { + it('registers all Grok/Cursor-native tool shims with renderers', () => { + const toolNames: string[] = []; + + registerGrokTools({ + registerTool(tool: { name: string; renderCall?: unknown; renderResult?: unknown }) { + toolNames.push(tool.name); + expect(tool.renderCall).toBeTypeOf('function'); + expect(tool.renderResult).toBeTypeOf('function'); + }, + } as unknown as ExtensionAPI); + + expect(toolNames.sort()).toEqual([...GROK_TOOL_NAMES].sort()); + }); +}); diff --git a/tests/tools/rendering.test.ts b/tests/tools/rendering.test.ts new file mode 100644 index 0000000..2f673f2 --- /dev/null +++ b/tests/tools/rendering.test.ts @@ -0,0 +1,92 @@ +import { describe, expect, it } from 'vitest'; +import { + booleanDetail, + detailRecord, + fileError, + fileNotFound, + MAX_LINES, + MAX_OUTPUT_CHARS, + numberDetail, + renderResultSummary, + renderResultText, + renderRunning, + stringDetail, + text, + toolError, + truncateChars, + truncateLines, +} from '../../src/tools/rendering.js'; +import { renderText } from './toolTestHelpers.js'; + +describe('tool rendering helpers', () => { + it('truncates long result lists and large output', () => { + expect(truncateLines(['one', 'two'])).toBe('one\ntwo'); + expect( + truncateLines(Array.from({ length: MAX_LINES + 1 }, String)).endsWith( + `\n\n[Showing first ${MAX_LINES} of ${MAX_LINES + 1} results. Refine your pattern to narrow results.]`, + ), + ).toBe(true); + expect(truncateChars('short')).toBe('short'); + expect(truncateChars('x'.repeat(MAX_OUTPUT_CHARS + 1))).toBe( + `${'x'.repeat(MAX_OUTPUT_CHARS)}\n\n[Output truncated at 50KB]`, + ); + }); + + it('renders summaries, expanded text, missing text fallback, and partial state', () => { + const result = { + content: [{ type: 'text', text: 'full output' }], + details: {}, + }; + + expect(renderText(text('plain'))).toBe('plain'); + expect(renderText(renderResultText(result, false, 'summary'))).toBe('summary'); + expect(renderText(renderResultText(result, true, 'summary'))).toBe('full output'); + expect( + renderText(renderResultText({ content: [{ type: 'image' }], details: {} }, true, 'summary')), + ).toBe('summary'); + expect(renderText(renderRunning(true) ?? text(''))).toBe('Running...'); + expect(renderRunning(false)).toBeUndefined(); + expect(renderText(renderResultSummary(result, false, true, 'summary'))).toBe('Running...'); + }); + + it('reads typed detail values with defaults for absent or invalid details', () => { + const result = { + content: [{ type: 'text', text: '' }], + details: { count: 2, path: 'file.txt', deleted: true, invalid: null }, + }; + + expect(detailRecord(result)).toEqual(result.details); + expect(detailRecord({ details: null })).toEqual({}); + expect(numberDetail(result, 'count')).toBe(2); + expect(numberDetail(result, 'path')).toBe(0); + expect(stringDetail(result, 'path')).toBe('file.txt'); + expect(stringDetail(result, 'count')).toBe(''); + expect(booleanDetail(result, 'deleted')).toBe(true); + expect(booleanDetail(result, 'invalid')).toBe(false); + }); + + it('formats file and command errors with stable empty details', () => { + expect(fileNotFound('/tmp/missing.txt', { deleted: false })).toEqual({ + content: [{ type: 'text', text: 'File not found: /tmp/missing.txt' }], + details: { path: '/tmp/missing.txt', deleted: false }, + }); + expect(fileError({}, 'Read', '/tmp/file.txt', { totalLines: 0 })).toEqual({ + content: [{ type: 'text', text: 'Read error: Unknown error' }], + details: { path: '/tmp/file.txt', totalLines: 0, failed: true, error: 'Unknown error' }, + }); + expect(toolError({ code: 1 }, 'Grep', { matchCount: 0 })).toEqual({ + content: [{ type: 'text', text: 'No matches found' }], + details: { matchCount: 0 }, + }); + expect( + toolError({ code: 1, message: 'find: missing: No such file' }, 'Glob', { fileCount: 0 }), + ).toEqual({ + content: [{ type: 'text', text: 'Glob error: find: missing: No such file' }], + details: { fileCount: 0, failed: true, error: 'find: missing: No such file' }, + }); + expect(toolError({}, 'Grep', { matchCount: 0 })).toEqual({ + content: [{ type: 'text', text: 'Grep error: Unknown error' }], + details: { matchCount: 0, failed: true, error: 'Unknown error' }, + }); + }); +}); diff --git a/tests/tools/search.test.ts b/tests/tools/search.test.ts new file mode 100644 index 0000000..8237b7a --- /dev/null +++ b/tests/tools/search.test.ts @@ -0,0 +1,338 @@ +import { mkdirSync, rmSync, symlinkSync, utimesSync, writeFileSync } from 'node:fs'; +import { join } from 'node:path'; +import { describe, expect, it, vi } from 'vitest'; +import { registerSearchTools, sortByModifiedNewest } from '../../src/tools/search.js'; +import { + collectTools, + executePreparedTool, + executeTool, + firstText, + plainTheme, + renderText, + type ToolResult, + tempDir, +} from './toolTestHelpers.js'; + +function setupProject() { + const dir = tempDir('pi-grok-cli-search-'); + mkdirSync(join(dir, 'src')); + writeFileSync(join(dir, 'src', 'alpha.ts'), 'needle\nhaystack\n', 'utf-8'); + writeFileSync(join(dir, 'src', 'beta.md'), 'needle in docs\n', 'utf-8'); + writeFileSync(join(dir, 'src', 'gamma.ts'), 'plain text\n', 'utf-8'); + return dir; +} + +function expectGrepResult(cwd: string, result: ToolResult) { + expect(firstText(result)).toContain(`${join(cwd, 'src', 'alpha.ts')}:1:needle`); + expect(firstText(result)).not.toContain('beta.md'); + expect(result.details).toEqual({ matchCount: 1 }); +} + +function expectGlobResult(cwd: string, result: ToolResult) { + expect(firstText(result)).toContain(join(cwd, 'src', 'alpha.ts')); + expect(firstText(result)).toContain(join(cwd, 'src', 'gamma.ts')); + expect(firstText(result)).not.toContain('beta.md'); + expect(result.details).toEqual({ fileCount: 2 }); +} + +async function withFindFallbackTools( + run: (tools: ReturnType) => Promise, +) { + const bin = tempDir('pi-grok-cli-search-bin-'); + symlinkSync('/usr/bin/find', join(bin, 'find')); + const oldPath = process.env.PATH; + process.env.PATH = bin; + vi.resetModules(); + try { + await run(collectTools((await import('../../src/tools/search.js')).registerSearchTools)); + } finally { + process.env.PATH = oldPath; + vi.resetModules(); + } +} + +async function withNoSearchBinaries( + run: (tools: ReturnType) => Promise, +) { + const oldPath = process.env.PATH; + process.env.PATH = tempDir('pi-grok-cli-empty-bin-'); + vi.resetModules(); + try { + await run(collectTools((await import('../../src/tools/search.js')).registerSearchTools)); + } finally { + process.env.PATH = oldPath; + vi.resetModules(); + } +} + +describe('search tools', () => { + it('greps matching file contents with include filters', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: 'needle', path: 'src', include: '*.ts' }, + cwd, + ); + + expectGrepResult(cwd, result); + }); + + it('greps matching file contents with Cursor-style glob filters', async () => { + const cwd = setupProject(); + const result = await executePreparedTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: 'needle', path: 'src', glob_filter: '*.ts' }, + cwd, + ); + + expectGrepResult(cwd, result); + }); + + it('greps patterns that start with a dash', async () => { + const cwd = setupProject(); + writeFileSync(join(cwd, 'src', 'dash.ts'), '-export const value = 1\n', 'utf-8'); + + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: '-export', path: 'src/dash.ts' }, + cwd, + ); + + expect(firstText(result)).toBe(`${join(cwd, 'src', 'dash.ts')}:1:-export const value = 1`); + expect(result.details).toEqual({ matchCount: 1 }); + }); + + it('includes file paths when grepping a single file', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: 'needle', path: 'src/alpha.ts' }, + cwd, + ); + + expect(firstText(result)).toBe(`${join(cwd, 'src', 'alpha.ts')}:1:needle`); + expect(result.details).toEqual({ matchCount: 1 }); + }); + + it('reports no grep matches as an empty result', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: 'absent', path: 'src' }, + cwd, + ); + + expect(firstText(result)).toBe('No matches found'); + expect(result.details).toEqual({ matchCount: 0 }); + }); + + it('reports grep command errors with empty match details', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Grep'), + { pattern: '[', path: 'src' }, + cwd, + ); + + expect(firstText(result).startsWith('Grep error:')).toBe(true); + expect(result.details).toEqual({ + matchCount: 0, + failed: true, + error: expect.stringMatching(/regex parse error|Invalid regular expression/), + }); + }); + + it('globs files under the requested path', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Glob'), + { pattern: '**/*.ts', path: 'src' }, + cwd, + ); + + expectGlobResult(cwd, result); + }); + + it('globs files with Cursor-style glob pattern arguments', async () => { + const cwd = setupProject(); + const result = await executePreparedTool( + collectTools(registerSearchTools).get('Glob'), + { glob_pattern: '**/*.ts', path: 'src' }, + cwd, + ); + + expectGlobResult(cwd, result); + }); + + it('reports empty glob command results', async () => { + const cwd = setupProject(); + const result = await executeTool( + collectTools(registerSearchTools).get('Glob'), + { pattern: '**/*.json', path: 'src' }, + cwd, + ); + + expect(firstText(result)).toBe('No files found'); + expect(result.details).toEqual({ fileCount: 0 }); + }); + + it('globs path-containing patterns through the find fallback', async () => { + const cwd = setupProject(); + await withFindFallbackTools(async (fallbackTools) => { + const result = await executeTool(fallbackTools.get('Glob'), { pattern: 'src/**/*.ts' }, cwd); + + expectGlobResult(cwd, result); + }); + }); + + it('globs basename-only patterns through the find fallback', async () => { + const cwd = setupProject(); + await withFindFallbackTools(async (fallbackTools) => { + const result = await executeTool(fallbackTools.get('Glob'), { pattern: '*.ts' }, cwd); + + expectGlobResult(cwd, result); + }); + }); + + it('globs files without ripgrep or Unix find on PATH', async () => { + const cwd = setupProject(); + await withNoSearchBinaries(async (fallbackTools) => { + const result = await executeTool(fallbackTools.get('Glob'), { pattern: 'src/**/*.ts' }, cwd); + + expectGlobResult(cwd, result); + }); + }); + + it('greps files without ripgrep or Unix grep on PATH', async () => { + const cwd = setupProject(); + await withNoSearchBinaries(async (fallbackTools) => { + const result = await executeTool( + fallbackTools.get('Grep'), + { pattern: 'needle', path: 'src', include: '*.ts' }, + cwd, + ); + + expectGrepResult(cwd, result); + }); + }); + + it('sorts glob results by modification time newest first', async () => { + const cwd = setupProject(); + const oldTime = new Date('2024-01-01T00:00:00.000Z'); + const newTime = new Date('2024-01-02T00:00:00.000Z'); + utimesSync(join(cwd, 'src', 'alpha.ts'), oldTime, oldTime); + utimesSync(join(cwd, 'src', 'gamma.ts'), newTime, newTime); + const result = await executeTool( + collectTools(registerSearchTools).get('Glob'), + { pattern: '**/*.ts', path: 'src' }, + cwd, + ); + + expect(firstText(result).split('\n')).toEqual([ + join(cwd, 'src', 'gamma.ts'), + join(cwd, 'src', 'alpha.ts'), + ]); + }); + + it('sorts existing glob results when another match is deleted before stat', () => { + const cwd = setupProject(); + const deleted = join(cwd, 'src', 'deleted.ts'); + writeFileSync(deleted, 'deleted\n', 'utf-8'); + rmSync(deleted); + + expect(sortByModifiedNewest([deleted, join(cwd, 'src', 'alpha.ts')])).toEqual([ + join(cwd, 'src', 'alpha.ts'), + deleted, + ]); + }); + + it('renders grep calls and result states', () => { + const grep = collectTools(registerSearchTools).get('Grep'); + const result = { + content: [{ type: 'text', text: 'src/alpha.ts:1:needle' }], + details: { matchCount: 1 }, + }; + + expect( + renderText( + grep?.renderCall?.({ pattern: 'needle', path: 'src', include: '*.ts' }, plainTheme) ?? { + render: () => [], + }, + ), + ).toBe('Grep "needle" in src [*.ts]'); + expect( + renderText( + grep?.renderResult?.(result, { expanded: false, isPartial: false }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('1 match(es)'); + expect( + renderText( + grep?.renderResult?.(result, { expanded: true, isPartial: false }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('src/alpha.ts:1:needle'); + expect( + renderText( + grep?.renderResult?.( + { + content: [{ type: 'text', text: 'No matches found' }], + details: {}, + }, + { expanded: false, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe('No matches'); + expect( + renderText( + grep?.renderResult?.(result, { expanded: false, isPartial: true }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('Running...'); + }); + + it('renders glob calls and result states', () => { + const glob = collectTools(registerSearchTools).get('Glob'); + const result = { + content: [{ type: 'text', text: 'src/alpha.ts\nsrc/gamma.ts' }], + details: { fileCount: 2 }, + }; + + expect( + renderText( + glob?.renderCall?.({ pattern: '**/*.ts', path: 'src' }, plainTheme) ?? { + render: () => [], + }, + ), + ).toBe('Glob **/*.ts in src'); + expect( + renderText( + glob?.renderResult?.(result, { expanded: false, isPartial: false }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('2 file(s)'); + expect( + renderText( + glob?.renderResult?.( + { content: [{ type: 'text', text: 'No files found' }], details: {} }, + { expanded: false, isPartial: false }, + plainTheme, + {}, + ) ?? { render: () => [] }, + ), + ).toBe('No files'); + expect( + renderText( + glob?.renderResult?.(result, { expanded: false, isPartial: true }, plainTheme, {}) ?? { + render: () => [], + }, + ), + ).toBe('Running...'); + }); +}); diff --git a/tests/tools/shell.test.ts b/tests/tools/shell.test.ts new file mode 100644 index 0000000..2fc519b --- /dev/null +++ b/tests/tools/shell.test.ts @@ -0,0 +1,148 @@ +import { writeFileSync } from 'node:fs'; +import { join } from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { registerShellTool } from '../../src/tools/shell.js'; +import { + collectTools, + executeTool, + firstText, + renderToolCall, + renderToolResult, + tempDir, +} from './toolTestHelpers.js'; + +describe('shell tool', () => { + it('returns stdout, stderr, and exit zero details', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'printf stdout && printf stderr >&2' }, + cwd, + ); + + expect(firstText(result)).toBe('stdout\n[stderr]\nstderr'); + expect(result.details).toEqual({ + exitCode: 0, + command: 'printf stdout && printf stderr >&2', + }); + }); + + it('runs commands in a resolved working directory', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + writeFileSync(join(cwd, 'target.txt'), 'from cwd', 'utf-8'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'cat target.txt', working_directory: '.' }, + cwd, + ); + + expect(firstText(result)).toBe('from cwd'); + expect(result.details).toEqual({ + exitCode: 0, + command: 'cat target.txt', + }); + }); + + it('returns a clear placeholder when commands produce no output', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'true' }, + cwd, + ); + + expect(firstText(result)).toBe('(no output)'); + expect(result.details).toEqual({ exitCode: 0, command: 'true' }); + }); + + it('includes exit code, error message, and captured output on failure', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'printf before && printf problem >&2 && exit 7' }, + cwd, + ); + + expect(firstText(result)).toContain('Shell error (exit code 7):'); + expect(firstText(result)).toContain('before\n[stderr]\nproblem'); + expect(result.details).toEqual({ + exitCode: 7, + command: 'printf before && printf problem >&2 && exit 7', + }); + }); + + it('truncates large successful and failed output', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const tools = collectTools(registerShellTool); + const largeOutput = "head -c 50001 /dev/zero | tr '\\0' x"; + + const successResult = await executeTool(tools.get('Shell'), { command: largeOutput }, cwd); + const failureResult = await executeTool( + tools.get('Shell'), + { command: `${largeOutput}; exit 9` }, + cwd, + ); + + expect(firstText(successResult)).toHaveLength('\n\n[Output truncated at 50KB]'.length + 50_000); + expect(firstText(successResult).endsWith('[Output truncated at 50KB]')).toBe(true); + expect(firstText(failureResult)).toContain('Shell error (exit code 9):'); + expect(firstText(failureResult).endsWith('[Output truncated at 50KB]')).toBe(true); + }); + + it('truncates multibyte output by characters without hitting exec buffer limits', async () => { + const cwd = tempDir('pi-grok-cli-shell-'); + const result = await executeTool( + collectTools(registerShellTool).get('Shell'), + { command: 'perl -e \'print "漢" x 50001\'' }, + cwd, + ); + + expect(firstText(result)).toHaveLength('\n\n[Output truncated at 50KB]'.length + 50_000); + expect(firstText(result).startsWith('Shell error')).toBe(false); + expect(firstText(result).endsWith('[Output truncated at 50KB]')).toBe(true); + }); + + it('renders shell calls and result states', () => { + const shell = collectTools(registerShellTool).get('Shell'); + + expect( + renderToolCall(shell, { + command: 'pwd', + working_directory: 'src', + }), + ).toBe('Shell pwd in src'); + expect(renderToolCall(shell, { command: 'pwd' })).toBe('Shell pwd'); + expect( + renderToolResult(shell, { + content: [{ type: 'text', text: 'full output' }], + details: { exitCode: 0 }, + }), + ).toBe('Exit 0'); + expect( + renderToolResult(shell, { + content: [{ type: 'text', text: 'spawn failed' }], + details: { exitCode: 'ENOENT' }, + }), + ).toBe('Exit 1'); + expect( + renderToolResult( + shell, + { + content: [{ type: 'text', text: 'full output' }], + details: { exitCode: 0 }, + }, + { expanded: true, isPartial: false }, + ), + ).toBe('full output'); + expect( + renderToolResult( + shell, + { + content: [{ type: 'text', text: 'still running' }], + details: { exitCode: 0 }, + }, + { expanded: false, isPartial: true }, + ), + ).toBe('Running...'); + }); +}); diff --git a/tests/tools/toolTestHelpers.ts b/tests/tools/toolTestHelpers.ts new file mode 100644 index 0000000..eca9206 --- /dev/null +++ b/tests/tools/toolTestHelpers.ts @@ -0,0 +1,118 @@ +import { mkdtempSync, rmSync } from 'node:fs'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import type { ExtensionAPI } from '@earendil-works/pi-coding-agent'; +import { afterEach } from 'vitest'; + +const tempDirs: string[] = []; + +afterEach(() => { + for (const dir of tempDirs.splice(0)) rmSync(dir, { recursive: true }); +}); + +export type ToolResult = { + content: { type: string; text?: string }[]; + details: Record; +}; + +type Renderable = { render: (width: number) => string[] }; + +type ToolTheme = { + bold: (text: string) => string; + fg: (name: string, text: string) => string; +}; + +type RegisteredTool = { + name: string; + prepareArguments?: (params: Record) => Record; + execute: ( + toolCallId: string, + params: Record, + signal: AbortSignal, + onUpdate: () => void, + ctx: { cwd: string }, + ) => Promise; + renderCall?: (args: Record, theme: ToolTheme) => Renderable; + renderResult?: ( + result: ToolResult, + state: { expanded: boolean; isPartial: boolean }, + theme: ToolTheme, + args: Record, + ) => Renderable; +}; + +export function collectTools(registerTools: (pi: ExtensionAPI) => void) { + const tools = new Map(); + registerTools({ + registerTool(tool: RegisteredTool) { + tools.set(tool.name, tool); + }, + } as unknown as ExtensionAPI); + return tools; +} + +export async function executeTool( + tool: RegisteredTool | undefined, + params: Record, + cwd: string, +) { + if (!tool) throw new Error('Tool was not registered'); + return tool.execute('tool-call-id', params, new AbortController().signal, () => {}, { + cwd, + }); +} + +export function prepareToolArguments( + tool: RegisteredTool | undefined, + params: Record, +) { + if (!tool) throw new Error('Tool was not registered'); + return tool.prepareArguments?.(params) ?? params; +} + +export async function executePreparedTool( + tool: RegisteredTool | undefined, + params: Record, + cwd: string, +) { + if (!tool) throw new Error('Tool was not registered'); + return executeTool(tool, prepareToolArguments(tool, params), cwd); +} + +export function firstText(result: ToolResult) { + return result.content[0]?.text ?? ''; +} + +export function renderText(component: { render: (width: number) => string[] }) { + return component + .render(120) + .map((line) => line.trimEnd()) + .join('\n'); +} + +export const plainTheme = { + bold: (text: string) => text, + fg: (_name: string, text: string) => text, +}; + +export function renderToolCall(tool: RegisteredTool | undefined, args: Record) { + if (!tool?.renderCall) throw new Error('Tool call renderer was not registered'); + return renderText(tool.renderCall(args, plainTheme)); +} + +export function renderToolResult( + tool: RegisteredTool | undefined, + result: ToolResult, + state = { expanded: false, isPartial: false }, +) { + if (!tool?.renderResult) { + throw new Error('Tool result renderer was not registered'); + } + return renderText(tool.renderResult(result, state, plainTheme, {})); +} + +export function tempDir(prefix: string) { + const dir = mkdtempSync(join(tmpdir(), prefix)); + tempDirs.push(dir); + return dir; +} diff --git a/tsconfig.json b/tsconfig.json index 73fa4c9..e9b2d7a 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,18 +1,18 @@ { - "compilerOptions": { - "target": "ES2022", - "module": "ES2022", - "moduleResolution": "bundler", - "strict": true, - "esModuleInterop": true, - "skipLibCheck": true, - "forceConsistentCasingInFileNames": true, - "resolveJsonModule": true, - "declaration": true, - "declarationMap": true, - "sourceMap": true, - "outDir": "./dist" - }, - "include": ["src/**/*.ts"], - "exclude": ["node_modules", "dist"] + "compilerOptions": { + "target": "ES2022", + "module": "ES2022", + "moduleResolution": "bundler", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist" + }, + "include": ["src/**/*.ts"], + "exclude": ["node_modules", "dist"] } diff --git a/vitest.config.ts b/vitest.config.ts index 3f9296d..0e6e4d1 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -1,12 +1,12 @@ -import { defineConfig } from "vitest/config"; +import { defineConfig } from 'vitest/config'; export default defineConfig({ - test: { - coverage: { - provider: "v8", - reporter: ["text", "lcov"], - include: ["src/**/*.ts"], - exclude: ["src/index.ts"], - }, - }, + test: { + coverage: { + provider: 'v8', + reporter: ['text', 'lcov'], + include: ['src/**/*.ts'], + exclude: ['src/index.ts'], + }, + }, });