Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

## [Unreleased]

### Added

- Added support for a custom `userInfo` function in OAuth provider configuration, enabling callers to perform the user info request themselves. The `userInfo` option continues to accept either a URL string or an object with a `url` and optional request options (for example, custom headers). [#182](https://github.com/aura-stack-ts/auth/pull/182)

---

## [0.7.2] - 2026-06-05
Expand Down
42 changes: 42 additions & 0 deletions packages/core/src/@types/oauth.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,45 @@
import type { infer as Infer } from "zod"
import type { User } from "@/@types/session.ts"
import type { LiteralUnion } from "@/@types/utility.ts"
import type { BuiltInOAuthProvider } from "@/oauth/index.ts"
import type { OAuthAccessTokenResponse } from "@/schemas.ts"

export type { BuiltInOAuthProvider } from "@/oauth/index.ts"

export type OAuthAccessTokenResponseType = Infer<typeof OAuthAccessTokenResponse>

export type AccessTokenContext = {
/**
* Access token string returned by the OAuth provider's token endpoint. The token
* must be used to exchange for user information from the provider's userinfo endpoint.
*/
accessToken: string
/**
* The access token type returned by the OAuth provider's token endpoint, typically "Bearer".
*/
tokenType?: string | undefined
/**
* The number of seconds until the access token expires, as returned by the OAuth provider's
* token endpoint.
*/
expiresIn?: number | undefined
/**
* Optional refresh token returned by the OAuth provider's token endpoint, which can be
* used to obtain a new access token when the current one expires.
*/
refreshToken?: string | undefined
/**
* The scopes granted by the user for the access token, as returned by the OAuth provider's
* token endpoint.
*/
scope?: string | string[] | null | undefined
/**
* The userinfo endpoint URL of the OAuth provider. This is required to fetch user
* information using the access token.
*/
userInfoURL: string
}

/** Known query parameter names supported when building an OAuth authorization URL. */
export type AuthorizeParams = LiteralUnion<
"clientId" | "prompt" | "scope" | "responseMode" | "audience" | "loginHint" | "nonce" | "display"
Expand Down Expand Up @@ -43,6 +79,10 @@ export interface OAuthProviderConfig<Profile extends object = Record<string, any
headers?: Record<string, string>
method?: string
}
| {
url: string
request: (context: AccessTokenContext) => Profile | Promise<Profile>
}
/**
* @deprecated
* use `authorize.params.scope` instead of `scope`
Expand Down Expand Up @@ -84,3 +124,5 @@ export type OAuthProviderRecord<DefaultUser extends User = User> = Record<
LiteralUnion<BuiltInOAuthProvider>,
OAuthProviderCredentials<any, DefaultUser>
>

export type CustomUserInfoFunction = Extract<OAuthProviderConfig["userInfo"], { request: (context: AccessTokenContext) => any }>
2 changes: 1 addition & 1 deletion packages/core/src/actions/callback/callback.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ export const callbackAction = (oauth: OAuthProviderRecord) => {
}
}

const userInfo = await getUserInfo(oauthConfig, accessToken.access_token, logger)
const userInfo = await getUserInfo(oauthConfig, accessToken, logger)
const session = await context.sessionStrategy.createSession(userInfo)
const csrfToken = await createCSRF(jose)

Expand Down
75 changes: 62 additions & 13 deletions packages/core/src/actions/callback/userinfo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@ import { fetchAsync } from "@/shared/fetch-async.ts"
import { AURA_AUTH_VERSION } from "@/shared/utils.ts"
import { OAuthErrorResponse } from "@/schemas.ts"
import { isNativeError, isOAuthProtocolError, OAuthProtocolError } from "@/shared/errors.ts"
import type { InternalLogger, OAuthProviderCredentials, User } from "@/@types/index.ts"
import type {
AccessTokenContext,
InternalLogger,
OAuthAccessTokenResponseType,
OAuthProviderCredentials,
User,
} from "@/@types/index.ts"
import { isCustomUserInfoFunction } from "@/shared/assert.ts"

/**
* Map the default user information fields from the OAuth provider's userinfo response
Expand All @@ -24,17 +31,11 @@ const getDefaultUserInfo = (profile: Record<string, string>): User => {
}
}

/**
* Get user information from the OAuth provider's userinfo endpoint using the provided access token.
* The response by default is mapped to the standardized `User` format unless a custom
* `profile` function is provided in the `oauthConfig`.
*
* @param oauthConfig - OAuth provider configuration
* @param accessToken - Access Token to access the userinfo endpoint
* @param logger - Optional logger instance
* @returns The user information retrieved from the userinfo endpoint
*/
export const getUserInfo = async (oauthConfig: OAuthProviderCredentials, accessToken: string, logger?: InternalLogger) => {
type ProviderConfig = {
userInfo: Exclude<OAuthProviderCredentials["userInfo"], { request: (context: AccessTokenContext) => any }>
} & Omit<OAuthProviderCredentials, "userInfo">

const createUserInfoRequest = async (oauthConfig: ProviderConfig, accessToken: string, logger?: InternalLogger) => {
const userInfoConfig = oauthConfig.userInfo
const userinfoURL = typeof userInfoConfig === "string" ? userInfoConfig : userInfoConfig.url
const extraHeaders = typeof userInfoConfig === "string" ? undefined : userInfoConfig.headers
Expand Down Expand Up @@ -74,8 +75,56 @@ export const getUserInfo = async (oauthConfig: OAuthProviderCredentials, accessT
}
logger?.log("OAUTH_USERINFO_SUCCESS")

const userInfo = oauthConfig?.profile ? oauthConfig.profile(json) : getDefaultUserInfo(json)
return json
} catch (error) {
if (isOAuthProtocolError(error)) {
throw error
}
logger?.log("OAUTH_USERINFO_REQUEST_FAILED")
if (isNativeError(error)) {
throw new OAuthProtocolError("SERVER_ERROR", "Failed to fetch user information from OAuth provider", "", {
cause: error,
})
}
throw new OAuthProtocolError("SERVER_ERROR", "Failed to fetch user information", "", { cause: error })
}
}

