Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ export class ConversationSessionSupervisor {
return pty;
}

detachActive(sessionId: string): Pty | undefined {
const runtime = this.runtimes.get(sessionId);
if (!runtime) return undefined;

runtime.spawnInFlight = undefined;
this.clearRecoveryGraceTimer(runtime);
const pty = runtime.active?.pty;
runtime.active = undefined;
return pty;
}

isDesired(sessionId: string): boolean {
return this.runtimes.get(sessionId)?.desired === true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ const buildCommandMock = vi.hoisted(() =>
);
const installPluginMock = vi.hoisted(() => vi.fn(async () => []));
const writeHooksMock = vi.hoisted(() => vi.fn(async () => []));
const sshConnectionManagerMock = vi.hoisted(() => ({
handlers: [] as Array<(event: { type: string; connectionId: string }) => void>,
}));

vi.mock('@main/core/dependencies/host-dependency-store', () => ({
hostDependencyStore: {
Expand Down Expand Up @@ -94,6 +97,17 @@ vi.mock('@main/core/pty/ssh2-pty', () => ({
openSsh2Pty,
}));

vi.mock('@main/core/ssh/lifecycle/production-ssh-connection-manager', () => ({
sshConnectionManager: {
on: vi.fn(
(_event: string, handler: (event: { type: string; connectionId: string }) => void) => {
sshConnectionManagerMock.handlers.push(handler);
}
),
off: vi.fn(),
},
}));

vi.mock('./keystroke-injection', () => ({
scheduleInitialPromptInjection: vi.fn(),
}));
Expand Down Expand Up @@ -153,6 +167,8 @@ vi.mock('@main/core/settings/settings-service', () => ({
const { events } = await import('@main/lib/events');
const { agentHookService } = await import('@main/core/agent-hooks/agent-hook-service');
const { appSettingsService } = await import('@main/core/settings/settings-service');
const { sshConnectionManager } =
await import('@main/core/ssh/lifecycle/production-ssh-connection-manager');

type RespawnState = {
knownSessionIds: Set<string>;
Expand Down Expand Up @@ -192,9 +208,11 @@ function sshProvider(
{
tmux = false,
ctx = {} as never,
connectionId = 'ssh-1',
}: {
tmux?: boolean;
ctx?: ConstructorParameters<typeof SshConversationProvider>[0]['ctx'];
connectionId?: string;
} = {}
) {
return new SshConversationProvider({
Expand All @@ -204,6 +222,7 @@ function sshProvider(
tmux,
ctx,
proxy: proxy as never,
connectionId,
});
}

Expand Down Expand Up @@ -255,6 +274,8 @@ describe('conversation provider respawn state', () => {
installPluginMock.mockResolvedValue([]);
writeHooksMock.mockReset();
writeHooksMock.mockResolvedValue([]);
sshConnectionManagerMock.handlers.length = 0;
vi.mocked(sshConnectionManager.off).mockClear();
mockSettings();
vi.mocked(events.emit).mockClear();
vi.mocked(agentHookService.getPort).mockReturnValue(0);
Expand Down Expand Up @@ -995,4 +1016,120 @@ describe('conversation provider respawn state', () => {
expect((provider as unknown as RespawnState).sessions.get(sessionId)).toBe(secondPty);
expect(events.emit).not.toHaveBeenCalledWith(agentSessionExitedChannel, expect.anything());
});

it('detaches stale SSH conversations on disconnect and resumes them when connected', async () => {
const firstExitHandlers: Array<(info: PtyExitInfo) => void> = [];
const secondExitHandlers: Array<(info: PtyExitInfo) => void> = [];
const firstPty = fakePty(firstExitHandlers);
const secondPty = fakePty(secondExitHandlers);
openSsh2Pty
.mockResolvedValueOnce({ success: true, data: firstPty })
.mockResolvedValueOnce({ success: true, data: secondPty });
const provider = sshProvider(undefined, { connectionId: 'ssh-1' });
const item = conversation();
const sessionId = makePtySessionId(item.projectId, item.taskId, item.id);

await provider.startSession(item);
expect((provider as unknown as RespawnState).sessions.get(sessionId)).toBe(firstPty);

for (const handler of sshConnectionManagerMock.handlers) {
handler({ type: 'disconnected', connectionId: 'ssh-1' });
}

expect((provider as unknown as RespawnState).sessions.has(sessionId)).toBe(false);
expect(ptySessionRegistry.get(sessionId)).toBeUndefined();
expect(firstPty.kill).toHaveBeenCalledOnce();

for (const handler of sshConnectionManagerMock.handlers) {
handler({ type: 'connected', connectionId: 'ssh-1' });
}
await new Promise((resolve) => setImmediate(resolve));

expect(openSsh2Pty).toHaveBeenCalledTimes(2);
expect((provider as unknown as RespawnState).sessions.get(sessionId)).toBe(secondPty);
expect(ptySessionRegistry.get(sessionId)).toBe(secondPty);
});

it('clears in-flight SSH conversation starts on disconnect so reconnect can resume', async () => {
const firstExitHandlers: Array<(info: PtyExitInfo) => void> = [];
const secondExitHandlers: Array<(info: PtyExitInfo) => void> = [];
const firstPty = fakePty(firstExitHandlers);
const secondPty = fakePty(secondExitHandlers);
let resolveFirstOpen: ((value: { success: true; data: Pty }) => void) | undefined;
openSsh2Pty
.mockImplementationOnce(
() =>
new Promise((resolve) => {
resolveFirstOpen = resolve;
})
)
.mockResolvedValueOnce({ success: true, data: secondPty });
const provider = sshProvider(undefined, { connectionId: 'ssh-1' });
const item = conversation();
const sessionId = makePtySessionId(item.projectId, item.taskId, item.id);

const firstStart = provider.startSession(item);
await new Promise((resolve) => setImmediate(resolve));
expect(openSsh2Pty).toHaveBeenCalledTimes(1);

for (const handler of sshConnectionManagerMock.handlers) {
handler({ type: 'disconnected', connectionId: 'ssh-1' });
}
for (const handler of sshConnectionManagerMock.handlers) {
handler({ type: 'reconnected', connectionId: 'ssh-1' });
}
await new Promise((resolve) => setImmediate(resolve));

expect(openSsh2Pty).toHaveBeenCalledTimes(2);
expect((provider as unknown as RespawnState).sessions.get(sessionId)).toBe(secondPty);

resolveFirstOpen?.({ success: true, data: firstPty });
await firstStart;

expect(firstPty.kill).toHaveBeenCalledOnce();
expect((provider as unknown as RespawnState).sessions.get(sessionId)).toBe(secondPty);
});

it('cancels in-flight SSH rehydrate starts after detachAll', async () => {
const firstExitHandlers: Array<(info: PtyExitInfo) => void> = [];
const rehydratedExitHandlers: Array<(info: PtyExitInfo) => void> = [];
const firstPty = fakePty(firstExitHandlers);
const rehydratedPty = fakePty(rehydratedExitHandlers);
let resolveRehydrateOpen: ((value: { success: true; data: Pty }) => void) | undefined;
openSsh2Pty.mockResolvedValueOnce({ success: true, data: firstPty }).mockImplementationOnce(
() =>
new Promise((resolve) => {
resolveRehydrateOpen = resolve;
})
);
const provider = sshProvider(undefined, { connectionId: 'ssh-1' });
const item = conversation();
const sessionId = makePtySessionId(item.projectId, item.taskId, item.id);

await provider.startSession(item);
for (const handler of sshConnectionManagerMock.handlers) {
handler({ type: 'disconnected', connectionId: 'ssh-1' });
}
for (const handler of sshConnectionManagerMock.handlers) {
handler({ type: 'reconnected', connectionId: 'ssh-1' });
}
await new Promise((resolve) => setImmediate(resolve));

await provider.detachAll();
resolveRehydrateOpen?.({ success: true, data: rehydratedPty });
await new Promise((resolve) => setImmediate(resolve));

expect(openSsh2Pty).toHaveBeenCalledTimes(2);
expect(rehydratedPty.kill).toHaveBeenCalledOnce();
expect((provider as unknown as RespawnState).sessions.has(sessionId)).toBe(false);
expect(ptySessionRegistry.get(sessionId)).toBeUndefined();
});

it('unsubscribes SSH connection listeners when detached', async () => {
const provider = sshProvider(undefined, { connectionId: 'ssh-1' });

await provider.detachAll();

expect(sshConnectionManager.off).toHaveBeenCalledWith('connection-event', expect.any(Function));
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import { openSsh2Pty } from '@main/core/pty/ssh2-pty';
import { getTerminalColorEnv } from '@main/core/pty/terminal-color-scheme';
import { killTmuxSession, makeTmuxSessionName } from '@main/core/pty/tmux-session-name';
import { providerOverrideSettings } from '@main/core/settings/provider-settings-service';
import { sshConnectionManager } from '@main/core/ssh/lifecycle/production-ssh-connection-manager';
import type { SshClientProxy } from '@main/core/ssh/lifecycle/ssh-client-proxy';
import type { SshConnectionManagerEvent } from '@main/core/ssh/lifecycle/ssh-connection-manager';
import { events } from '@main/lib/events';
import { log } from '@main/lib/logger';
import { telemetryService } from '@main/lib/telemetry';
Expand All @@ -36,7 +38,10 @@ function parseExtraArgs(value: string | undefined): string[] {
export class SshConversationProvider implements ConversationProvider {
private sessions = new Map<string, Pty>();
private knownSessionIds = new Set<string>();
private conversations = new Map<string, Conversation>();
private reconnectSizes = new Map<string, { cols: number; rows: number }>();
private supervisor = new ConversationSessionSupervisor();
private detached = false;
private readonly projectId: string;
private readonly taskPath: string;
private readonly taskId: string;
Expand All @@ -45,6 +50,8 @@ export class SshConversationProvider implements ConversationProvider {
private readonly shellSetup?: string;
private readonly ctx: IExecutionContext;
private readonly proxy: SshClientProxy;
private readonly connectionId: string;
private readonly _handleReconnect: (evt: SshConnectionManagerEvent) => void;

constructor({
projectId,
Expand All @@ -55,6 +62,7 @@ export class SshConversationProvider implements ConversationProvider {
shellSetup,
ctx,
proxy,
connectionId,
}: {
projectId: string;
taskPath: string;
Expand All @@ -64,6 +72,7 @@ export class SshConversationProvider implements ConversationProvider {
shellSetup?: string;
ctx: IExecutionContext;
proxy: SshClientProxy;
connectionId: string;
}) {
this.projectId = projectId;
this.taskPath = taskPath;
Expand All @@ -73,6 +82,24 @@ export class SshConversationProvider implements ConversationProvider {
this.shellSetup = shellSetup;
this.ctx = ctx;
this.proxy = proxy;
this.connectionId = connectionId;
this._handleReconnect = (evt: SshConnectionManagerEvent) => {
if (evt.connectionId !== this.connectionId) return;
if (evt.type === 'disconnected') {
this.detachStaleSessionsForReconnect();
return;
}
if (evt.type === 'connected' || evt.type === 'reconnected') {
this.rehydrate().catch((e: unknown) => {
log.error('SshConversationProvider: rehydrate failed after reconnect', {
taskId: this.taskId,
connectionId: this.connectionId,
error: String(e),
});
});
}
};
sshConnectionManager.on('connection-event', this._handleReconnect);
}

async startSession(
Expand All @@ -97,12 +124,14 @@ export class SshConversationProvider implements ConversationProvider {
requireDesired: boolean,
options: { shellRefreshRetried: boolean }
): Promise<void> {
if (this.detached) return;
const sessionId = makePtySessionId(
conversation.projectId,
conversation.taskId,
conversation.id
);
this.knownSessionIds.add(sessionId);
this.conversations.set(sessionId, conversation);

const spawnSize = ptySessionRegistry.getLastSize(sessionId) ?? initialSize;
const spawnToken = this.supervisor.beginStart(sessionId, {
Expand Down Expand Up @@ -271,6 +300,7 @@ export class SshConversationProvider implements ConversationProvider {
},
});
this.sessions.set(sessionId, pty);
this.reconnectSizes.delete(sessionId);
scheduleInitialPromptInjection({
pty,
conversation,
Expand Down Expand Up @@ -312,6 +342,8 @@ export class SshConversationProvider implements ConversationProvider {
this.knownSessionIds.delete(sessionId);
this.supervisor.forget(sessionId);
}
this.conversations.delete(sessionId);
this.reconnectSizes.delete(sessionId);
}

async stopSession(conversationId: string): Promise<void> {
Expand All @@ -334,6 +366,8 @@ export class SshConversationProvider implements ConversationProvider {
await killTmuxSession(this.ctx, makeTmuxSessionName(sessionId));
}
this.supervisor.forget(sessionId);
this.conversations.delete(sessionId);
this.reconnectSizes.delete(sessionId);
}

async destroyAll(): Promise<void> {
Expand All @@ -346,11 +380,17 @@ export class SshConversationProvider implements ConversationProvider {
this.supervisor.forget(sessionId);
}
this.knownSessionIds.clear();
this.conversations.clear();
this.reconnectSizes.clear();
}

async detachAll(): Promise<void> {
for (const [sessionId, pty] of this.sessions) {
this.detached = true;
sshConnectionManager.off('connection-event', this._handleReconnect);
for (const sessionId of this.knownSessionIds) {
this.supervisor.stop(sessionId);
}
for (const [sessionId, pty] of this.sessions) {
try {
pty.kill();
} catch {}
Expand All @@ -359,6 +399,47 @@ export class SshConversationProvider implements ConversationProvider {
this.sessions.clear();
}
Comment thread
janburzinski marked this conversation as resolved.

private detachStaleSessionsForReconnect(): void {
for (const [sessionId] of this.conversations) {
const lastSize = ptySessionRegistry.getLastSize(sessionId);
if (lastSize) this.reconnectSizes.set(sessionId, lastSize);
const pty = this.supervisor.detachActive(sessionId) ?? this.sessions.get(sessionId);
this.sessions.delete(sessionId);
if (!pty) continue;
ptySessionRegistry.unregister(sessionId, { pty });
try {
pty.kill();
} catch (e) {
log.warn('SshConversationProvider: error detaching stale PTY after disconnect', {
sessionId,
error: String(e),
});
}
}
}

private async rehydrate(): Promise<void> {
if (this.detached) return;
await Promise.all(
Array.from(this.conversations.entries()).map(async ([sessionId, conversation]) => {
if (this.detached) return;
if (this.sessions.has(sessionId) || !this.supervisor.isDesired(sessionId)) return;
const initialSize = this.reconnectSizes.get(sessionId) ?? {
cols: DEFAULT_COLS,
rows: DEFAULT_ROWS,
};
await this.startSessionInternal(conversation, initialSize, true, undefined, true, {
shellRefreshRetried: false,
}).catch((e) => {
log.error('SshConversationProvider: rehydrate failed', {
conversationId: conversation.id,
error: String(e),
});
});
})
);
}

private scheduleShellRefreshRetry({
conversation,
sessionId,
Expand Down
Loading
Loading