Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions packages/api/src/controllers/user/oauthEndpoints.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { test } from "node:test";
import { createPglite } from "../../db/pglite.ts";
import { authCodes, clients, organizationMembers, organizations, users } from "../../db/schema.ts";
import { InvalidGrantError, UnauthorizedClientError } from "../../errors.ts";
import { generateEdDSAKeyPair, signJWT, storeKeyPair } from "../../services/jwks.ts";
import { generateEdDSAKeyPair, signJWT, storeKeyPair, verifyJWT } from "../../services/jwks.ts";
import {
createSession,
getActiveRefreshTokenSession,
Expand All @@ -18,7 +18,7 @@ import type { Context } from "../../types.ts";
import { sha256Base64Url } from "../../utils/crypto.ts";
import { postIntrospect } from "./introspect.ts";
import { postRevoke } from "./revoke.ts";
import { postToken } from "./token.ts";
import { postToken, postTokenOrganization } from "./token.ts";
import { handleUserinfo } from "./userinfo.ts";
import { getWellKnownOpenidConfiguration } from "./wellKnownOpenid.ts";

Expand Down Expand Up @@ -127,12 +127,12 @@ async function createUser(context: Context) {
});
}

async function createUserOrganization(context: Context) {
async function createUserOrganization(context: Context, slug = "test-org") {
const [organization] = await context.db
.insert(organizations)
.values({
slug: "test-org",
name: "Test Org",
slug,
name: slug === "test-org" ? "Test Org" : "Second Org",
createdByUserSub: "user-sub",
})
.returning();
Expand Down Expand Up @@ -464,6 +464,61 @@ test("token allows hosted first-party cookie refresh for public SPA clients", as
}
});

test("token organization switch mints target-org tokens from current app token", async () => {
const { context, cleanup } = await createContext();
try {
await createUser(context);
await createUserOrganization(context, "first-org");
const targetOrganization = await createUserOrganization(context, "second-org");
await createPublicRefreshClient(context);
const accessToken = await signJWT(
context,
{
iss: context.config.issuer,
sub: "user-sub",
aud: "public-refresh-client",
azp: "public-refresh-client",
scope: "openid profile",
token_use: "access",
grant_type: "authorization_code",
},
"5m"
);
const request = createRequest({
method: "POST",
url: "/token/organization",
authorization: `Bearer ${accessToken}`,
body: JSON.stringify({
organization_id: targetOrganization.id,
client_id: "public-refresh-client",
}),
});
const response = createResponse();

await postTokenOrganization(context, request, response);

const json = response.json as Record<string, unknown>;
const idTokenClaims = await verifyJWT(context, json.id_token as string, "public-refresh-client");
const accessTokenClaims = await verifyJWT(
context,
json.access_token as string,
"public-refresh-client"
);
const setCookie = response.getHeader("set-cookie");
assert.equal(response.statusCode, 200);
assert.equal(json.token_type, "Bearer");
assert.equal(json.scope, "openid profile");
assert.equal(typeof json.refresh_token, "string");
assert.equal(idTokenClaims.org_id, targetOrganization.id);
assert.equal(idTokenClaims.org_slug, "second-org");
assert.equal(accessTokenClaims.org_id, targetOrganization.id);
assert.ok(Array.isArray(setCookie));
assert.ok(setCookie.some((value) => value.includes(USER_REFRESH_COOKIE_NAME)));
} finally {
await cleanup();
}
});