/**
* Get user information from the OAuth provider's userinfo endpoint using the provided access token.
* The response by default is mapped to the standardized `User` format unless a custom
* `profile` function is provided in the `oauthConfig`.
*
* @param oauthConfig - OAuth provider configuration
* @param accessToken - Access Token to access the userinfo endpoint
* @param logger - Optional logger instance
* @returns The user information retrieved from the userinfo endpoint
*/
export const getUserInfo = async (
oauthConfig: OAuthProviderCredentials,
accessToken: OAuthAccessTokenResponseType,
logger?: InternalLogger
) => {
try {
let userProfile: Record<string, any> = {}
if (isCustomUserInfoFunction(oauthConfig.userInfo)) {
logger?.log("OAUTH_USERINFO_REQUEST_INITIATED", {
structuredData: {
endpoint: oauthConfig.name,
},
})
userProfile = await oauthConfig.userInfo.request({
accessToken: accessToken.access_token,
expiresIn: accessToken?.expires_in,
refreshToken: accessToken?.refresh_token,
scope: accessToken?.scope,
tokenType: accessToken?.token_type,
userInfoURL: oauthConfig.userInfo.url,
})
} else {
userProfile = await createUserInfoRequest(oauthConfig as ProviderConfig, accessToken.access_token, logger)
}
const userInfo = oauthConfig?.profile ? oauthConfig.profile(userProfile) : getDefaultUserInfo(userProfile)
return userInfo
} catch (error) {
if (isOAuthProtocolError(error)) {
Expand Down
14 changes: 14 additions & 0 deletions packages/core/src/shared/assert.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import { Type as TypeboxType } from "typebox"
import type { BaseSchema, ObjectSchema } from "valibot"
import { equals, patternToRegex } from "@/shared/utils.ts"
import type {
AccessTokenContext,
AsymmetricKeyPair,
AsymmetricKeyPairFromEnv,
CryptoSecret,
JWTConfig,
JWTMode,
JWTPayloadWithToken,
OAuthProviderConfig,
SessionConfig,
} from "@/@types/index.ts"
import type { JWK } from "@aura-stack/jose/jose"
Expand Down Expand Up @@ -212,3 +214,15 @@ export const isTypeboxEntries = (value: unknown): value is TypeboxType.TProperti
Object.values(value).every((v) => typeof v === "object" && "type" in v)
)
}

type CustomUserInfoFunction = Extract<OAuthProviderConfig["userInfo"], { request: (context: AccessTokenContext) => any }>

export const isCustomUserInfoFunction = (value: OAuthProviderConfig["userInfo"]): value is CustomUserInfoFunction => {
return (
typeof value === "object" &&
value !== null &&
typeof value.url === "string" &&
"request" in value &&
typeof value.request === "function"
)
}
1 change: 1 addition & 0 deletions packages/core/src/shared/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*/
export { createBasicAuthHeader } from "@/shared/utils.ts"
export { createSyslogMessage } from "@/shared/logger.ts"
export { fetchAsync } from "@/shared/fetch-async.ts"
104 changes: 93 additions & 11 deletions packages/core/test/actions/callback/userinfo.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ describe("getUserInfo", () => {
}))
)

