diff --git a/mobile/scripts/bundle-terminal.ts b/mobile/scripts/bundle-terminal.ts index faff7d69..e6b97518 100644 --- a/mobile/scripts/bundle-terminal.ts +++ b/mobile/scripts/bundle-terminal.ts @@ -45,7 +45,7 @@ ${umdContent} let ws = null; let fitAddon = null; - async function connect(wsUrl, initialCommand) { + async function connect(wsUrl, initialCommand, token) { const ghostty = await Ghostty.load(); term = new Terminal({ @@ -100,6 +100,9 @@ ${umdContent} ws.onopen = () => { window.ReactNativeWebView.postMessage(JSON.stringify({ type: 'connected' })); + if (token) { + ws.send(JSON.stringify({ type: 'auth', token: token })); + } const dims = term.getDimensions ? term.getDimensions() : { cols: term.cols, rows: term.rows }; ws.send(JSON.stringify({ type: 'resize', cols: dims.cols, rows: dims.rows })); if (initialCommand) { diff --git a/mobile/src/screens/TerminalScreen.tsx b/mobile/src/screens/TerminalScreen.tsx index c8ad2621..23559d69 100644 --- a/mobile/src/screens/TerminalScreen.tsx +++ b/mobile/src/screens/TerminalScreen.tsx @@ -12,7 +12,7 @@ import { import { useSafeAreaInsets } from 'react-native-safe-area-context' import { WebView } from 'react-native-webview' import { useQuery } from '@tanstack/react-query' -import { api, getTerminalHtml, getTerminalUrl, HOST_WORKSPACE_NAME } from '../lib/api' +import { api, getTerminalHtml, getTerminalUrl, getToken, HOST_WORKSPACE_NAME } from '../lib/api' import { ExtraKeysBar } from '../components/ExtraKeysBar' import { useTheme } from '../contexts/ThemeContext' @@ -103,11 +103,13 @@ export function TerminalScreen({ route, navigation }: any) { } const wsUrl = getTerminalUrl(name) + const token = getToken() const escapedCommand = initialCommand ? initialCommand.replace(/\\/g, '\\\\').replace(/'/g, "\\'") : '' + const escapedToken = token ? token.replace(/\\/g, '\\\\').replace(/'/g, "\\'") : '' const injectedJS = ` if (window.initTerminal) { - window.initTerminal('${wsUrl}', '${escapedCommand}'); + window.initTerminal('${wsUrl}', '${escapedCommand}', '${escapedToken}'); } true; ` diff --git a/src/agent/auth.ts b/src/agent/auth.ts index 855817ec..a0fec0b9 100644 --- a/src/agent/auth.ts +++ b/src/agent/auth.ts @@ -2,7 +2,7 @@ import { timingSafeEqual } from 'crypto'; import type { AgentConfig } from '../shared/types'; import { getTailscaleIdentity } from '../tailscale'; -function secureCompare(a: string, b: string): boolean { +export function secureCompare(a: string, b: string): boolean { if (a.length !== b.length) return false; return timingSafeEqual(Buffer.from(a), Buffer.from(b)); } diff --git a/src/agent/run.ts b/src/agent/run.ts index 6a215e45..f287278a 100644 --- a/src/agent/run.ts +++ b/src/agent/run.ts @@ -39,6 +39,7 @@ interface TailscaleInfo { interface WebSocketData { type: 'terminal'; workspaceName: string; + authenticated: boolean; } function createAgentServer( @@ -89,6 +90,7 @@ function createAgentServer( isWorkspaceRunning, isHostAccessAllowed: () => currentConfig.allowHostAccess === true, getPreferredShell, + getAuthToken: () => currentConfig.auth?.token, }); const triggerAutoSync = () => { @@ -143,24 +145,21 @@ function createAgentServer( return staticResponse; } - const authResult = checkAuth(req, currentConfig); - if (!authResult.ok) { - return unauthorizedResponse(); - } - const terminalMatch = pathname.match(/^\/rpc\/terminal\/([^/]+)$/); if (terminalMatch) { const type: WebSocketData['type'] = 'terminal'; const workspaceName = decodeURIComponent(terminalMatch[1]); + const authResult = checkAuth(req, currentConfig); + const running = await isWorkspaceRunning(workspaceName); if (!running) { return new Response('Not Found', { status: 404 }); } const upgraded = server.upgrade(req, { - data: { type, workspaceName }, + data: { type, workspaceName, authenticated: authResult.ok }, }); if (upgraded) { @@ -169,6 +168,11 @@ function createAgentServer( return new Response('WebSocket upgrade failed', { status: 400 }); } + const authResult = checkAuth(req, currentConfig); + if (!authResult.ok) { + return unauthorizedResponse(); + } + if (pathname === '/health' && method === 'GET') { const identity = getTailscaleIdentity(req); const response: Record = { status: 'ok', version: pkg.version }; @@ -198,9 +202,9 @@ function createAgentServer( websocket: { open(ws: ServerWebSocket) { - const { type, workspaceName } = ws.data; + const { type, workspaceName, authenticated } = ws.data; if (type === 'terminal') { - terminalHandler.handleOpen(ws, workspaceName); + terminalHandler.handleOpen(ws, workspaceName, authenticated); } }, diff --git a/src/client/api.ts b/src/client/api.ts index 786b73cf..5be5a209 100644 --- a/src/client/api.ts +++ b/src/client/api.ts @@ -187,6 +187,10 @@ export class ApiClient { return `${wsUrl}/rpc/terminal/${encodeURIComponent(name)}`; } + getAuthToken(): string | undefined { + return this.token; + } + getOpencodeUrl(name: string): string { const wsUrl = this.baseUrl.replace(/^http/, 'ws'); return `${wsUrl}/rpc/opencode/${encodeURIComponent(name)}`; diff --git a/src/client/ws-shell.ts b/src/client/ws-shell.ts index 6313b53b..1b119e1c 100644 --- a/src/client/ws-shell.ts +++ b/src/client/ws-shell.ts @@ -5,6 +5,7 @@ import { DEFAULT_AGENT_PORT } from '../shared/constants'; export interface WSShellOptions { url: string; + token?: string; onConnect?: () => void; onDisconnect?: (code: number) => void; onError?: (error: Error) => void; @@ -121,7 +122,7 @@ export async function openTailscaleSSH(options: TailscaleSSHOptions): Promise { - const { url, onConnect, onDisconnect, onError } = options; + const { url, token, onConnect, onDisconnect, onError } = options; return new Promise((resolve, reject) => { const ws = new WebSocket(url); @@ -155,6 +156,10 @@ export async function openWSShell(options: WSShellOptions): Promise { } stdin.resume(); + if (token) { + safeSend(JSON.stringify({ type: 'auth', token })); + } + sendResize(); if (onConnect) { diff --git a/src/index.ts b/src/index.ts index fc29c504..454d509d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -466,6 +466,7 @@ program try { const agentHost = await getAgentWithFallback(); const client = await createClient(); + const token = await getToken(); const workspace = await client.getWorkspace(name); if (workspace.status !== 'running') { @@ -495,6 +496,7 @@ program const wsUrl = getTerminalWSUrl(agentHost, name); await openWSShell({ url: wsUrl, + token: token || undefined, onError: (err) => { console.error(`\nConnection error: ${err.message}`); }, diff --git a/src/terminal/bun-handler.ts b/src/terminal/bun-handler.ts index 9eaa26e3..3fd30351 100644 --- a/src/terminal/bun-handler.ts +++ b/src/terminal/bun-handler.ts @@ -1,7 +1,8 @@ import type { ServerWebSocket } from 'bun'; import { createTerminalSession, TerminalSession } from './handler'; import { createHostTerminalSession, HostTerminalSession } from './host-handler'; -import { isControlMessage } from './types'; +import { isControlMessage, isAuthMessage } from './types'; +import { secureCompare } from '../agent/auth'; import { HOST_WORKSPACE_NAME } from '../shared/client-types'; type AnyTerminalSession = TerminalSession | HostTerminalSession; @@ -11,6 +12,7 @@ interface TerminalConnection { session: AnyTerminalSession | null; workspaceName: string; started: boolean; + authenticated: boolean; } export interface TerminalHandlerOptions { @@ -18,6 +20,7 @@ export interface TerminalHandlerOptions { isWorkspaceRunning: (workspaceName: string) => Promise; isHostAccessAllowed?: () => boolean; getPreferredShell?: () => string | undefined; + getAuthToken?: () => string | undefined; } export class TerminalHandler { @@ -25,14 +28,16 @@ export class TerminalHandler { private getContainerName: (workspaceName: string) => string; private isHostAccessAllowed: () => boolean; private getPreferredShell: () => string | undefined; + private getAuthToken: () => string | undefined; constructor(options: TerminalHandlerOptions) { this.getContainerName = options.getContainerName; this.isHostAccessAllowed = options.isHostAccessAllowed || (() => false); this.getPreferredShell = options.getPreferredShell || (() => undefined); + this.getAuthToken = options.getAuthToken || (() => undefined); } - handleOpen(ws: ServerWebSocket, workspaceName: string): void { + handleOpen(ws: ServerWebSocket, workspaceName: string, authenticated = true): void { const isHostMode = workspaceName === HOST_WORKSPACE_NAME; if (isHostMode && !this.isHostAccessAllowed()) { @@ -40,11 +45,15 @@ export class TerminalHandler { return; } + const authToken = this.getAuthToken(); + const isAuthenticated = authenticated || !authToken; + const connection: TerminalConnection = { ws, session: null, workspaceName, started: false, + authenticated: isAuthenticated, }; this.connections.set(ws, connection); } @@ -53,6 +62,24 @@ export class TerminalHandler { const connection = this.connections.get(ws); if (!connection) return; + if (!connection.authenticated) { + try { + const message = JSON.parse(data); + if (isAuthMessage(message)) { + const authToken = this.getAuthToken(); + if (authToken && secureCompare(message.token, authToken)) { + connection.authenticated = true; + return; + } + } + } catch { + // invalid JSON, reject + } + ws.close(4001, 'Authentication failed'); + this.connections.delete(ws); + return; + } + if (data.startsWith('{')) { try { const message = JSON.parse(data); diff --git a/src/terminal/types.ts b/src/terminal/types.ts index debb2ed9..6ad13f97 100644 --- a/src/terminal/types.ts +++ b/src/terminal/types.ts @@ -26,3 +26,18 @@ export function isControlMessage(data: unknown): data is ControlMessage { typeof (data as ControlMessage).rows === 'number' ); } + +export interface AuthMessage { + type: 'auth'; + token: string; +} + +export function isAuthMessage(data: unknown): data is AuthMessage { + const msg = data as AuthMessage; + return ( + typeof data === 'object' && + data !== null && + msg.type === 'auth' && + typeof msg.token === 'string' + ); +} diff --git a/test/helpers/agent.ts b/test/helpers/agent.ts index 90e7fe94..e2352c90 100644 --- a/test/helpers/agent.ts +++ b/test/helpers/agent.ts @@ -116,10 +116,18 @@ export async function waitForHealthy(baseUrl: string, timeout = 10000): Promise< return false; } -export function createApiClient(baseUrl: string): ApiClient { +export function createApiClient(baseUrl: string, token?: string): ApiClient { type Client = RouterClient; const link = new RPCLink({ url: `${baseUrl}/rpc`, + fetch: (url, init) => { + const reqInit = init as RequestInit; + const headers = new Headers(reqInit?.headers); + if (token) { + headers.set('Authorization', `Bearer ${token}`); + } + return fetch(url, { ...reqInit, headers }); + }, }); const client = createORPCClient(link); @@ -363,7 +371,7 @@ export async function startTestAgent(options: TestAgentOptions = {}): Promise { }); describe('WebSocket Endpoints', () => { - it('rejects WebSocket upgrade without auth', async () => { - const wsUrl = `${agent.baseUrl.replace('http', 'ws')}/rpc/terminal/test-workspace`; + it('returns 404 for non-existent workspace regardless of auth', async () => { + const wsUrl = `${agent.baseUrl.replace('http', 'ws')}/rpc/terminal/nonexistent-workspace`; const result = await new Promise<{ error: Error | null; code?: number }>((resolve) => { const ws = new WebSocket(wsUrl); @@ -97,32 +97,10 @@ describe('Auth Middleware Integration', () => { }); }); - expect(result.code).toBe(401); - }); - - it('rejects WebSocket upgrade with wrong token', async () => { - const wsUrl = `${agent.baseUrl.replace('http', 'ws')}/rpc/terminal/test-workspace`; - - const result = await new Promise<{ error: Error | null; code?: number }>((resolve) => { - const ws = new WebSocket(wsUrl, { - headers: { Authorization: 'Bearer wrong-token' }, - }); - ws.on('error', (err) => { - resolve({ error: err }); - }); - ws.on('unexpected-response', (_, res) => { - resolve({ error: null, code: res.statusCode }); - }); - ws.on('open', () => { - ws.close(); - resolve({ error: new Error('WebSocket should not have opened') }); - }); - }); - - expect(result.code).toBe(401); + expect(result.code).toBe(404); }); - it('accepts WebSocket upgrade with valid token (returns 404 for non-existent workspace)', async () => { + it('returns 404 for non-existent workspace with valid Bearer token', async () => { const wsUrl = `${agent.baseUrl.replace('http', 'ws')}/rpc/terminal/nonexistent-workspace`; const result = await new Promise<{ error: Error | null; code?: number }>((resolve) => { @@ -137,7 +115,7 @@ describe('Auth Middleware Integration', () => { }); ws.on('open', () => { ws.close(); - resolve({ error: null, code: 200 }); + resolve({ error: new Error('WebSocket should not have opened') }); }); }); diff --git a/test/integration/terminal-auth.test.ts b/test/integration/terminal-auth.test.ts new file mode 100644 index 00000000..2b7ea2df --- /dev/null +++ b/test/integration/terminal-auth.test.ts @@ -0,0 +1,84 @@ +import { describe, it, expect, beforeAll, afterAll } from 'vitest'; +import WebSocket from 'ws'; +import { startTestAgent, type TestAgent } from '../helpers/agent'; + +const TEST_TOKEN = 'test-auth-token-12345'; + +function waitForOpen(ws: WebSocket, timeout = 5000): Promise { + return new Promise((resolve, reject) => { + const timer = setTimeout(() => reject(new Error('Timeout waiting for connection')), timeout); + ws.once('open', () => { + clearTimeout(timer); + resolve(); + }); + ws.once('error', (err) => { + clearTimeout(timer); + reject(err); + }); + }); +} + +function collectMessages(ws: WebSocket, durationMs: number): Promise { + return new Promise((resolve) => { + let output = ''; + const handler = (data: Buffer | string) => { + output += data.toString(); + }; + ws.on('message', handler); + setTimeout(() => { + ws.off('message', handler); + resolve(output); + }, durationMs); + }); +} + +describe('Terminal WebSocket - First Message Auth', () => { + let agent: TestAgent; + let workspaceName: string; + beforeAll(async () => { + agent = await startTestAgent({ + config: { + auth: { token: TEST_TOKEN }, + }, + }); + workspaceName = agent.generateWorkspaceName(); + const result = await agent.api.createWorkspace({ name: workspaceName }); + expect(result.status).toBe(201); + }, 120000); + + afterAll(async () => { + await agent.cleanup(); + }); + + it('authenticates WebSocket via first auth message and can execute a command', async () => { + const wsUrl = `ws://127.0.0.1:${agent.port}/rpc/terminal/${workspaceName}`; + const ws = new WebSocket(wsUrl); + + await waitForOpen(ws, 15000); + ws.send(JSON.stringify({ type: 'auth', token: TEST_TOKEN })); + ws.send(JSON.stringify({ type: 'resize', cols: 80, rows: 24 })); + await new Promise((r) => setTimeout(r, 300)); + + const outputPromise = collectMessages(ws, 2500); + ws.send('echo "FIRST_MSG_AUTH_OK"\n'); + + const output = await outputPromise; + expect(output).toContain('FIRST_MSG_AUTH_OK'); + + ws.close(); + }, 30000); + + it('closes connection with 4001 when sending resize without auth', async () => { + const wsUrl = `ws://127.0.0.1:${agent.port}/rpc/terminal/${workspaceName}`; + const ws = new WebSocket(wsUrl); + + await waitForOpen(ws, 15000); + ws.send(JSON.stringify({ type: 'resize', cols: 80, rows: 24 })); + + const closeCode = await new Promise((resolve) => { + ws.on('close', (code) => resolve(code)); + }); + + expect(closeCode).toBe(4001); + }, 15000); +}); diff --git a/web/src/components/Terminal.tsx b/web/src/components/Terminal.tsx index 5c26972a..0888bf94 100644 --- a/web/src/components/Terminal.tsx +++ b/web/src/components/Terminal.tsx @@ -1,6 +1,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { Ghostty, Terminal as GhosttyTerminal, FitAddon } from 'ghostty-web' -import { getTerminalUrl } from '@/lib/api' +import { getTerminalUrl, getToken } from '@/lib/api' interface TerminalProps { workspaceName: string @@ -137,6 +137,10 @@ function TerminalInstance({ workspaceName, initialCommand, runId }: TerminalProp ws.onopen = () => { if (cancelled.current) return setIsConnected(true) + const token = getToken() + if (token) { + ws.send(JSON.stringify({ type: 'auth', token })) + } const { cols, rows } = cached.terminal ws.send(JSON.stringify({ type: 'resize', cols, rows }))