Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/auth/auth_credential.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export interface OAuth2Auth {
* verify the state
*/
authUri?: string;
nonce?: string;
state?: string;
codeVerifier?: string;
/**
Expand All @@ -54,8 +55,11 @@ export interface OAuth2Auth {
authCode?: string;
accessToken?: string;
refreshToken?: string;
idToken?: string;
expiresAt?: number;
expiresIn?: number;
audience?: string;
tokenEndpointAuthMethod?: string;
}

/**
Expand Down
88 changes: 85 additions & 3 deletions core/src/auth/auth_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
*/

import {State} from '../sessions/state.js';
import {randomUUID} from '../utils/env_aware_utils.js';

import {AuthCredential} from './auth_credential.js';
import {AuthConfig} from './auth_tool.js';
import {OAuth2CredentialExchanger} from './oauth2/oauth2_credential_exchanger.js';

// TODO(b/425992518): Implement the rest
/**
* A handler that handles the auth flow in Agent Development Kit to help
* orchestrates the credential request and response flow (e.g. OAuth flow)
Expand All @@ -24,6 +25,28 @@ export class AuthHandler {
return state.get<AuthCredential>(credentialKey);
}

async parseAndStoreAuthResponse(state: State): Promise<void> {
const credentialKey = 'temp:' + this.authConfig.credentialKey;

if (this.authConfig.exchangedAuthCredential) {
state.set(credentialKey, this.authConfig.exchangedAuthCredential);
}

const authSchemeType = this.authConfig.authScheme.type;
if (!['oauth2', 'openIdConnect'].includes(authSchemeType)) {
return;
}

if (this.authConfig.exchangedAuthCredential) {
const exchanger = new OAuth2CredentialExchanger();
const exchangedCredential = await exchanger.exchange({
authCredential: this.authConfig.exchangedAuthCredential,
authScheme: this.authConfig.authScheme,
});
state.set(credentialKey, exchangedCredential.credential);
}
}

generateAuthRequest(): AuthConfig {
const authSchemeType = this.authConfig.authScheme.type;

Expand Down Expand Up @@ -79,7 +102,66 @@ export class AuthHandler {
* auth scheme.
*/
generateAuthUri(): AuthCredential | undefined {
return this.authConfig.rawAuthCredential;
// TODO - b/425992518: Implement the rest of the function
const authScheme = this.authConfig.authScheme;
const authCredential = this.authConfig.rawAuthCredential;

if (!authCredential || !authCredential.oauth2) {
return authCredential;
}

let authorizationEndpoint = '';
let scopes: string[] = [];

if ('authorizationEndpoint' in authScheme) {
authorizationEndpoint = authScheme.authorizationEndpoint;
scopes = authScheme.scopes || [];
} else if (authScheme.type === 'oauth2' && authScheme.flows) {
const flows = authScheme.flows;
const flow =
flows.implicit ||
flows.authorizationCode ||
flows.clientCredentials ||
flows.password;

if (flow) {
if ('authorizationUrl' in flow && flow.authorizationUrl) {
authorizationEndpoint = flow.authorizationUrl;
} else if ('tokenUrl' in flow && flow.tokenUrl) {
authorizationEndpoint = flow.tokenUrl;
}

if (flow.scopes) {
scopes = Object.keys(flow.scopes);
}
}
}

if (!authorizationEndpoint) {
throw new Error('Authorization endpoint not configured in auth scheme.');
}

const state = randomUUID();
const url = new URL(authorizationEndpoint);
url.searchParams.set('client_id', authCredential.oauth2.clientId || '');
url.searchParams.set(
'redirect_uri',
authCredential.oauth2.redirectUri || '',
);
url.searchParams.set('response_type', 'code');
url.searchParams.set('scope', scopes.join(' '));
url.searchParams.set('state', state);
url.searchParams.set('access_type', 'offline');
url.searchParams.set('prompt', 'consent');

const exchangedAuthCredential: AuthCredential = {
...authCredential,
oauth2: {
...authCredential.oauth2,
authUri: url.toString(),
state,
},
};

return exchangedAuthCredential;
}
}
191 changes: 191 additions & 0 deletions core/src/auth/auth_preprocessor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import {
REQUEST_EUC_FUNCTION_CALL_NAME,
handleFunctionCallsAsync,
} from '../agents/functions.js';
import {InvocationContext} from '../agents/invocation_context.js';
import {isLlmAgent} from '../agents/llm_agent.js';
import {BaseLlmRequestProcessor} from '../agents/processors/base_llm_processor.js';
import {ReadonlyContext} from '../agents/readonly_context.js';
import {
Event,
getFunctionCalls,
getFunctionResponses,
} from '../events/event.js';
import {State} from '../sessions/state.js';
import {BaseTool} from '../tools/base_tool.js';
import {AuthHandler} from './auth_handler.js';
import {AuthConfig, AuthToolArguments} from './auth_tool.js';

const TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_';