const response = await getUserInfo(oauthCustomService, "access_token_123")
const response = await getUserInfo(oauthCustomService, {
access_token: "access_token_123",
})

expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", {
method: "GET",
Expand Down Expand Up @@ -51,8 +53,9 @@ describe("getUserInfo", () => {
}))
)

const oauthConfig: OAuthProviderConfig<{ username: string; avatar_url: string; uniqueId: string; email: string }> = {
...oauthCustomService,
type Profile = { username: string; avatar_url: string; uniqueId: string; email: string }
const oauthConfig: OAuthProviderConfig<Profile> = {
...(oauthCustomService as unknown as OAuthProviderCredentials<Profile>),
profile(profile) {
return {
sub: profile.uniqueId,
Expand All @@ -63,7 +66,9 @@ describe("getUserInfo", () => {
},
}

const response = await getUserInfo(oauthConfig as OAuthProviderCredentials, "access_token_123")
const response = await getUserInfo(oauthConfig as OAuthProviderCredentials, {
access_token: "access_token_123",
})

expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", {
method: "GET",
Expand Down Expand Up @@ -105,9 +110,11 @@ describe("getUserInfo", () => {
},
}

await expect(getUserInfo(oauthConfig as OAuthProviderCredentials, "access_token_123")).rejects.toThrow(
/Failed to fetch user information from OAuth provider/
)
await expect(
getUserInfo(oauthConfig as OAuthProviderCredentials, {
access_token: "access_token_123",
})
).rejects.toThrow(/Failed to fetch user information from OAuth provider/)

expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", {
method: "GET",
Expand All @@ -134,7 +141,11 @@ describe("getUserInfo", () => {
}))
)

await expect(getUserInfo(oauthCustomService, "invalid_access_token")).rejects.toThrow(/Invalid userinfo response format/)
await expect(
getUserInfo(oauthCustomService, {
access_token: "invalid_access_token",
})
).rejects.toThrow(/Invalid userinfo response format/)

expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", {
method: "GET",
Expand All @@ -155,9 +166,11 @@ describe("getUserInfo", () => {
})
)

await expect(getUserInfo(oauthCustomService, "access_token")).rejects.toThrow(
/Failed to fetch user information from OAuth provider/
)
await expect(
getUserInfo(oauthCustomService, {
access_token: "access_token",
})
).rejects.toThrow(/Failed to fetch user information from OAuth provider/)

expect(fetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", {
method: "GET",
Expand All @@ -169,4 +182,73 @@ describe("getUserInfo", () => {
signal: expect.any(AbortSignal),
})
})

test("with custom userInfo function", async () => {
const mockFetch = vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ name: "John Doe", email: "johndoe@example.com" }),
})

vi.stubGlobal("fetch", mockFetch)

const oauthConfig: OAuthProviderConfig = {
...oauthCustomService,
userInfo: {
url: "https://example.com/oauth/userinfo",
request: async (ctx) => {
const response = await fetch(ctx.userInfoURL, {
method: "GET",
headers: {
"User-Agent": `Aura Auth/${AURA_AUTH_VERSION}`,
},
mode: "no-cors",
})
const json = await response.json()
return {
sub: ctx.accessToken,
name: json.name,
email: json.email,
image: "http://example.com/john-doe.jpg",
}
},
},
}

const profile = await getUserInfo(oauthConfig, {
access_token: "access_token",
})

expect(mockFetch).toHaveBeenCalledWith("https://example.com/oauth/userinfo", {
method: "GET",
headers: {
"User-Agent": `Aura Auth/${AURA_AUTH_VERSION}`,
},
mode: "no-cors",
})

expect(profile).toEqual({
sub: "access_token",
name: "John Doe",
email: "johndoe@example.com",
image: "http://example.com/john-doe.jpg",
})
})

test("custom userInfo function throws error", async () => {
const oauthConfig: OAuthProviderConfig = {
...oauthCustomService,
userInfo: {
url: "https://example.com/oauth/userinfo",
request: async () => {
throw new Error("Custom userInfo error")
},
},
}

await expect(
getUserInfo(oauthConfig, {
access_token: "access_token",
})
).rejects.toThrow(/Failed to fetch user information from OAuth provider/)
})
})
Loading