Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions src/cli/aws/__tests__/agentcore-a2a-bearer.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import { invokeA2ARuntime } from '../agentcore.js';
import type { A2AInvokeOptions } from '../agentcore.js';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

// Mock the SDK so the SigV4 path doesn't need real credentials
const mockSdkSend = vi.fn();
vi.mock('@aws-sdk/client-bedrock-agentcore', () => {
class MockBedrockAgentCoreClient {
send = mockSdkSend;
middlewareStack = { add: vi.fn() };
// eslint-disable-next-line @typescript-eslint/no-empty-function
constructor(_config: unknown) {}
}
return {
BedrockAgentCoreClient: MockBedrockAgentCoreClient,
InvokeAgentRuntimeCommand: vi.fn(),
StopRuntimeSessionCommand: vi.fn(),
EvaluateCommand: vi.fn(),
};
});

vi.mock('../account.js', () => ({
getCredentialProvider: vi
.fn()
.mockReturnValue(() => Promise.resolve({ accessKeyId: 'test', secretAccessKey: 'test' })),
}));

const a2aResultBody = JSON.stringify({
jsonrpc: '2.0',
id: 1,
result: { artifacts: [{ parts: [{ kind: 'text', text: 'Hello from A2A' }] }] },
});

const baseOpts: A2AInvokeOptions = {
region: 'us-east-1',
runtimeArn: 'arn:aws:bedrock-agentcore:us-east-1:123456789:runtime/test-runtime',
userId: 'test-user',
};

async function drain(stream: AsyncGenerator<string, void, unknown>): Promise<string> {
let out = '';
for await (const chunk of stream) out += chunk;
return out;
}

describe('invokeA2ARuntime bearer-token auth path', () => {
let fetchSpy: ReturnType<typeof vi.spyOn>;
let capturedRequests: { url: string; init: RequestInit }[];

beforeEach(() => {
capturedRequests = [];
fetchSpy = vi.spyOn(globalThis, 'fetch').mockImplementation((input, init) => {
capturedRequests.push({ url: input as string, init: init! });
return Promise.resolve({
ok: true,
status: 200,
text: () => Promise.resolve(a2aResultBody),
headers: { get: () => null },
} as unknown as Response);
});
});

afterEach(() => {
fetchSpy.mockRestore();
vi.clearAllMocks();
});

it('uses fetch with Bearer Authorization header and never the SigV4 client', async () => {
const result = await invokeA2ARuntime({ ...baseOpts, bearerToken: 'test-jwt-token' }, 'hi');
const text = await drain(result.stream);

expect(fetchSpy).toHaveBeenCalledTimes(1);
expect(mockSdkSend).not.toHaveBeenCalled();

const headers = capturedRequests[0]!.init.headers as Record<string, string>;
expect(headers.Authorization).toBe('Bearer test-jwt-token');

// JSON-RPC message/send body is carried in the fetch payload
const body = JSON.parse(capturedRequests[0]!.init.body as string);
expect(body.method).toBe('message/send');
expect(body.params.message.parts[0].text).toBe('hi');

// Response is still routed through parseA2AResponse
expect(text).toBe('Hello from A2A');
});

it('falls back to the SigV4 client when no bearerToken is supplied', async () => {
mockSdkSend.mockResolvedValue({
response: { transformToByteArray: () => Promise.resolve(new TextEncoder().encode(a2aResultBody)) },
});

await invokeA2ARuntime(baseOpts, 'hi');

expect(mockSdkSend).toHaveBeenCalledTimes(1);
expect(fetchSpy).not.toHaveBeenCalled();
});
});
25 changes: 23 additions & 2 deletions src/cli/aws/agentcore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,8 @@ export interface A2AInvokeOptions {
logger?: SSELogger;
/** Custom headers to forward to the agent runtime */
headers?: Record<string, string>;
/** Bearer token for CUSTOM_JWT auth. When provided, uses raw HTTP with Authorization header instead of SigV4. */
bearerToken?: string;
}

let a2aRequestId = 1;
Expand All @@ -940,8 +942,6 @@ let a2aRequestId = 1;
* Streams text parts from the response artifacts.
*/
export async function invokeA2ARuntime(options: A2AInvokeOptions, message: string): Promise<StreamingInvokeResult> {
const client = createAgentCoreClient(options.region, options.headers);

const body = {
jsonrpc: '2.0',
id: a2aRequestId++,
Expand All @@ -957,6 +957,27 @@ export async function invokeA2ARuntime(options: A2AInvokeOptions, message: strin

options.logger?.logSSEEvent(`A2A request: ${JSON.stringify(body)}`);

if (options.bearerToken) {
const url = buildInvokeUrl(options.region, options.runtimeArn);
const headers = buildBearerInvokeHeaders(options, 'application/json, text/event-stream');

const res = await fetch(url, { method: 'POST', headers, body: JSON.stringify(body) });
if (!res.ok) {
const errBody = await res.text().catch(() => '');
throw new Error(`Invoke failed (${res.status}): ${errBody || res.statusText}`);
}

const text = await res.text();
options.logger?.logSSEEvent(`A2A response: ${text}`);

return {
stream: singleValueStream(parseA2AResponse(text)),
sessionId: res.headers.get('X-Amzn-Bedrock-AgentCore-Runtime-Session-Id') ?? undefined,
};
}

const client = createAgentCoreClient(options.region, options.headers);

const command = new InvokeAgentRuntimeCommand({
agentRuntimeArn: options.runtimeArn,
payload: new TextEncoder().encode(JSON.stringify(body)),
Expand Down
1 change: 1 addition & 0 deletions src/cli/commands/invoke/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
userId: options.userId,
sessionId: options.sessionId,
headers: options.headers,
bearerToken: options.bearerToken,
},
options.prompt
);
Expand Down
2 changes: 2 additions & 0 deletions src/cli/operations/dev/web-ui/handlers/invocations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ async function handleDeployedInvocation(
prompt,
sessionId,
userId,
bearerToken: resolved.bearerToken,
});
} else if (protocol === 'AGUI') {
await handleDeployedAguiInvocation(ctx, res, origin, {
Expand Down Expand Up @@ -555,6 +556,7 @@ async function handleDeployedA2AInvocation(
runtimeArn: params.runtimeArn,
userId: params.userId,
sessionId: params.sessionId,
bearerToken: params.bearerToken,
},
params.prompt
);
Expand Down
1 change: 1 addition & 0 deletions src/cli/tui/screens/invoke/useInvokeFlow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ export function useInvokeFlow(options: InvokeFlowOptions = {}): InvokeFlowState
sessionId: sessionId ?? undefined,
logger,
headers,
bearerToken: bearerToken || undefined,
},
prompt
)
Expand Down
Loading