diff --git a/apps/code/src/renderer/features/sessions/service/service.test.ts b/apps/code/src/renderer/features/sessions/service/service.test.ts index a47d7ee56..b8c1cb089 100644 --- a/apps/code/src/renderer/features/sessions/service/service.test.ts +++ b/apps/code/src/renderer/features/sessions/service/service.test.ts @@ -296,6 +296,12 @@ describe("SessionService", () => { hasCodeAccess: true, needsScopeReauth: false, }); + mockTrpcAgent.onSessionEvent.subscribe.mockReturnValue({ + unsubscribe: vi.fn(), + }); + mockTrpcAgent.onPermissionRequest.subscribe.mockReturnValue({ + unsubscribe: vi.fn(), + }); mockTrpcCloudTask.onUpdate.subscribe.mockReturnValue({ unsubscribe: vi.fn(), }); @@ -1024,29 +1030,56 @@ describe("SessionService", () => { ); }); - it("sets session to error state on fatal error", async () => { + it("attempts automatic recovery on fatal error", async () => { const service = getSessionService(); - const mockSession = createMockSession(); + const mockSession = createMockSession({ + logUrl: "https://logs.example.com/run-123", + }); mockSessionStoreSetters.getSessionByTaskId.mockReturnValue(mockSession); mockSessionStoreSetters.getSessions.mockReturnValue({ "run-123": { ...mockSession, isPromptPending: false }, }); + mockTrpcWorkspace.verify.query.mockResolvedValue({ exists: true }); + mockTrpcLogs.readLocalLogs.query.mockResolvedValue(""); + mockTrpcAgent.reconnect.mutate.mockResolvedValue({ + sessionId: "run-123", + channel: "agent-event:run-123", + configOptions: [], + }); + + await service.connectToTask({ + task: createMockTask({ + latest_run: { + id: "run-123", + task: "task-123", + team: 123, + environment: "local", + status: "in_progress", + log_url: "https://logs.example.com/run-123", + error_message: null, + output: null, + state: {}, + branch: null, + created_at: "2024-01-01T00:00:00Z", + updated_at: "2024-01-01T00:00:00Z", + completed_at: null, + }, + }), + repoPath: "/repo", + }); + mockTrpcAgent.prompt.mutate.mockRejectedValue( new Error("Internal error: process exited"), ); await expect(service.sendPrompt("task-123", "Hello")).rejects.toThrow(); - - // Check that one of the updateSession calls set status to error - const updateCalls = mockSessionStoreSetters.updateSession.mock.calls as [ - string, - { status?: string }, - ][]; - const errorCall = updateCalls.find( - ([, updates]) => updates.status === "error", + expect(mockSessionStoreSetters.updateSession).toHaveBeenCalledWith( + "run-123", + expect.objectContaining({ + status: "disconnected", + errorMessage: expect.stringContaining("Reconnecting"), + }), ); - expect(errorCall).toBeDefined(); - expect(errorCall?.[0]).toBe("run-123"); }); }); @@ -1363,4 +1396,90 @@ describe("SessionService", () => { ).resolves.not.toThrow(); }); }); + + describe("automatic local recovery", () => { + it("reconnects automatically after a subscription error", async () => { + vi.useFakeTimers(); + const service = getSessionService(); + const mockSession = createMockSession({ + status: "connected", + logUrl: "https://logs.example.com/run-123", + }); + + mockSessionStoreSetters.getSessionByTaskId.mockReturnValue(mockSession); + mockSessionStoreSetters.getSessions.mockReturnValue({ + "run-123": mockSession, + }); + mockTrpcWorkspace.verify.query.mockResolvedValue({ exists: true }); + mockTrpcLogs.readLocalLogs.query.mockResolvedValue(""); + mockTrpcAgent.reconnect.mutate.mockResolvedValue({ + sessionId: "run-123", + channel: "agent-event:run-123", + configOptions: [], + }); + + await service.clearSessionError("task-123", "/repo"); + + const onError = mockTrpcAgent.onSessionEvent.subscribe.mock.calls[0]?.[1] + ?.onError as ((error: Error) => void) | undefined; + expect(onError).toBeDefined(); + + onError?.(new Error("connection dropped")); + await vi.runAllTimersAsync(); + + expect(mockTrpcAgent.reconnect.mutate).toHaveBeenCalledTimes(2); + expect(mockSessionStoreSetters.updateSession).toHaveBeenCalledWith( + "run-123", + expect.objectContaining({ + status: "disconnected", + errorMessage: expect.stringContaining("Reconnecting"), + }), + ); + + vi.useRealTimers(); + }); + + it("shows the error screen only after automatic reconnect attempts fail", async () => { + vi.useFakeTimers(); + const service = getSessionService(); + const mockSession = createMockSession({ + status: "connected", + logUrl: "https://logs.example.com/run-123", + }); + + mockSessionStoreSetters.getSessionByTaskId.mockReturnValue(mockSession); + mockSessionStoreSetters.getSessions.mockReturnValue({ + "run-123": mockSession, + }); + mockTrpcWorkspace.verify.query.mockResolvedValue({ exists: true }); + mockTrpcLogs.readLocalLogs.query.mockResolvedValue(""); + mockTrpcAgent.reconnect.mutate + .mockResolvedValueOnce({ + sessionId: "run-123", + channel: "agent-event:run-123", + configOptions: [], + }) + .mockResolvedValue(null); + + await service.clearSessionError("task-123", "/repo"); + + const onError = mockTrpcAgent.onSessionEvent.subscribe.mock.calls[0]?.[1] + ?.onError as ((error: Error) => void) | undefined; + expect(onError).toBeDefined(); + + onError?.(new Error("connection dropped")); + await vi.runAllTimersAsync(); + + expect(mockTrpcAgent.reconnect.mutate).toHaveBeenCalledTimes(4); + expect(mockSessionStoreSetters.setSession).toHaveBeenCalledWith( + expect.objectContaining({ + status: "error", + errorTitle: "Connection lost", + errorMessage: expect.any(String), + }), + ); + + vi.useRealTimers(); + }); + }); }); diff --git a/apps/code/src/renderer/features/sessions/service/service.ts b/apps/code/src/renderer/features/sessions/service/service.ts index 75d319dd3..7279df08b 100644 --- a/apps/code/src/renderer/features/sessions/service/service.ts +++ b/apps/code/src/renderer/features/sessions/service/service.ts @@ -54,6 +54,7 @@ import { ANALYTICS_EVENTS } from "@shared/types/analytics"; import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud"; import type { AcpMessage, StoredLogEntry } from "@shared/types/session-events"; import { isJsonRpcRequest } from "@shared/types/session-events"; +import { getBackoffDelay } from "@shared/utils/backoff"; import { buildPermissionToolMetadata, track } from "@utils/analytics"; import { logger } from "@utils/logger"; import { @@ -73,6 +74,15 @@ import { } from "@utils/session"; const log = logger.scope("session-service"); +const LOCAL_SESSION_RECONNECT_ATTEMPTS = 3; +const LOCAL_SESSION_RECONNECT_BACKOFF = { + initialDelayMs: 1_000, + maxDelayMs: 5_000, +}; +const LOCAL_SESSION_RECOVERY_MESSAGE = + "Lost connection to the agent. Reconnecting…"; +const LOCAL_SESSION_RECOVERY_FAILED_MESSAGE = + "Connecting to to the agent has been lost. Retry, or start a new session."; /** * Build default configOptions for cloud sessions so the mode switcher @@ -140,6 +150,8 @@ export function resetSessionService(): void { export class SessionService { private connectingTasks = new Map>(); + private localRepoPaths = new Map(); + private localRecoveryAttempts = new Map>(); private nextCloudTaskWatchToken = 0; private subscriptions = new Map< string, @@ -185,6 +197,7 @@ export class SessionService { async connectToTask(params: ConnectParams): Promise { const { task } = params; const taskId = task.id; + this.localRepoPaths.set(taskId, params.repoPath); log.info("Connecting to task", { taskId }); @@ -377,7 +390,7 @@ export class SessionService { sessionId?: string; adapter?: Adapter; }, - ): Promise { + ): Promise { const { rawEntries, sessionId, adapter } = prefetchedLogs ?? (await this.fetchSessionLogs(logUrl, taskRunId)); const events = convertStoredEntriesToEvents(rawEntries); @@ -493,6 +506,7 @@ export class SessionService { ), ); } + return true; } else { log.warn("Reconnect returned null", { taskId, taskRunId }); this.setErrorSession( @@ -501,6 +515,7 @@ export class SessionService { taskTitle, "Session could not be resumed. Please retry or start a new session.", ); + return false; } } catch (error) { const errorMessage = @@ -513,10 +528,13 @@ export class SessionService { errorMessage || "Failed to reconnect. Please retry or start a new session.", ); + return false; } } private async teardownSession(taskRunId: string): Promise { + const session = this.getSessionByRunId(taskRunId); + try { await trpcClient.agent.cancel.mutate({ sessionId: taskRunId }); } catch (error) { @@ -528,6 +546,10 @@ export class SessionService { this.unsubscribeFromChannel(taskRunId); sessionStoreSetters.removeSession(taskRunId); + if (session) { + this.localRepoPaths.delete(session.taskId); + this.localRecoveryAttempts.delete(session.taskId); + } useSessionAdapterStore.getState().removeAdapter(taskRunId); removePersistedConfigOptions(taskRunId); } @@ -579,6 +601,133 @@ export class SessionService { sessionStoreSetters.setSession(session); } + private async tryAutoRecoverLocalSession( + taskId: string, + taskRunId: string, + reason: string, + ): Promise { + const existingRecovery = this.localRecoveryAttempts.get(taskId); + if (existingRecovery) { + return existingRecovery; + } + + const recoveryPromise = this.runAutoRecoverLocalSession( + taskId, + taskRunId, + reason, + ).finally(() => { + this.localRecoveryAttempts.delete(taskId); + }); + + this.localRecoveryAttempts.set(taskId, recoveryPromise); + return recoveryPromise; + } + + private async runAutoRecoverLocalSession( + taskId: string, + taskRunId: string, + reason: string, + ): Promise { + const repoPath = this.localRepoPaths.get(taskId); + const session = sessionStoreSetters.getSessionByTaskId(taskId); + if (!repoPath || !session || session.isCloud) { + return false; + } + + log.warn("Attempting automatic local session recovery", { + taskId, + taskRunId, + reason, + }); + + sessionStoreSetters.updateSession(taskRunId, { + status: "disconnected", + errorTitle: undefined, + errorMessage: LOCAL_SESSION_RECOVERY_MESSAGE, + isPromptPending: false, + isCompacting: false, + promptStartedAt: null, + }); + + for ( + let attempt = 0; + attempt < LOCAL_SESSION_RECONNECT_ATTEMPTS; + attempt++ + ) { + const currentSession = sessionStoreSetters.getSessionByTaskId(taskId); + if (!currentSession || currentSession.taskRunId !== taskRunId) { + return false; + } + + if (attempt > 0) { + const delay = getBackoffDelay( + attempt - 1, + LOCAL_SESSION_RECONNECT_BACKOFF, + ); + await new Promise((resolve) => setTimeout(resolve, delay)); + } + + const recovered = await this.reconnectInPlace(taskId, repoPath); + if (recovered) { + log.info("Automatic local session recovery succeeded", { + taskId, + taskRunId, + attempt: attempt + 1, + }); + return true; + } + } + + const latestSession = sessionStoreSetters.getSessionByTaskId(taskId); + if (latestSession?.taskRunId === taskRunId) { + this.setErrorSession( + taskId, + taskRunId, + latestSession.taskTitle, + LOCAL_SESSION_RECOVERY_FAILED_MESSAGE, + "Connection lost", + ); + } + + log.warn("Automatic local session recovery exhausted", { + taskId, + taskRunId, + }); + + return false; + } + + private startAutoRecoverLocalSession( + taskId: string, + taskRunId: string, + taskTitle: string, + reason: string, + fallbackMessage: string, + ): void { + void this.tryAutoRecoverLocalSession(taskId, taskRunId, reason).then( + (recovered) => { + if (recovered) { + return; + } + + const latestSession = sessionStoreSetters.getSessionByTaskId(taskId); + if (!latestSession || latestSession.taskRunId !== taskRunId) { + return; + } + + if (latestSession.status !== "error") { + this.setErrorSession( + taskId, + taskRunId, + taskTitle, + fallbackMessage, + "Connection lost", + ); + } + }, + ); + } + private async createNewLocalSession( taskId: string, taskTitle: string, @@ -700,11 +849,23 @@ export class SessionService { }, onError: (err) => { log.error("Session subscription error", { taskRunId, error: err }); - sessionStoreSetters.updateSession(taskRunId, { - status: "error", - errorMessage: - "Lost connection to the agent. Please restart the task.", - }); + const session = this.getSessionByRunId(taskRunId); + if (!session || session.isCloud) { + sessionStoreSetters.updateSession(taskRunId, { + status: "error", + errorMessage: + "Lost connection to the agent. Please restart the task.", + }); + return; + } + + this.startAutoRecoverLocalSession( + session.taskId, + taskRunId, + session.taskTitle, + "subscription_error", + "Lost connection to the agent. Please retry or start a new session.", + ); }, }, ); @@ -760,6 +921,8 @@ export class SessionService { } this.connectingTasks.clear(); + this.localRepoPaths.clear(); + this.localRecoveryAttempts.clear(); this.cloudPermissionRequestIds.clear(); this.idleKilledSubscription?.unsubscribe(); this.idleKilledSubscription = null; @@ -1175,21 +1338,19 @@ export class SessionService { sessionStoreSetters.clearOptimisticItems(session.taskRunId); if (isFatalSessionError(errorMessage, errorDetails)) { - log.error("Fatal prompt error, setting session to error state", { + log.error("Fatal prompt error, attempting recovery", { taskRunId: session.taskRunId, errorMessage, errorDetails, }); - sessionStoreSetters.updateSession(session.taskRunId, { - status: "error", - errorMessage: - errorDetails || + this.startAutoRecoverLocalSession( + session.taskId, + session.taskRunId, + session.taskTitle, + errorDetails || errorMessage, + errorDetails || "Session connection lost. Please retry or start a new session.", - isPromptPending: false, - isCompacting: false, - promptStartedAt: null, - initialPrompt: undefined, - }); + ); } else { sessionStoreSetters.updateSession(session.taskRunId, { isPromptPending: false, @@ -1947,6 +2108,7 @@ export class SessionService { * to an empty session. */ async clearSessionError(taskId: string, repoPath: string): Promise { + this.localRepoPaths.set(taskId, repoPath); const session = sessionStoreSetters.getSessionByTaskId(taskId); if (session?.initialPrompt?.length) { const { taskTitle, initialPrompt } = session; @@ -1975,6 +2137,7 @@ export class SessionService { * session instead of attempting to resume the stale one. */ async resetSession(taskId: string, repoPath: string): Promise { + this.localRepoPaths.set(taskId, repoPath); await this.reconnectInPlace(taskId, repoPath, null); } @@ -1992,9 +2155,10 @@ export class SessionService { taskId: string, repoPath: string, overrideSessionId?: string | null, - ): Promise { + ): Promise { + this.localRepoPaths.set(taskId, repoPath); const session = sessionStoreSetters.getSessionByTaskId(taskId); - if (!session) return; + if (!session) return false; const { taskRunId, taskTitle, logUrl } = session; @@ -2020,7 +2184,7 @@ export class SessionService { ? undefined : (overrideSessionId ?? prefetchedLogs.sessionId); - await this.reconnectToLocalSession( + return this.reconnectToLocalSession( taskId, taskRunId, taskTitle, @@ -2607,6 +2771,11 @@ export class SessionService { }; } + private getSessionByRunId(taskRunId: string): AgentSession | undefined { + const sessions = sessionStoreSetters.getSessions(); + return sessions[taskRunId]; + } + private async appendAndPersist( taskId: string, session: AgentSession,