async function storeAuthAndCollectResumeTargets(
events: Event[],
authFcIds: Set<string>,
authResponses: Record<string, unknown>,
state: State,
): Promise<Set<string>> {
const requestedAuthConfigById: Record<string, AuthConfig> = {};
for (const event of events) {
const eventFunctionCalls = getFunctionCalls(event);
for (const functionCall of eventFunctionCalls) {
if (
functionCall.id &&
authFcIds.has(functionCall.id) &&
functionCall.name === REQUEST_EUC_FUNCTION_CALL_NAME
) {
const args = functionCall.args as unknown as AuthToolArguments;
if (args && args.authConfig) {
requestedAuthConfigById[functionCall.id] = args.authConfig;
}
}
}
}

for (const fcId of authFcIds) {
if (!(fcId in authResponses)) {
continue;
}
const authConfig = authResponses[fcId] as AuthConfig;
const requestedAuthConfig = requestedAuthConfigById[fcId];
if (requestedAuthConfig && requestedAuthConfig.credentialKey) {
authConfig.credentialKey = requestedAuthConfig.credentialKey;
}
await new AuthHandler(authConfig).parseAndStoreAuthResponse(state);
}

const toolsToResume: Set<string> = new Set();
for (const fcId of authFcIds) {
const requestedAuthConfig = requestedAuthConfigById[fcId];
if (!requestedAuthConfig) {
continue;
}
for (const event of events) {
const eventFunctionCalls = getFunctionCalls(event);
for (const functionCall of eventFunctionCalls) {
if (
functionCall.id === fcId &&
functionCall.name === REQUEST_EUC_FUNCTION_CALL_NAME
) {
const args = functionCall.args as unknown as AuthToolArguments;
if (args && args.functionCallId) {
if (
args.functionCallId.startsWith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX)
) {
continue;
}
toolsToResume.add(args.functionCallId);
}
}
}
}
}

return toolsToResume;
}

export class AuthPreprocessor extends BaseLlmRequestProcessor {
override async *runAsync(
invocationContext: InvocationContext,
): AsyncGenerator<Event, void, void> {
const agent = invocationContext.agent;
if (!isLlmAgent(agent)) {
return;
}

const events = invocationContext.session.events;
if (!events || events.length === 0) {
return;
}

let lastEventWithContent = null;
for (let i = events.length - 1; i >= 0; i--) {
const event = events[i];
if (event.content !== undefined) {
lastEventWithContent = event;
break;
}
}

if (!lastEventWithContent || lastEventWithContent.author !== 'user') {
return;
}

const responses = getFunctionResponses(lastEventWithContent);
if (!responses || responses.length === 0) {
return;
}

const authFcIds: Set<string> = new Set();
const authResponses: Record<string, unknown> = {};

for (const functionCallResponse of responses) {
if (functionCallResponse.name !== REQUEST_EUC_FUNCTION_CALL_NAME) {
continue;
}
if (functionCallResponse.id) {
authFcIds.add(functionCallResponse.id);
authResponses[functionCallResponse.id] = functionCallResponse.response;
}
}

if (authFcIds.size === 0) {
return;
}

const state = new State(invocationContext.session.state);
const toolsToResume = await storeAuthAndCollectResumeTargets(
events,
authFcIds,
authResponses,
state,
);

if (toolsToResume.size === 0) {
return;
}

for (let i = events.length - 2; i >= 0; i--) {
const event = events[i];
const functionCalls = getFunctionCalls(event);
if (!functionCalls || functionCalls.length === 0) {
continue;
}

const hasMatchingCall = functionCalls.some((call) =>
call.id ? toolsToResume.has(call.id) : false,
);

if (hasMatchingCall) {
const canonicalTools = await agent.canonicalTools(
new ReadonlyContext(invocationContext),
);
const toolsDict: Record<string, BaseTool> = {};
for (const tool of canonicalTools) {
toolsDict[tool.name] = tool;
}

const functionResponseEvent = await handleFunctionCallsAsync({
invocationContext,
functionCallEvent: event,
toolsDict,
beforeToolCallbacks: agent.canonicalBeforeToolCallbacks,
afterToolCallbacks: agent.canonicalAfterToolCallbacks,
filters: toolsToResume,
});

if (functionResponseEvent) {
yield functionResponseEvent;
}
return;
}
}
}
}

export const AUTH_PREPROCESSOR = new AuthPreprocessor();
2 changes: 2 additions & 0 deletions core/src/auth/oauth2/oauth2_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ interface OAuth2TokenResponse {
access_token?: string;
refresh_token?: string;
expires_in?: number;
id_token?: string;
}

/**
Expand Down Expand Up @@ -67,6 +68,7 @@ export async function fetchOAuth2Tokens(
return {
accessToken: data.access_token,
refreshToken: data.refresh_token,
idToken: data.id_token,
expiresIn: data.expires_in,
expiresAt: data.expires_in
? Date.now() + data.expires_in * 1000
Expand Down
Loading
Loading