diff --git a/packages/core/CHANGELOG.md b/packages/core/CHANGELOG.md index 87f79390..efa32ac9 100644 --- a/packages/core/CHANGELOG.md +++ b/packages/core/CHANGELOG.md @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ## [Unreleased] +### Added + +- Added support for a custom `userInfo` function in OAuth provider configuration, enabling callers to perform the user info request themselves. The `userInfo` option continues to accept either a URL string or an object with a `url` and optional request options (for example, custom headers). [#182](https://github.com/aura-stack-ts/auth/pull/182) + --- ## [0.7.2] - 2026-06-05 diff --git a/packages/core/src/@types/oauth.ts b/packages/core/src/@types/oauth.ts index bbeea67b..bca81427 100644 --- a/packages/core/src/@types/oauth.ts +++ b/packages/core/src/@types/oauth.ts @@ -1,9 +1,45 @@ +import type { infer as Infer } from "zod" import type { User } from "@/@types/session.ts" import type { LiteralUnion } from "@/@types/utility.ts" import type { BuiltInOAuthProvider } from "@/oauth/index.ts" +import type { OAuthAccessTokenResponse } from "@/schemas.ts" export type { BuiltInOAuthProvider } from "@/oauth/index.ts" +export type OAuthAccessTokenResponseType = Infer + +export type AccessTokenContext = { + /** + * Access token string returned by the OAuth provider's token endpoint. The token + * must be used to exchange for user information from the provider's userinfo endpoint. + */ + accessToken: string + /** + * The access token type returned by the OAuth provider's token endpoint, typically "Bearer". + */ + tokenType?: string | undefined + /** + * The number of seconds until the access token expires, as returned by the OAuth provider's + * token endpoint. + */ + expiresIn?: number | undefined + /** + * Optional refresh token returned by the OAuth provider's token endpoint, which can be + * used to obtain a new access token when the current one expires. + */ + refreshToken?: string | undefined + /** + * The scopes granted by the user for the access token, as returned by the OAuth provider's + * token endpoint. + */ + scope?: string | string[] | null | undefined + /** + * The userinfo endpoint URL of the OAuth provider. This is required to fetch user + * information using the access token. + */ + userInfoURL: string +} + /** Known query parameter names supported when building an OAuth authorization URL. */ export type AuthorizeParams = LiteralUnion< "clientId" | "prompt" | "scope" | "responseMode" | "audience" | "loginHint" | "nonce" | "display" @@ -43,6 +79,10 @@ export interface OAuthProviderConfig method?: string } + | { + url: string + request: (context: AccessTokenContext) => Profile | Promise + } /** * @deprecated * use `authorize.params.scope` instead of `scope` @@ -84,3 +124,5 @@ export type OAuthProviderRecord = Record< LiteralUnion, OAuthProviderCredentials > + +export type CustomUserInfoFunction = Extract any }> diff --git a/packages/core/src/actions/callback/callback.ts b/packages/core/src/actions/callback/callback.ts index 6ccb1cc0..dcb9b2a4 100644 --- a/packages/core/src/actions/callback/callback.ts +++ b/packages/core/src/actions/callback/callback.ts @@ -124,7 +124,7 @@ export const callbackAction = (oauth: OAuthProviderRecord) => { } } - const userInfo = await getUserInfo(oauthConfig, accessToken.access_token, logger) + const userInfo = await getUserInfo(oauthConfig, accessToken, logger) const session = await context.sessionStrategy.createSession(userInfo) const csrfToken = await createCSRF(jose) diff --git a/packages/core/src/actions/callback/userinfo.ts b/packages/core/src/actions/callback/userinfo.ts index 87ab8b88..c7148d79 100644 --- a/packages/core/src/actions/callback/userinfo.ts +++ b/packages/core/src/actions/callback/userinfo.ts @@ -2,7 +2,14 @@ import { fetchAsync } from "@/shared/fetch-async.ts" import { AURA_AUTH_VERSION } from "@/shared/utils.ts" import { OAuthErrorResponse } from "@/schemas.ts" import { isNativeError, isOAuthProtocolError, OAuthProtocolError } from "@/shared/errors.ts" -import type { InternalLogger, OAuthProviderCredentials, User } from "@/@types/index.ts" +import type { + AccessTokenContext, + InternalLogger, + OAuthAccessTokenResponseType, + OAuthProviderCredentials, + User, +} from "@/@types/index.ts" +import { isCustomUserInfoFunction } from "@/shared/assert.ts" /** * Map the default user information fields from the OAuth provider's userinfo response @@ -24,17 +31,11 @@ const getDefaultUserInfo = (profile: Record): User => { } } -/** - * Get user information from the OAuth provider's userinfo endpoint using the provided access token. - * The response by default is mapped to the standardized `User` format unless a custom - * `profile` function is provided in the `oauthConfig`. - * - * @param oauthConfig - OAuth provider configuration - * @param accessToken - Access Token to access the userinfo endpoint - * @param logger - Optional logger instance - * @returns The user information retrieved from the userinfo endpoint - */ -export const getUserInfo = async (oauthConfig: OAuthProviderCredentials, accessToken: string, logger?: InternalLogger) => { +type ProviderConfig = { + userInfo: Exclude any }> +} & Omit + +const createUserInfoRequest = async (oauthConfig: ProviderConfig, accessToken: string, logger?: InternalLogger) => { const userInfoConfig = oauthConfig.userInfo const userinfoURL = typeof userInfoConfig === "string" ? userInfoConfig : userInfoConfig.url const extraHeaders = typeof userInfoConfig === "string" ? undefined : userInfoConfig.headers @@ -74,8 +75,56 @@ export const getUserInfo = async (oauthConfig: OAuthProviderCredentials, accessT } logger?.log("OAUTH_USERINFO_SUCCESS") - const userInfo = oauthConfig?.profile ? oauthConfig.profile(json) : getDefaultUserInfo(json) + return json + } catch (error) { + if (isOAuthProtocolError(error)) { + throw error + } + logger?.log("OAUTH_USERINFO_REQUEST_FAILED") + if (isNativeError(error)) { + throw new OAuthProtocolError("SERVER_ERROR", "Failed to fetch user information from OAuth provider", "", { + cause: error, + }) + } + throw new OAuthProtocolError("SERVER_ERROR", "Failed to fetch user information", "", { cause: error }) + } +} +/** + * Get user information from the OAuth provider's userinfo endpoint using the provided access token. + * The response by default is mapped to the standardized `User` format unless a custom + * `profile` function is provided in the `oauthConfig`. + * + * @param oauthConfig - OAuth provider configuration + * @param accessToken - Access Token to access the userinfo endpoint + * @param logger - Optional logger instance + * @returns The user information retrieved from the userinfo endpoint + */ +export const getUserInfo = async ( + oauthConfig: OAuthProviderCredentials, + accessToken: OAuthAccessTokenResponseType, + logger?: InternalLogger +) => { + try { + let userProfile: Record = {} + if (isCustomUserInfoFunction(oauthConfig.userInfo)) { + logger?.log("OAUTH_USERINFO_REQUEST_INITIATED", { + structuredData: { + endpoint: oauthConfig.name, + }, + }) + userProfile = await oauthConfig.userInfo.request({ + accessToken: accessToken.access_token, + expiresIn: accessToken?.expires_in, + refreshToken: accessToken?.refresh_token, + scope: accessToken?.scope, + tokenType: accessToken?.token_type, + userInfoURL: oauthConfig.userInfo.url, + }) + } else { + userProfile = await createUserInfoRequest(oauthConfig as ProviderConfig, accessToken.access_token, logger) + } + const userInfo = oauthConfig?.profile ? oauthConfig.profile(userProfile) : getDefaultUserInfo(userProfile) return userInfo } catch (error) { if (isOAuthProtocolError(error)) { diff --git a/packages/core/src/shared/assert.ts b/packages/core/src/shared/assert.ts index 6164209f..93e95b1a 100644 --- a/packages/core/src/shared/assert.ts +++ b/packages/core/src/shared/assert.ts @@ -4,12 +4,14 @@ import { Type as TypeboxType } from "typebox" import type { BaseSchema, ObjectSchema } from "valibot" import { equals, patternToRegex } from "@/shared/utils.ts" import type { + AccessTokenContext, AsymmetricKeyPair, AsymmetricKeyPairFromEnv, CryptoSecret, JWTConfig, JWTMode, JWTPayloadWithToken, + OAuthProviderConfig, SessionConfig, } from "@/@types/index.ts" import type { JWK } from "@aura-stack/jose/jose" @@ -212,3 +214,15 @@ export const isTypeboxEntries = (value: unknown): value is TypeboxType.TProperti Object.values(value).every((v) => typeof v === "object" && "type" in v) ) } + +type CustomUserInfoFunction = Extract any }> + +export const isCustomUserInfoFunction = (value: OAuthProviderConfig["userInfo"]): value is CustomUserInfoFunction => { + return ( + typeof value === "object" && + value !== null && + typeof value.url === "string" && + "request" in value && + typeof value.request === "function" + ) +} diff --git a/packages/core/src/shared/index.ts b/packages/core/src/shared/index.ts index ef0498a5..f2bc73ff 100644 --- a/packages/core/src/shared/index.ts +++ b/packages/core/src/shared/index.ts @@ -3,3 +3,4 @@ */ export { createBasicAuthHeader } from "@/shared/utils.ts" export { createSyslogMessage } from "@/shared/logger.ts" +export { fetchAsync } from "@/shared/fetch-async.ts" diff --git a/packages/core/test/actions/callback/userinfo.test.ts b/packages/core/test/actions/callback/userinfo.test.ts index a8c4d8fc..f3b73add 100644 --- a/packages/core/test/actions/callback/userinfo.test.ts +++ b/packages/core/test/actions/callback/userinfo.test.ts @@ -21,7 +21,9 @@ describe("getUserInfo", () => { })) ) - const response = await getUserInfo(oauthCustomService, "access_token_123") + const response = await getUserInfo(oauthCustomService, { + access_token: "access_token_123", + }) expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", { method: "GET", @@ -51,8 +53,9 @@ describe("getUserInfo", () => { })) ) - const oauthConfig: OAuthProviderConfig<{ username: string; avatar_url: string; uniqueId: string; email: string }> = { - ...oauthCustomService, + type Profile = { username: string; avatar_url: string; uniqueId: string; email: string } + const oauthConfig: OAuthProviderConfig = { + ...(oauthCustomService as unknown as OAuthProviderCredentials), profile(profile) { return { sub: profile.uniqueId, @@ -63,7 +66,9 @@ describe("getUserInfo", () => { }, } - const response = await getUserInfo(oauthConfig as OAuthProviderCredentials, "access_token_123") + const response = await getUserInfo(oauthConfig as OAuthProviderCredentials, { + access_token: "access_token_123", + }) expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", { method: "GET", @@ -105,9 +110,11 @@ describe("getUserInfo", () => { }, } - await expect(getUserInfo(oauthConfig as OAuthProviderCredentials, "access_token_123")).rejects.toThrow( - /Failed to fetch user information from OAuth provider/ - ) + await expect( + getUserInfo(oauthConfig as OAuthProviderCredentials, { + access_token: "access_token_123", + }) + ).rejects.toThrow(/Failed to fetch user information from OAuth provider/) expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", { method: "GET", @@ -134,7 +141,11 @@ describe("getUserInfo", () => { })) ) - await expect(getUserInfo(oauthCustomService, "invalid_access_token")).rejects.toThrow(/Invalid userinfo response format/) + await expect( + getUserInfo(oauthCustomService, { + access_token: "invalid_access_token", + }) + ).rejects.toThrow(/Invalid userinfo response format/) expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", { method: "GET", @@ -155,9 +166,11 @@ describe("getUserInfo", () => { }) ) - await expect(getUserInfo(oauthCustomService, "access_token")).rejects.toThrow( - /Failed to fetch user information from OAuth provider/ - ) + await expect( + getUserInfo(oauthCustomService, { + access_token: "access_token", + }) + ).rejects.toThrow(/Failed to fetch user information from OAuth provider/) expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", { method: "GET", @@ -169,4 +182,73 @@ describe("getUserInfo", () => { signal: expect.any(AbortSignal), }) }) + + test("with custom userInfo function", async () => { + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ name: "John Doe", email: "johndoe@example.com" }), + }) + + vi.stubGlobal("fetch", mockFetch) + + const oauthConfig: OAuthProviderConfig = { + ...oauthCustomService, + userInfo: { + url: "https://example.com/oauth/userinfo", + request: async (ctx) => { + const response = await fetch(ctx.userInfoURL, { + method: "GET", + headers: { + "User-Agent": `Aura Auth/${AURA_AUTH_VERSION}`, + }, + mode: "no-cors", + }) + const json = await response.json() + return { + sub: ctx.accessToken, + name: json.name, + email: json.email, + image: "http://example.com/john-doe.jpg", + } + }, + }, + } + + const profile = await getUserInfo(oauthConfig, { + access_token: "access_token", + }) + + expect(mockFetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", { + method: "GET", + headers: { + "User-Agent": `Aura Auth/${AURA_AUTH_VERSION}`, + }, + mode: "no-cors", + }) + + expect(profile).toEqual({ + sub: "access_token", + name: "John Doe", + email: "johndoe@example.com", + image: "http://example.com/john-doe.jpg", + }) + }) + + test("custom userInfo function throws error", async () => { + const oauthConfig: OAuthProviderConfig = { + ...oauthCustomService, + userInfo: { + url: "https://example.com/oauth/userinfo", + request: async () => { + throw new Error("Custom userInfo error") + }, + }, + } + + await expect( + getUserInfo(oauthConfig, { + access_token: "access_token", + }) + ).rejects.toThrow(/Failed to fetch user information from OAuth provider/) + }) })