test("introspect returns active metadata for same-client access tokens", async () => {
const { context, cleanup } = await createContext();
try {
Expand Down
197 changes: 195 additions & 2 deletions packages/api/src/controllers/user/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ import type { IncomingMessage, ServerResponse } from "node:http";
import { eq } from "drizzle-orm";
import { z } from "zod/v4";
import { users } from "../../db/schema.ts";
import { InvalidGrantError, InvalidRequestError, UnauthorizedClientError } from "../../errors.ts";
import {
ForbiddenError,
InvalidGrantError,
InvalidRequestError,
UnauthorizedClientError,
UnauthorizedError,
} from "../../errors.ts";
import { genericErrors } from "../../http/openapi-helpers.ts";
import { getCachedBody, withRateLimit } from "../../middleware/rateLimit.ts";
import { signJWT } from "../../services/jwks.ts";
import { signJWT, verifyJWT } from "../../services/jwks.ts";
import {
clearRefreshTokenCookie,
createSession,
Expand Down Expand Up @@ -41,6 +47,11 @@ import {
} from "../../utils/http.ts";
import { verifyCodeChallenge } from "../../utils/pkce.ts";

const TokenOrganizationRequestSchema = z.object({
organization_id: z.string().uuid(),
client_id: z.string().min(1),
});

async function resolveIssuer(context: Context): Promise<string> {
const issuerSetting = await getSetting(context, "issuer");
if (typeof issuerSetting === "string" && issuerSetting.length > 0) return issuerSetting;
Expand Down Expand Up @@ -163,6 +174,34 @@ function resolveDelegatedPermissions(permissions: string[], grantedScopes: strin
return permissions.filter((permission) => grantedScopeSet.has(permission));
}

function getBearerToken(request: IncomingMessage): string {
const auth = request.headers.authorization;
if (typeof auth !== "string") throw new UnauthorizedError("Bearer token required");
const [scheme, token] = auth.split(" ");
if (!scheme || scheme.toLowerCase() !== "bearer" || !token) {
throw new UnauthorizedError("Bearer token required");
}
return token;
}

function tokenAudienceMatches(audience: unknown, clientId: string): boolean {
if (typeof audience === "string") return audience === clientId;
if (Array.isArray(audience)) return audience.some((item) => item === clientId);
return false;
}

function assertTokenIssuedToClient(payload: import("jose").JWTPayload, clientId: string): void {
if (typeof payload.azp === "string" && payload.azp !== clientId) {
throw new ForbiddenError("Token was not issued to this client");
}
if (!tokenAudienceMatches(payload.aud, clientId)) {
throw new ForbiddenError("Token was not issued to this client");
}
if (payload.grant_type === "client_credentials") {
throw new ForbiddenError("User token required");
}
}

export function resolveSessionClientId(sessionData: unknown): string | null {
if (!sessionData || typeof sessionData !== "object") return null;
const maybeClientId = (sessionData as { clientId?: unknown }).clientId;
Expand Down Expand Up @@ -228,6 +267,160 @@ export const TokenRequestSchema = z.union([
}),
]);

export const postTokenOrganization = withRateLimit("token")(
withAudit({
eventType: "TOKEN_ISSUED",
resourceType: "token",
extractResourceId: (body) =>
body && typeof body === "object" && "client_id" in body
? (body as { client_id?: string }).client_id
: undefined,
})(
async (
context: Context,
request: IncomingMessage,
response: ServerResponse,
..._params: unknown[]
): Promise<void> => {
const body = await getCachedBody(request);
let rawRequest: unknown;
try {
rawRequest = JSON.parse(body);
} catch {
throw new InvalidRequestError("Invalid JSON body");
}
const parsedRequest = TokenOrganizationRequestSchema.safeParse(rawRequest);
if (!parsedRequest.success) {
throw new InvalidRequestError(parsedRequest.error.issues[0]?.message || "Invalid request");
}

const token = getBearerToken(request);
const clientId = parsedRequest.data.client_id;
const requestedOrganizationId = parsedRequest.data.organization_id;
const payload = await verifyJWT(context, token, clientId);
if (typeof payload.sub !== "string" || payload.sub.length === 0) {
throw new UnauthorizedError("User token required");
}
assertTokenIssuedToClient(payload, clientId);

const { getClient } = await import("../../models/clients.ts");
const client = await getClient(context, clientId);
if (!client) throw new UnauthorizedClientError("Unknown client");
if (!client.grantTypes.includes("authorization_code")) {
throw new UnauthorizedClientError("authorization_code grant not allowed for this client");
}

const { getUserBySub } = await import("../../models/users.ts");
const user = await getUserBySub(context, payload.sub);
if (!user) throw new InvalidGrantError("User not found");

const { getUserOrgAccess, resolveOrganizationContext } = await import(
"../../models/rbac.ts"
);
const { organizationId, organizationSlug } = await resolveOrganizationContext(
context,
user.sub,
requestedOrganizationId
);
const { roleKeys, permissions: organizationPermissions } = await getUserOrgAccess(
context,
user.sub,
organizationId
);
const directPermissionRows = await context.db.query.userPermissions.findMany({
where: (table, { eq }) => eq(table.userSub, user.sub),
});
const uniquePermissions = Array.from(
new Set([
...organizationPermissions,
...directPermissionRows.map((row) => row.permissionKey),
])
).sort();

const allowedScopes = resolveClientScopeKeys(client.scopes);
const grantedScope =
typeof payload.scope === "string" && payload.scope.length > 0
? resolveGrantedScopes(allowedScopes, payload.scope).join(" ")
: resolveGrantedScopes(allowedScopes).join(" ");
const grantedScopes = parseScopeString(grantedScope);
const delegatedPermissions = resolveDelegatedPermissions(uniquePermissions, grantedScopes);
const now = Math.floor(Date.now() / 1000);
const idTokenTtl =
client.idTokenLifetimeSeconds && client.idTokenLifetimeSeconds > 0
? client.idTokenLifetimeSeconds
: 300;
const issuer = await resolveIssuer(context);
const idTokenClaims = buildUserIdTokenClaims({
issuer,
subject: user.sub,
audience: clientId,
expiresAtSeconds: now + idTokenTtl,
issuedAtSeconds: now,
email: user.email,
name: user.name,
orgId: organizationId,
orgSlug: organizationSlug,
roles: roleKeys,
permissions: uniquePermissions,
amr: ["pwd"],
});
const idToken = await signJWT(
context,
idTokenClaims as import("jose").JWTPayload,
`${idTokenTtl}s`
);
const accessTokenTtl = resolveAccessTokenLifetimeSeconds(client);
const accessTokenClaims = buildUserAccessTokenClaims({
issuer,
subject: user.sub,
audience: clientId,
authorizedParty: clientId,
expiresAtSeconds: now + accessTokenTtl,
issuedAtSeconds: now,
scope: grantedScope,
grantType: "authorization_code",
orgId: organizationId,
orgSlug: organizationSlug,
roles: roleKeys,
permissions: delegatedPermissions,
});
const accessToken = await signJWT(
context,
accessTokenClaims as import("jose").JWTPayload,
`${accessTokenTtl}s`
);
const tokenResponse: TokenResponse = {
access_token: accessToken,
id_token: idToken,
token_type: "Bearer",
expires_in: accessTokenTtl,
scope: grantedScope,
};

if (shouldIssueRefreshTokenForClient(client.grantTypes)) {
const sessionData = {
sub: user.sub,
email: user.email || undefined,
name: user.name || undefined,
organizationId,
organizationSlug: organizationSlug || undefined,
clientId,
scope: grantedScope,
keyState: "locked",
} satisfies SessionData;
const s = await createSession(context, "user", sessionData);
tokenResponse.refresh_token = s.refreshToken;
const ttlSeconds = await getSessionTtlSeconds(context, "user");
const refreshTtlSeconds = await getRefreshTokenTtlSeconds(context, "user");
issueSessionCookies(response, s.sessionId, ttlSeconds, false);
issueRefreshTokenCookie(response, s.refreshToken, refreshTtlSeconds, false);
}

sendJson(response, 200, tokenResponse);
}
)
);

export const postToken = withRateLimit("token")(
withAudit({
eventType: "TOKEN_ISSUED",
Expand Down
4 changes: 4 additions & 0 deletions packages/api/src/http/createServer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ test("user CORS allows SDK user endpoints for registered public SPA origins", ()
isUserCorsOriginAllowed("/api/user/session", "https://atlas.wylde.net", corsPolicy),
true
);
assert.equal(
isUserCorsOriginAllowed("/api/token/organization", "https://atlas.wylde.net", corsPolicy),
true
);
});

test("user CORS rejects SDK user endpoints for unregistered origins", () => {
Expand Down
2 changes: 2 additions & 0 deletions packages/api/src/http/createServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ function isPublicSpaCorsPath(pathname: string): boolean {
return (
pathname === "/token" ||
pathname === "/api/token" ||
pathname === "/token/organization" ||
pathname === "/api/token/organization" ||
pathname === "/userinfo" ||
pathname === "/api/userinfo" ||
pathname === "/revoke" ||
Expand Down
11 changes: 9 additions & 2 deletions packages/api/src/http/routers/userRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ import {
import { postRevoke } from "../../controllers/user/revoke.ts";
import { getScopeDescriptions } from "../../controllers/user/scopeDescriptions.ts";
import { getSession, postSessionOrganization } from "../../controllers/user/session.ts";
import { postToken } from "../../controllers/user/token.ts";
import { postToken, postTokenOrganization } from "../../controllers/user/token.ts";
import {
getDeviceApprovalRequests,
getTrustedDevices,
Expand Down Expand Up @@ -224,7 +224,10 @@ export function createUserRouter(context: Context) {
"/webauthn/login/finish",
].includes(pathname);
const isOAuthPost =
method === "POST" && ["/token", "/userinfo", "/introspect", "/revoke"].includes(pathname);
method === "POST" &&
["/token", "/token/organization", "/userinfo", "/introspect", "/revoke"].includes(
pathname
);
const isScimRequest = pathname.startsWith("/scim/v2/");
const needsCsrf =
!["GET", "HEAD", "OPTIONS"].includes(method) &&
Expand Down Expand Up @@ -547,6 +550,10 @@ export function createUserRouter(context: Context) {
return await postToken(context, request, response);
}

if (method === "POST" && pathname === "/token/organization") {
return await postTokenOrganization(context, request, response);
}

if ((method === "GET" || method === "POST") && pathname === "/userinfo") {
return await handleUserinfo(context, request, response);
}
Expand Down
Loading
Loading