diff --git a/src/index.ts b/src/index.ts index db3514e..2a9553e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -105,7 +105,7 @@ export class AwsAgentCoreAdapter implements BackendAdapter { } const runtimeArn = this.requiredRuntimeArn(); - const runtimeSessionId = createAgentCoreSessionId(); + const runtimeSessionId = createAgentCoreSessionId(request.task.id); const timestamp = nowIso(); const session: SessionRecord = { id: createId("session"), @@ -149,7 +149,7 @@ export class AwsAgentCoreAdapter implements BackendAdapter { } async cancel(taskId: string): Promise { - const session = this.sessions.get(taskId); + const session = this.sessions.get(taskId) ?? this.fallbackSession(taskId); if (!session) { return { status: "not_found" }; } @@ -215,7 +215,7 @@ export class AwsAgentCoreAdapter implements BackendAdapter { } as any)); await this.waitForEndpoint(agentRuntimeId, endpointName); - const runtimeSessionId = createAgentCoreSessionId(); + const runtimeSessionId = createAgentCoreSessionId(request.task.id); const timestamp = nowIso(); const runtime: RuntimeRecord = { id: createId("runtime"), @@ -320,6 +320,15 @@ export class AwsAgentCoreAdapter implements BackendAdapter { return this.config.runtimeArn; } + private fallbackSession(taskId: string): { runtimeSessionId: string; runtimeArn: string; qualifier?: string; target?: RuntimeTarget } | undefined { + if (!this.config.runtimeArn) return undefined; + return { + runtimeSessionId: createAgentCoreSessionId(taskId), + runtimeArn: this.config.runtimeArn, + qualifier: this.config.qualifier + }; + } + private async waitForEndpoint(agentRuntimeId: string, endpointName: string): Promise { for (let attempt = 0; attempt < 60; attempt += 1) { const endpoint: any = await this.clients.control.send(new GetAgentRuntimeEndpointCommand({ agentRuntimeId, endpointName })); @@ -351,8 +360,8 @@ export class AwsAgentCoreAdapter implements BackendAdapter { } } -function createAgentCoreSessionId(): string { - return createId("agentcore_session"); +function createAgentCoreSessionId(taskId: string): string { + return `ad-${createHash("sha256").update(taskId).digest("hex").slice(0, 32)}`; } function stringDetail(details: Record, key: string): string | undefined { diff --git a/test/adapter.test.ts b/test/adapter.test.ts index 8783454..9c9a7a9 100644 --- a/test/adapter.test.ts +++ b/test/adapter.test.ts @@ -199,10 +199,34 @@ describe("AwsAgentCoreAdapter", () => { const request = createRequest("agent.run", "session"); const target = (await adapter.resolveTarget(request)).target; const task = createTask(request); - await adapter.provision({ dispatch: request, task, target }); + const provisioned = await adapter.provision({ dispatch: request, task, target }); await expect(adapter.cancel(task.id)).resolves.toMatchObject({ status: "cancelled" }); expect(data.commands.map((command) => command.constructor.name)).toContain("StopRuntimeSessionCommand"); + expect(data.commands.find((command) => command.constructor.name === "StopRuntimeSessionCommand").input.runtimeSessionId) + .toBe(provisioned.session?.providerRefs.runtimeSessionId); + }); + + it("can cancel a session-mode task after adapter restart", async () => { + const firstData = new FakeDataClient(); + const firstAdapter = createAdapter(firstData); + const request = createRequest("agent.run", "session"); + const target = (await firstAdapter.resolveTarget(request)).target; + const task = createTask(request); + const provisioned = await firstAdapter.provision({ dispatch: request, task, target }); + + const restartedData = new FakeDataClient(); + const restartedAdapter = createAdapter(restartedData); + await expect(restartedAdapter.cancel(task.id)).resolves.toMatchObject({ + status: "cancelled", + providerRefs: { runtimeSessionId: provisioned.session?.providerRefs.runtimeSessionId } + }); + + const stop = restartedData.commands.find((command) => command.constructor.name === "StopRuntimeSessionCommand"); + expect(stop.input).toMatchObject({ + agentRuntimeArn: "arn:aws:bedrock-agentcore:us-west-2:123456789012:agent/00000000-0000-0000-0000-000000000000:1", + runtimeSessionId: provisioned.session?.providerRefs.runtimeSessionId + }); }); });