diff --git a/.env.example b/.env.example index 9b4d310..39b7c27 100644 --- a/.env.example +++ b/.env.example @@ -2,6 +2,7 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5432/scrawn CLICKHOUSE_URL=http://default:clickhouse@localhost:8123/scrawn HMAC_SECRET= +MASTER_API_KEY_HASH= # hex(HMAC-SHA256(HMAC_SECRET, )) — computed with HMAC_SECRET as the signing key (used for project creation during onboarding) APP_URL=http://localhost:8060 # URL the Scrawn backend is hosted on # SENTRY_DSN= diff --git a/src/__tests__/auth.test.ts b/src/__tests__/auth.test.ts index 71819ac..f8bd322 100644 --- a/src/__tests__/auth.test.ts +++ b/src/__tests__/auth.test.ts @@ -15,11 +15,12 @@ import { getPostgresDB } from "../storage/db/postgres/db"; import { webhookEndpointsTable } from "../storage/db/postgres/schema"; import { DateTime } from "luxon"; import { clearDatabase } from "./db"; -import { insertKey } from "./fixtures/apiKey"; +import { insertKey, TEST_PROJECT_ID } from "./fixtures/apiKey"; async function insertWebhookEndpoint(apiKeyId: string): Promise { const db = getPostgresDB(); await db.insert(webhookEndpointsTable).values({ + projectId: TEST_PROJECT_ID, apiKeyId, url: "https://example.com/webhook", privateKey: "test-private-key", diff --git a/src/__tests__/createAPIKey.test.ts b/src/__tests__/createAPIKey.test.ts index e0a2521..e691a3a 100644 --- a/src/__tests__/createAPIKey.test.ts +++ b/src/__tests__/createAPIKey.test.ts @@ -14,7 +14,7 @@ import { registerEvent, } from "./fixtures/grpc"; import { verifyApiKeyCreated } from "./assertions/events"; -import { createTestApiKey } from "./fixtures/apiKey"; +import { createTestApiKey, TEST_PROJECT_ID } from "./fixtures/apiKey"; import { getPostgresDB } from "../storage/db/postgres/db"; import { hashAPIKey } from "../utils/hashAPIKey"; import { @@ -44,6 +44,7 @@ async function createDashboardApiKey(): Promise<{ key: hashAPIKey(rawKey), role: "dashboard", expiresAt: DateTime.utc().plus({ years: 1 }).toISO(), + projectId: TEST_PROJECT_ID, }) .returning({ id: apiKeysTable.id }); return { rawKey, id: key!.id }; @@ -149,6 +150,7 @@ describe("AuthService", () => { const db = getPostgresDB(); await db.insert(webhookEndpointsTable).values({ + projectId: TEST_PROJECT_ID, apiKeyId: res.apiKeyId, url: "https://example.com/webhook", privateKey: "test-private-key", diff --git a/src/__tests__/db/index.ts b/src/__tests__/db/index.ts index ada334a..c8178e8 100644 --- a/src/__tests__/db/index.ts +++ b/src/__tests__/db/index.ts @@ -31,7 +31,8 @@ export async function clearDatabase() { users, tags, metadata, - expressions + expressions, + projects RESTART IDENTITY CASCADE `); diff --git a/src/__tests__/fixtures/apiKey.ts b/src/__tests__/fixtures/apiKey.ts index dbdb718..5b41df9 100644 --- a/src/__tests__/fixtures/apiKey.ts +++ b/src/__tests__/fixtures/apiKey.ts @@ -1,16 +1,35 @@ import { getPostgresDB } from "../../storage/db/postgres/db"; import { + projectsTable, apiKeysTable, webhookEndpointsTable, } from "../../storage/db/postgres/schema"; +import { eq } from "drizzle-orm"; import { hashAPIKey } from "../../utils/hashAPIKey"; import { DateTime } from "luxon"; +export const TEST_PROJECT_ID = "00000000-0000-0000-0000-000000000001"; + +async function ensureTestProject(): Promise { + const db = getPostgresDB(); + const [existing] = await db + .select({ id: projectsTable.id }) + .from(projectsTable) + .where(eq(projectsTable.id, TEST_PROJECT_ID)) + .limit(1); + if (existing) return; + await db.insert(projectsTable).values({ + id: TEST_PROJECT_ID, + name: "test-project", + }); +} + export async function createTestApiKey(): Promise<{ rawKey: string; id: string; }> { const db = getPostgresDB(); + await ensureTestProject(); const rawKey = `scrn_test_${crypto.randomUUID().replace(/-/g, "").slice(0, 32)}`; const [key] = await db .insert(apiKeysTable) @@ -19,10 +38,12 @@ export async function createTestApiKey(): Promise<{ key: hashAPIKey(rawKey), role: "test", expiresAt: DateTime.utc().plus({ years: 1 }).toISO(), + projectId: TEST_PROJECT_ID, }) .returning({ id: apiKeysTable.id }); await db.insert(webhookEndpointsTable).values({ + projectId: TEST_PROJECT_ID, apiKeyId: key!.id, url: "https://example.com/webhook", privateKey: "test-private-key", @@ -38,6 +59,7 @@ export async function insertKey( overrides: Partial<{ revoked: boolean; expiresAt: string }> = {} ): Promise { const db = getPostgresDB(); + await ensureTestProject(); const [key] = await db .insert(apiKeysTable) .values({ @@ -47,6 +69,7 @@ export async function insertKey( expiresAt: overrides.expiresAt ?? DateTime.utc().plus({ years: 1 }).toISO(), revoked: overrides.revoked ?? false, + projectId: TEST_PROJECT_ID, }) .returning({ id: apiKeysTable.id }); return key!.id; diff --git a/src/context/auth.ts b/src/context/auth.ts index 8bfe9da..bea042c 100644 --- a/src/context/auth.ts +++ b/src/context/auth.ts @@ -6,4 +6,5 @@ export interface AuthContext { apiKeyId: string; role: ApiKeyRole; mode: "production" | "test" | null; + projectId: string; } diff --git a/src/interceptors/auth.ts b/src/interceptors/auth.ts index 076b955..b271258 100644 --- a/src/interceptors/auth.ts +++ b/src/interceptors/auth.ts @@ -167,6 +167,7 @@ export function authInterceptor( apiKeyId: cached.id, role: cached.role, mode: cached.mode, + projectId: cached.projectId, }; wideEventBuilder?.setAuth(cached.id, true); @@ -201,7 +202,7 @@ export function authInterceptor( if ( DateTime.utc() > - DateTime.fromSQL(apiKeyRecord.expiresAt, { zone: "utc" }) + DateTime.fromISO(apiKeyRecord.expiresAt, { zone: "utc" }) ) { return callback?.(AuthError.expiredAPIKey()); } @@ -220,6 +221,7 @@ export function authInterceptor( id: apiKeyRecord.id, role: apiKeyRecord.role as ApiKeyRole, mode: recordMode, + projectId: apiKeyRecord.projectId, expiresAt: apiKeyRecord.expiresAt, }); @@ -227,6 +229,7 @@ export function authInterceptor( apiKeyId: apiKeyRecord.id, role: apiKeyRecord.role as ApiKeyRole, mode: recordMode, + projectId: apiKeyRecord.projectId, }; wideEventBuilder?.setAuth(apiKeyRecord.id, false); @@ -270,6 +273,7 @@ async function lookupApiKey(apiKeyHash: string) { role: apiKeysTable.role, expiresAt: apiKeysTable.expiresAt, revoked: apiKeysTable.revoked, + projectId: apiKeysTable.projectId, }) .from(apiKeysTable) .where(eq(apiKeysTable.key, apiKeyHash)) diff --git a/src/routes/gRPC/auth/createAPIKey.ts b/src/routes/gRPC/auth/createAPIKey.ts index 462e521..535214c 100644 --- a/src/routes/gRPC/auth/createAPIKey.ts +++ b/src/routes/gRPC/auth/createAPIKey.ts @@ -70,6 +70,7 @@ export async function createAPIKey( key: apiKeyHash, role: validatedData.role, expiresAt: expiresAt.toISO(), + projectId: auth.projectId, }); if (!keyEventData) { diff --git a/src/routes/gRPC/data/query.ts b/src/routes/gRPC/data/query.ts index 9f2799f..f52c852 100644 --- a/src/routes/gRPC/data/query.ts +++ b/src/routes/gRPC/data/query.ts @@ -2,6 +2,7 @@ import type { sendUnaryData } from "@grpc/grpc-js"; import { QueryRequest, QueryResponse, Row } from "../../../gen/data/v1/data"; import { dataQuerySchema, type DataQueryRequest } from "../../../zod/data"; import { EventError } from "../../../errors/event"; +import { AuthError } from "../../../errors/auth"; import { formatZodError } from "../../../utils/formatZodError"; import { getPostgresDB } from "../../../storage/db/postgres/db"; import { @@ -9,7 +10,6 @@ import { sessionsTable, tagsTable, expressionsTable, - metadataTable, } from "../../../storage/db/postgres/schema"; import { eq, @@ -29,6 +29,7 @@ import type { SQL } from "drizzle-orm"; import type { AnyPgColumn } from "drizzle-orm/pg-core"; import type { WideEventBuilder } from "../../../context/requestContext"; import { wideEventContextKey } from "../../../context/requestContext"; +import { apiKeyContextKey } from "../../../context/auth"; import type { ContextUnaryCall } from "../../../interface/types/context.js"; interface FieldDef { @@ -36,14 +37,15 @@ interface FieldDef { cast: "text" | "integer" | "uuid" | "timestamptz" | "boolean"; } +type ScopedTable = + | typeof usersTable + | typeof sessionsTable + | typeof tagsTable + | typeof expressionsTable; + interface TableDef { tableName: string; - table: - | typeof usersTable - | typeof sessionsTable - | typeof tagsTable - | typeof expressionsTable - | typeof metadataTable; + table: ScopedTable; fields: Record; } @@ -95,13 +97,6 @@ const TABLE_REGISTRY: Record = { expr: { col: expressionsTable.expr, cast: "text" }, }, }, - metadata: { - tableName: "metadata", - table: metadataTable, - fields: { - id: { col: metadataTable.id, cast: "uuid" }, - }, - }, }; function castValue( @@ -206,6 +201,11 @@ export async function queryData( | undefined; try { + const auth = call[apiKeyContextKey]; + if (!auth) { + return callback?.(AuthError.invalidAPIKey("API key context not found")); + } + const req = { ...call.request } as Record; const validated = dataQuerySchema.parse(req); @@ -223,7 +223,11 @@ export async function queryData( } const db = getPostgresDB(); - const whereClause = buildWhere(validated.where, tableDef); + const userWhere = buildWhere(validated.where, tableDef); + const projectFilter = eq(tableDef.table.projectId, auth.projectId) as SQL; + const whereClause = userWhere + ? and(projectFilter, userWhere) + : projectFilter; const selectCols = buildSelect(tableDef); const columns = Object.keys(tableDef.fields); diff --git a/src/routes/gRPC/events/registerEvent.ts b/src/routes/gRPC/events/registerEvent.ts index 5ca981d..5973823 100644 --- a/src/routes/gRPC/events/registerEvent.ts +++ b/src/routes/gRPC/events/registerEvent.ts @@ -5,7 +5,7 @@ import { import type { WideEventBuilder } from "../../../context/requestContext"; import { apiKeyContextKey } from "../../../context/auth"; import { wideEventContextKey } from "../../../context/requestContext"; -import { registerEventSchema } from "../../../zod/event"; +import { createRegisterEventSchema } from "../../../zod/event"; import { EventError } from "../../../errors/event"; import { AuthError } from "../../../errors/auth"; import { createEventInstance, storeEvent } from "../../../utils/eventHelpers"; @@ -33,7 +33,9 @@ export async function registerEvent( ); } - const eventSkeleton = await registerEventSchema.parseAsync({ ...req }); + const eventSkeleton = await createRegisterEventSchema( + auth.projectId + ).parseAsync({ ...req }); wideEventBuilder?.setUser(eventSkeleton.userId); wideEventBuilder?.setEventContext({ eventType: eventSkeleton.type }); diff --git a/src/routes/gRPC/events/streamEvents.ts b/src/routes/gRPC/events/streamEvents.ts index c91faeb..e97e91d 100644 --- a/src/routes/gRPC/events/streamEvents.ts +++ b/src/routes/gRPC/events/streamEvents.ts @@ -8,7 +8,7 @@ import { import { EventError } from "../../../errors/event"; import { AuthError } from "../../../errors/auth"; import { StorageError } from "../../../errors/storage"; -import { streamEventSchema } from "../../../zod/event"; +import { createStreamEventSchema } from "../../../zod/event"; import { createEventInstance, storeEvent } from "../../../utils/eventHelpers"; import { apiKeyContextKey } from "../../../context/auth"; import { wideEventContextKey } from "../../../context/requestContext"; @@ -88,7 +88,9 @@ export async function streamEvents( for await (const req of call) { try { - const eventSkeleton = await streamEventSchema.parseAsync({ ...req }); + const eventSkeleton = await createStreamEventSchema( + auth.projectId + ).parseAsync({ ...req }); wideEventBuilder?.setUser(eventSkeleton.userId); wideEventBuilder?.setEventContext({ eventType: "AI_TOKEN_USAGE" }); diff --git a/src/routes/gRPC/payment/createCheckoutLink.ts b/src/routes/gRPC/payment/createCheckoutLink.ts index 097573b..6973f22 100644 --- a/src/routes/gRPC/payment/createCheckoutLink.ts +++ b/src/routes/gRPC/payment/createCheckoutLink.ts @@ -32,7 +32,7 @@ import { getPostgresDB } from "../../../storage/db/postgres/db"; import { checkIfExistingCheckoutLink } from "../../../storage/db/postgres/helpers/sessions"; import { ensureUserExists } from "../../../storage/db/postgres/helpers/users"; import { usersTable } from "../../../storage/db/postgres/schema"; -import { eq } from "drizzle-orm"; +import { eq, and } from "drizzle-orm"; export async function createCheckoutLink( call: ContextUnaryCall, @@ -64,7 +64,7 @@ export async function createCheckoutLink( const mode = auth.mode; - const config = await getPaymentProviderConfig(mode); + const config = await getPaymentProviderConfig(auth.projectId, mode); const validatedData = validateRequest(req); wideEventBuilder?.setUser(validatedData.userId); @@ -80,6 +80,7 @@ export async function createCheckoutLink( wideEventBuilder?.setPaymentContext({ priceAmount: custom_price }); const checkoutResult = await createCheckoutSession( + auth.projectId, config, custom_price, validatedData.userId, @@ -92,16 +93,22 @@ export async function createCheckoutLink( db, "create checkout link", async (txn) => { - await ensureUserExists(validatedData.userId, txn); + await ensureUserExists(auth.projectId, validatedData.userId, txn); await txn .select({ id: usersTable.id }) .from(usersTable) - .where(eq(usersTable.id, validatedData.userId)) + .where( + and( + eq(usersTable.projectId, auth.projectId), + eq(usersTable.id, validatedData.userId) + ) + ) .for("update"); const existingId = await checkIfExistingCheckoutLink( txn, + auth.projectId, validatedData.userId, mode ); @@ -112,6 +119,7 @@ export async function createCheckoutLink( } const sessionResult = await handleAddSession( + auth.projectId, validatedData.userId, checkoutResult.sessionId, beforeTimestamp, @@ -164,6 +172,7 @@ async function calculatePrice( } async function createCheckoutSession( + projectId: string, config: PaymentProviderConfig, customPrice: number, userId: string, @@ -177,7 +186,12 @@ async function createCheckoutSession( apiKeyId, }; - const checkoutResult = await createProviderCheckout(config, params, mode); + const checkoutResult = await createProviderCheckout( + projectId, + config, + params, + mode + ); if ( !checkoutResult.checkoutUrl || diff --git a/src/routes/gRPC/payment/paymentProvider.ts b/src/routes/gRPC/payment/paymentProvider.ts index 3f938e9..f151748 100644 --- a/src/routes/gRPC/payment/paymentProvider.ts +++ b/src/routes/gRPC/payment/paymentProvider.ts @@ -3,60 +3,63 @@ import { PaymentError } from "../../../errors/payment"; import { getMetadata } from "../../../storage/db/postgres/helpers/metadata"; import { decrypt } from "../../../utils/encryptMetadata.ts"; -let liveClient: DodoPayments | null = null; -let testClient: DodoPayments | null = null; +const clients = new Map(); -function clearClients(): void { - liveClient = null; - testClient = null; +function clientKey(projectId: string, mode: string): string { + return `${projectId}:${mode}`; +} + +export function clearClients(): void { + clients.clear(); +} + +export function removeClient(projectId: string): void { + clients.delete(clientKey(projectId, "test")); + clients.delete(clientKey(projectId, "production")); } export async function getDodoClient( + projectId: string, mode?: "test" | "production" ): Promise { if (!mode) { mode = process.env.NODE_ENV === "production" ? "production" : "test"; } - if (mode === "production") { - if (liveClient) return liveClient; - - const metadata = await getMetadata(); - const apiKey = metadata?.dodo_live_api_key; - if (!apiKey) { - throw PaymentError.missingApiKey(); - } - - liveClient = new DodoPayments({ - bearerToken: decrypt(apiKey), - environment: "live_mode", - webhookKey: metadata?.dodo_live_webhook_secret - ? decrypt(metadata.dodo_live_webhook_secret) - : undefined, - }); - return liveClient; + const key = clientKey(projectId, mode); + const cached = clients.get(key); + if (cached) return cached; + + const metadata = await getMetadata(projectId); + + if (!metadata) { + throw PaymentError.missingMetadata(); } - if (testClient) return testClient; + const encryptedApiKey = + mode === "production" + ? metadata.dodo_live_api_key + : metadata.dodo_test_api_key; + const encryptedWebhookSecret = + mode === "production" + ? metadata.dodo_live_webhook_secret + : metadata.dodo_test_webhook_secret; - const metadata = await getMetadata(); - const apiKey = metadata?.dodo_test_api_key; - if (!apiKey) { + if (!encryptedApiKey) { throw PaymentError.missingApiKey(); } - testClient = new DodoPayments({ - bearerToken: decrypt(apiKey), - environment: "test_mode", - webhookKey: metadata?.dodo_test_webhook_secret - ? decrypt(metadata.dodo_test_webhook_secret) + const client = new DodoPayments({ + bearerToken: decrypt(encryptedApiKey), + environment: mode === "production" ? "live_mode" : "test_mode", + webhookKey: encryptedWebhookSecret + ? decrypt(encryptedWebhookSecret) : undefined, }); - return testClient; -} -// Re-export for callers who need to invalidate cached clients after onboarding updates -export { clearClients }; + clients.set(key, client); + return client; +} export interface PaymentProviderConfig { productId: string; @@ -76,13 +79,14 @@ export interface CheckoutResult { } export async function getPaymentProviderConfig( + projectId: string, mode: "test" | "production" ): Promise { if (!mode) { mode = process.env.NODE_ENV === "production" ? "production" : "test"; } - const metadata = await getMetadata(); + const metadata = await getMetadata(projectId); if (!metadata) { throw PaymentError.missingMetadata(); @@ -90,9 +94,9 @@ export async function getPaymentProviderConfig( const productId = mode === "production" - ? metadata?.dodo_live_product_id - : metadata?.dodo_test_product_id; - const returnUrl = metadata?.redirect_url ?? null; + ? metadata.dodo_live_product_id + : metadata.dodo_test_product_id; + const returnUrl = metadata.redirect_url ?? null; if (!productId) { throw PaymentError.missingProductId(); @@ -102,11 +106,12 @@ export async function getPaymentProviderConfig( } export async function createProviderCheckout( + projectId: string, config: PaymentProviderConfig, params: CheckoutParams, mode: "test" | "production" ): Promise { - const client = await getDodoClient(mode); + const client = await getDodoClient(projectId, mode); const session = await client.checkoutSessions.create({ product_cart: [ diff --git a/src/routes/http/api/apiKeys.ts b/src/routes/http/api/apiKeys.ts index 2a18fd5..2dc084d 100644 --- a/src/routes/http/api/apiKeys.ts +++ b/src/routes/http/api/apiKeys.ts @@ -16,6 +16,7 @@ import { createApiKey } from "../../../storage/db/postgres/helpers/apiKeys"; import { upsertWebhookEndpoint } from "../../../storage/db/postgres/helpers/webhookEndpoints"; import { generateWebhookKeyPair } from "../../../utils/generateWebhookKeyPair"; import { getPostgresDB } from "../../../storage/db/postgres/db"; +import { executeInTransaction } from "../../../storage/adapter/postgres/handlers/addEventUtils"; import { apiKeysTable, webhookEndpointsTable, @@ -23,6 +24,7 @@ import { import { eq, and, isNull, ne, sql } from "drizzle-orm"; import type { ApiKeyRole } from "../../../utils/keyFormat"; import { invalidateWebhookEndpointCache } from "../../../interceptors/auth"; +import { apiKeyCache } from "../../../utils/apiKeyCache"; const createApiKeySchema = z.object({ name: z.string().min(1, "Name is required").max(255), @@ -47,29 +49,61 @@ export async function handleCreateApiKey( try { const auth = await authenticateHttpApiKey(request.headers.authorization); + if (auth.role !== "dashboard") { + throw AuthError.permissionDenied( + "Only dashboard keys can manage API keys" + ); + } builder.setApiKeyContext({ name: `create-key:${auth.apiKeyId}` }); const body = await request.body; const validated = createApiKeySchema.parse(body); + if ( + validated.role === "production" && + !validated.webhookUrl.startsWith("https://") + ) { + builder.setError(400, { + type: "ValidationError", + message: "Production webhook URLs must use HTTPS", + }); + reply.code(400); + return { error: "Production webhook URLs must use HTTPS" }; + } + const apiKey = generateAPIKey(validated.role as ApiKeyRole); const apiKeyHash = hashAPIKey(apiKey); const now = DateTime.utc(); const expiresAt = now.plus({ seconds: validated.expiresIn }); - const keyRecord = await createApiKey({ - name: validated.name, - key: apiKeyHash, - role: validated.role, - expiresAt: expiresAt.toISO(), - }); + const db = getPostgresDB(); + const { keyRecord, endpoint } = await executeInTransaction( + db, + "create API key", + async (txn) => { + const rec = await createApiKey( + { + name: validated.name, + key: apiKeyHash, + role: validated.role, + expiresAt: expiresAt.toISO(), + projectId: auth.projectId, + }, + txn + ); + + const keyPair = generateWebhookKeyPair(); + const ep = await upsertWebhookEndpoint( + auth.projectId, + rec.id, + validated.webhookUrl, + keyPair.privateKeyPem, + keyPair.publicKeyPrefixed, + txn + ); - const keyPair = generateWebhookKeyPair(); - const endpoint = await upsertWebhookEndpoint( - keyRecord.id, - validated.webhookUrl, - keyPair.privateKeyPem, - keyPair.publicKeyPrefixed + return { keyRecord: rec, endpoint: ep }; + } ); invalidateWebhookEndpointCache(keyRecord.id); @@ -127,7 +161,12 @@ export async function handleListApiKeys( ); try { - await authenticateHttpApiKey(request.headers.authorization); + const auth = await authenticateHttpApiKey(request.headers.authorization); + if (auth.role !== "dashboard") { + throw AuthError.permissionDenied( + "Only dashboard keys can manage API keys" + ); + } const db = getPostgresDB(); const keys = await db @@ -151,7 +190,11 @@ export async function handleListApiKeys( ) ) .where( - and(ne(apiKeysTable.role, "dashboard"), eq(apiKeysTable.revoked, false)) + and( + eq(apiKeysTable.projectId, auth.projectId), + ne(apiKeysTable.role, "dashboard"), + eq(apiKeysTable.revoked, false) + ) ) .orderBy(apiKeysTable.createdAt); @@ -189,20 +232,31 @@ export async function handleRevokeApiKey( ); try { - await authenticateHttpApiKey(request.headers.authorization); + const auth = await authenticateHttpApiKey(request.headers.authorization); + if (auth.role !== "dashboard") { + throw AuthError.permissionDenied( + "Only dashboard keys can manage API keys" + ); + } const params = request.params as { id: string }; const db = getPostgresDB(); const now = DateTime.utc().toISO(); - const result = await db + const [revokedRow] = await db .update(apiKeysTable) .set({ revoked: true, revokedAt: now }) .where( - and(eq(apiKeysTable.id, params.id), eq(apiKeysTable.revoked, false)) - ); + and( + eq(apiKeysTable.projectId, auth.projectId), + eq(apiKeysTable.id, params.id), + eq(apiKeysTable.revoked, false), + ne(apiKeysTable.role, "dashboard") + ) + ) + .returning({ key: apiKeysTable.key }); - if ((result.count ?? 0) === 0) { + if (!revokedRow) { builder.setError(404, { type: "NotFoundError", message: "API key not found or already revoked", @@ -211,6 +265,8 @@ export async function handleRevokeApiKey( return { error: "API key not found or already revoked" }; } + apiKeyCache.delete(revokedRow.key); + builder.setSuccess(200); reply.code(200); return { message: "API key revoked" }; diff --git a/src/routes/http/api/expressions.ts b/src/routes/http/api/expressions.ts index 59e63d9..1b3cfeb 100644 --- a/src/routes/http/api/expressions.ts +++ b/src/routes/http/api/expressions.ts @@ -45,9 +45,9 @@ export async function handleListExpressions( try { const authHeader = request.headers.authorization; - await authenticateHttpApiKey(authHeader); + const auth = await authenticateHttpApiKey(authHeader); - const expressions = await listExpressions(); + const expressions = await listExpressions(auth.projectId); builder.setSuccess(200).addContext({ expressionCount: expressions.length }); reply.code(200); @@ -84,15 +84,15 @@ export async function handleCreateExpression( try { const authHeader = request.headers.authorization; - await authenticateHttpApiKey(authHeader); + const auth = await authenticateHttpApiKey(authHeader); const body = await request.body; const validated = createExpressionSchema.parse(body); validateExprSyntax(validated.expr); - await resolveExprRefsInExpression(validated.expr); + await resolveExprRefsInExpression(validated.expr, auth.projectId); - await createExpression(validated.key, validated.expr); + await createExpression(auth.projectId, validated.key, validated.expr); builder.setSuccess(200); reply.code(200); @@ -147,10 +147,10 @@ export async function handleDeleteExpression( try { const authHeader = request.headers.authorization; - await authenticateHttpApiKey(authHeader); + const auth = await authenticateHttpApiKey(authHeader); const params = request.params as { key: string }; - const deleted = await deleteExpression(params.key); + const deleted = await deleteExpression(auth.projectId, params.key); if (!deleted) { builder.setError(404, { diff --git a/src/routes/http/api/onboarding.ts b/src/routes/http/api/onboarding.ts index f3d5721..b347638 100644 --- a/src/routes/http/api/onboarding.ts +++ b/src/routes/http/api/onboarding.ts @@ -1,3 +1,4 @@ +import { randomUUID } from "crypto"; import type { FastifyRequest, FastifyReply } from "fastify"; import * as Sentry from "@sentry/bun"; import { ZodError } from "zod"; @@ -9,18 +10,30 @@ import { } from "../../../context/requestContext.ts"; import { logger } from "../../../errors/logger.ts"; import { AuthError } from "../../../errors/auth"; +import { authenticateMasterApiKey } from "../../../utils/authenticateMasterApiKey.ts"; import { authenticateHttpApiKey } from "../../../utils/authenticateHttpApiKey.ts"; +import { generateAPIKey } from "../../../utils/generateAPIKey"; +import { hashAPIKey } from "../../../utils/hashAPIKey"; +import { encrypt, decrypt } from "../../../utils/encryptMetadata.ts"; +import { getPostgresDB } from "../../../storage/db/postgres/db"; +import { + projectsTable, + apiKeysTable, + metadataTable, +} from "../../../storage/db/postgres/schema"; import { - upsertMetadata, getMetadata, -} from "../../../storage/db/postgres/helpers/metadata.ts"; -import { clearClients } from "../../gRPC/payment/paymentProvider.ts"; -import { encrypt, decrypt } from "../../../utils/encryptMetadata.ts"; + getAnyMetadata, +} from "../../../storage/db/postgres/helpers/metadata"; +import { removeClient } from "../../gRPC/payment/paymentProvider.ts"; +import { DateTime } from "luxon"; +import { eq } from "drizzle-orm"; +import { executeInTransaction } from "../../../storage/adapter/postgres/handlers/addEventUtils"; export async function handleOnboarding( request: FastifyRequest, reply: FastifyReply -): Promise> { +): Promise> { const builder = createWideEventBuilder( generateRequestId(), request.method, @@ -29,7 +42,7 @@ export async function handleOnboarding( try { const authHeader = request.headers.authorization; - await authenticateHttpApiKey(authHeader); + authenticateMasterApiKey(authHeader); const body = await request.body; const validated = onboardingSchema.parse(body); @@ -44,6 +57,23 @@ export async function handleOnboarding( return {}; } + const projectId = randomUUID(); + + const existing = await getPostgresDB() + .select({ id: projectsTable.id }) + .from(projectsTable) + .where(eq(projectsTable.name, validated.name)) + .limit(1); + + if (existing.length > 0) { + builder.setError(409, { + type: "ConflictError", + message: `Project with name '${validated.name}' already exists`, + }); + reply.code(409); + return {}; + } + const liveClient = new DodoPayments({ bearerToken: validated.dodoLiveApiKey, environment: "live_mode", @@ -55,23 +85,41 @@ export async function handleOnboarding( let liveSecret: string; let testSecret: string; + let liveWebhookId: string | undefined; + let testWebhookId: string | undefined; try { const liveWebhook = await liveClient.webhooks.create({ - url: `${appUrl}/webhooks/payment/createdCheckout?mode=production`, + url: `${appUrl}/webhooks/payment/createdCheckout?mode=production&projectId=${projectId}`, description: "Scrawn live payment webhook", filter_types: ["payment.succeeded", "payment.failed"], }); + liveWebhookId = liveWebhook.id; liveSecret = (await liveClient.webhooks.retrieveSecret(liveWebhook.id)) .secret; const testWebhook = await testClient.webhooks.create({ - url: `${appUrl}/webhooks/payment/createdCheckout?mode=test`, + url: `${appUrl}/webhooks/payment/createdCheckout?mode=test&projectId=${projectId}`, description: "Scrawn test payment webhook", filter_types: ["payment.succeeded", "payment.failed"], }); + testWebhookId = testWebhook.id; testSecret = (await testClient.webhooks.retrieveSecret(testWebhook.id)) .secret; } catch (error) { + if (liveWebhookId) { + liveClient.webhooks.delete(liveWebhookId).catch((e) => + Sentry.captureException(e, { + extra: { context: "rollback: failed to delete live webhook" }, + }) + ); + } + if (testWebhookId) { + testClient.webhooks.delete(testWebhookId).catch((e) => + Sentry.captureException(e, { + extra: { context: "rollback: failed to delete test webhook" }, + }) + ); + } const errMsg = error instanceof Error ? error.message : String(error); Sentry.captureException(error, { extra: { context: "dodo webhook registration during onboarding" }, @@ -84,23 +132,68 @@ export async function handleOnboarding( return {}; } - await upsertMetadata({ - dodo_live_api_key: encrypt(validated.dodoLiveApiKey), - dodo_test_api_key: encrypt(validated.dodoTestApiKey), - dodo_live_product_id: validated.dodoLiveProductId, - dodo_test_product_id: validated.dodoTestProductId, - dodo_live_webhook_secret: encrypt(liveSecret), - dodo_test_webhook_secret: encrypt(testSecret), - currency: validated.currency, - redirect_url: validated.redirectUrl, - }); + const dashboardKey = generateAPIKey("dashboard"); + const dashboardKeyHash = hashAPIKey(dashboardKey); + const expiresAt = DateTime.utc().plus({ years: 10 }).toISO(); - clearClients(); + const db = getPostgresDB(); + try { + await executeInTransaction(db, "create project", async (txn) => { + await txn.insert(projectsTable).values({ + id: projectId, + name: validated.name, + }); - builder.setSuccess(200); + await txn.insert(metadataTable).values({ + projectId, + dodo_live_api_key: encrypt(validated.dodoLiveApiKey), + dodo_test_api_key: encrypt(validated.dodoTestApiKey), + dodo_live_product_id: validated.dodoLiveProductId, + dodo_test_product_id: validated.dodoTestProductId, + dodo_live_webhook_secret: encrypt(liveSecret), + dodo_test_webhook_secret: encrypt(testSecret), + currency: validated.currency, + redirect_url: validated.redirectUrl, + }); + + await txn.insert(apiKeysTable).values({ + projectId, + name: "Default dashboard key", + key: dashboardKeyHash, + role: "dashboard", + expiresAt, + }); + }); + } catch (txnError) { + if (liveWebhookId) { + liveClient.webhooks.delete(liveWebhookId).catch((e) => + Sentry.captureException(e, { + extra: { + context: + "rollback: failed to delete live webhook after DB failure", + }, + }) + ); + } + if (testWebhookId) { + testClient.webhooks.delete(testWebhookId).catch((e) => + Sentry.captureException(e, { + extra: { + context: + "rollback: failed to delete test webhook after DB failure", + }, + }) + ); + } + throw txnError; + } + + removeClient(projectId); + + builder.setSuccess(201); reply.code(201); - return {}; + return { projectId, apiKey: dashboardKey }; } catch (error) { Sentry.captureException(error, { extra: { context: "onboarding route handler" }, @@ -157,9 +250,27 @@ export async function handleGetConfig( try { const authHeader = request.headers.authorization; - await authenticateHttpApiKey(authHeader); - const metadata = await getMetadata(); + let projectId: string | undefined; + let isMasterKey = false; + try { + authenticateMasterApiKey(authHeader); + isMasterKey = true; + } catch (masterErr) { + if (!(masterErr instanceof AuthError)) { + throw masterErr; + } + + const auth = await authenticateHttpApiKey(authHeader); + if (auth.role !== "dashboard") { + throw AuthError.permissionDenied("Only dashboard keys can read config"); + } + projectId = auth.projectId; + } + + const metadata = isMasterKey + ? await getAnyMetadata() + : await getMetadata(projectId!); if (!metadata) { builder.setSuccess(200); diff --git a/src/routes/http/api/tags.ts b/src/routes/http/api/tags.ts index 6be7c28..444f875 100644 --- a/src/routes/http/api/tags.ts +++ b/src/routes/http/api/tags.ts @@ -47,9 +47,9 @@ export async function handleListTags( try { const authHeader = request.headers.authorization; - await authenticateHttpApiKey(authHeader); + const auth = await authenticateHttpApiKey(authHeader); - const tags = await listTags(); + const tags = await listTags(auth.projectId); builder.setSuccess(200).addContext({ tagCount: tags.length }); reply.code(200); @@ -86,12 +86,12 @@ export async function handleCreateTag( try { const authHeader = request.headers.authorization; - await authenticateHttpApiKey(authHeader); + const auth = await authenticateHttpApiKey(authHeader); const body = await request.body; const validated = createTagSchema.parse(body); - await createTag(validated.key, validated.amount); + await createTag(auth.projectId, validated.key, validated.amount); builder.setSuccess(200); reply.code(200); @@ -137,10 +137,10 @@ export async function handleDeleteTag( try { const authHeader = request.headers.authorization; - await authenticateHttpApiKey(authHeader); + const auth = await authenticateHttpApiKey(authHeader); const params = tagParamsSchema.parse(request.params); - const deleted = await deleteTag(params.key); + const deleted = await deleteTag(auth.projectId, params.key); if (!deleted) { builder.setError(404, { diff --git a/src/routes/http/api/webhookDeliveries.ts b/src/routes/http/api/webhookDeliveries.ts index a72fbfb..0bca7fb 100644 --- a/src/routes/http/api/webhookDeliveries.ts +++ b/src/routes/http/api/webhookDeliveries.ts @@ -16,6 +16,7 @@ import { apiKeysTable, } from "../../../storage/db/postgres/schema"; import { and, eq, desc, inArray, sql } from "drizzle-orm"; +import type { SQL } from "drizzle-orm"; const listDeliveriesQuerySchema = z.object({ apiKeyId: z.string().uuid("Invalid API key ID").optional(), @@ -37,47 +38,54 @@ export async function handleListDeliveries( ); try { - await authenticateHttpApiKey(request.headers.authorization); + const auth = await authenticateHttpApiKey(request.headers.authorization); const query = listDeliveriesQuerySchema.parse(request.query); const db = getPostgresDB(); - let conditions = undefined; + let conditions: SQL | undefined = eq( + webhookDeliveriesTable.projectId, + auth.projectId + ); if (query.apiKeyId) { const endpoints = await db .select({ id: webhookEndpointsTable.id }) .from(webhookEndpointsTable) - .where(eq(webhookEndpointsTable.apiKeyId, query.apiKeyId)); - const ids = endpoints.map((e) => e.id); - if (ids.length > 0) { - conditions = inArray(webhookDeliveriesTable.endpointId, ids); - } else { - conditions = eq( - webhookDeliveriesTable.id, - "00000000-0000-0000-0000-000000000000" + .where( + and( + eq(webhookEndpointsTable.projectId, auth.projectId), + eq(webhookEndpointsTable.apiKeyId, query.apiKeyId) + ) ); - } + const ids = endpoints.map((e) => e.id); + conditions = + ids.length > 0 + ? and(conditions, inArray(webhookDeliveriesTable.endpointId, ids)) + : and( + conditions, + eq( + webhookDeliveriesTable.id, + "00000000-0000-0000-0000-000000000000" + ) + ); } if (query.eventType) { - conditions = conditions - ? and(conditions, eq(webhookDeliveriesTable.eventType, query.eventType)) - : eq(webhookDeliveriesTable.eventType, query.eventType); + conditions = and( + conditions, + eq(webhookDeliveriesTable.eventType, query.eventType) + ); } if (query.status) { - conditions = conditions - ? and( - conditions, - sql`${webhookDeliveriesTable.status} = ${query.status}` - ) - : sql`${webhookDeliveriesTable.status} = ${query.status}`; + conditions = and( + conditions, + sql`${webhookDeliveriesTable.status} = ${query.status}` + ); } if (query.role) { - conditions = conditions - ? and(conditions, sql`${apiKeysTable.role} = ${query.role}`) - : sql`${apiKeysTable.role} = ${query.role}`; + conditions = and(conditions, sql`${apiKeysTable.role} = ${query.role}`); } const rows = await db diff --git a/src/routes/http/api/webhookEndpoints.ts b/src/routes/http/api/webhookEndpoints.ts index a02547b..a3ef8f3 100644 --- a/src/routes/http/api/webhookEndpoints.ts +++ b/src/routes/http/api/webhookEndpoints.ts @@ -106,6 +106,15 @@ export async function handleCreateWebhookEndpoint( return { error: "Target API key not found" }; } + if (targetKey.projectId !== auth.projectId) { + builder.setError(403, { + type: "PermissionDenied", + message: "Target API key does not belong to this project", + }); + reply.code(403); + return { error: "Target API key does not belong to this project" }; + } + if (targetKey.role === "dashboard") { builder.setError(400, { type: "ValidationError", @@ -130,6 +139,7 @@ export async function handleCreateWebhookEndpoint( const keyPair = generateWebhookKeyPair(); const endpoint = await upsertWebhookEndpoint( + auth.projectId, targetApiKeyId, validated.url, keyPair.privateKeyPem, @@ -184,7 +194,10 @@ export async function handleGetWebhookEndpoint( const auth = await authenticateHttpApiKey(request.headers.authorization); builder.setApiKeyContext({ name: `webhook:${auth.apiKeyId}` }); - const endpoint = await getWebhookEndpointByApiKeyId(auth.apiKeyId); + const endpoint = await getWebhookEndpointByApiKeyId( + auth.projectId, + auth.apiKeyId + ); const endpoints: WebhookEndpointResponse[] = endpoint ? [toEndpointResponse(endpoint)] @@ -227,7 +240,7 @@ export async function handleDeleteWebhookEndpoint( const auth = await authenticateHttpApiKey(request.headers.authorization); builder.setApiKeyContext({ name: `webhook:${auth.apiKeyId}` }); - const deleted = await deleteWebhookEndpoint(auth.apiKeyId); + const deleted = await deleteWebhookEndpoint(auth.projectId, auth.apiKeyId); if (!deleted) { builder.setError(404, { @@ -299,6 +312,15 @@ export async function handleSendTestWebhook( return { error: "API key not found" }; } + if (targetKey.projectId !== auth.projectId) { + builder.setError(403, { + type: "PermissionDenied", + message: "API key does not belong to this project", + }); + reply.code(403); + return { error: "API key does not belong to this project" }; + } + if (targetKey.role !== "test") { builder.setError(400, { type: "ValidationError", @@ -308,7 +330,10 @@ export async function handleSendTestWebhook( return { error: "Can only send test webhooks to test API keys" }; } - const endpoint = await getWebhookEndpointByApiKeyId(validated.apiKeyId); + const endpoint = await getWebhookEndpointByApiKeyId( + auth.projectId, + validated.apiKeyId + ); if (!endpoint) { builder.setError(404, { @@ -321,7 +346,7 @@ export async function handleSendTestWebhook( const now = DateTime.utc(); - await forwardWebhook(validated.apiKeyId, { + await forwardWebhook(auth.projectId, validated.apiKeyId, { eventType: "payment.succeeded", resource: "payment", action: "succeeded", @@ -374,7 +399,10 @@ export async function handleGetPublicKey( const auth = await authenticateHttpApiKey(request.headers.authorization); builder.setApiKeyContext({ name: `webhook:${auth.apiKeyId}` }); - const endpoint = await getWebhookEndpointByApiKeyId(auth.apiKeyId); + const endpoint = await getWebhookEndpointByApiKeyId( + auth.projectId, + auth.apiKeyId + ); if (!endpoint) { builder.setError(404, { diff --git a/src/routes/http/createdCheckout.ts b/src/routes/http/createdCheckout.ts index 878b648..84700d3 100644 --- a/src/routes/http/createdCheckout.ts +++ b/src/routes/http/createdCheckout.ts @@ -72,10 +72,12 @@ export async function handleDodoWebhook( timestamp: string | undefined, webhookId: string | undefined, mode: "production" | "test", + projectId: string, builder: WideEventBuilder ): Promise { try { const client = await getDodoClient( + projectId, mode === "production" ? "production" : "test" ); const headers = buildWebhookHeaders(signature, timestamp, webhookId); @@ -133,6 +135,15 @@ export async function handleDodoWebhook( ); } + if (session.projectId !== projectId) { + return errorResponse( + 404, + "NotFoundError", + `Session not found for checkout_session_id: ${checkout_session_id}`, + builder + ); + } + if (session.processed !== "pending") { Sentry.captureMessage( `Webhook received for session ${checkout_session_id} with non-pending status: ${session.processed}`, @@ -158,7 +169,7 @@ export async function handleDodoWebhook( } builder.setSuccess(200); - forwardWebhook(session.apiKeyId, { + forwardWebhook(session.projectId, session.apiKeyId, { eventType: "payment.failed", resource: "payment", action: "failed", @@ -190,8 +201,14 @@ export async function handleDodoWebhook( txn ); if (!claimed) return; - await updateUserBilledTimestamp(userId, billed_upto, txn); + await updateUserBilledTimestamp( + session.projectId, + userId, + billed_upto, + txn + ); await handleAddPayment( + session.projectId, userId, creditAmount, apiKeyId, @@ -212,7 +229,7 @@ export async function handleDodoWebhook( builder.setPaymentContext({ creditAmount }); builder.setSuccess(200); - forwardWebhook(apiKeyId, { + forwardWebhook(session.projectId, apiKeyId, { eventType: "payment.succeeded", resource: "payment", action: "succeeded", diff --git a/src/routes/http/forwardWebhook.ts b/src/routes/http/forwardWebhook.ts index d27b4b2..ade2c7c 100644 --- a/src/routes/http/forwardWebhook.ts +++ b/src/routes/http/forwardWebhook.ts @@ -38,10 +38,11 @@ export interface WebhookForwardEvent { } export async function forwardWebhook( + projectId: string, apiKeyId: string, event: WebhookForwardEvent ): Promise { - const endpoint = await getWebhookEndpointByApiKeyId(apiKeyId); + const endpoint = await getWebhookEndpointByApiKeyId(projectId, apiKeyId); if (!endpoint) { return; @@ -71,7 +72,7 @@ export async function forwardWebhook( Sentry.captureException(error, { extra: { context: "webhook signing failed", error: errorMsg }, }); - await recordDelivery(endpoint.id, webhookId, event, "failed", { + await recordDelivery(projectId, endpoint.id, webhookId, event, "failed", { error: errorMsg, }); return; @@ -117,6 +118,7 @@ export async function forwardWebhook( } await recordDelivery( + projectId, endpoint.id, webhookId, event, @@ -129,6 +131,7 @@ export async function forwardWebhook( } async function recordDelivery( + projectId: string, endpointId: string, eventId: string, event: WebhookForwardEvent, @@ -141,6 +144,7 @@ async function recordDelivery( try { const db = getPostgresDB(); await db.insert(webhookDeliveriesTable).values({ + projectId, endpointId, eventId, eventType: event.eventType, diff --git a/src/routes/http/registerWebhookRoutes.ts b/src/routes/http/registerWebhookRoutes.ts index 0e15cd3..84502e8 100644 --- a/src/routes/http/registerWebhookRoutes.ts +++ b/src/routes/http/registerWebhookRoutes.ts @@ -36,7 +36,10 @@ export async function registerWebhookRoutes( ); try { - const mode = (request.query as Record)?.mode; + const query = request.query as Record; + const mode = query?.mode; + const projectId = query?.projectId; + if (mode !== "production" && mode !== "test") { builder.setError(400, { type: "ValidationError", @@ -47,6 +50,28 @@ export async function registerWebhookRoutes( return { error: "Invalid mode query parameter" }; } + if (!projectId) { + builder.setError(400, { + type: "ValidationError", + message: "Missing 'projectId' query parameter.", + }); + reply.code(400); + return { error: "Missing projectId query parameter" }; + } + + if ( + !/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i.test( + projectId + ) + ) { + builder.setError(400, { + type: "ValidationError", + message: "Invalid 'projectId' query parameter.", + }); + reply.code(400); + return { error: "Invalid projectId query parameter" }; + } + const signatureHeader = request.headers["webhook-signature"]; const timestampHeader = request.headers["webhook-timestamp"]; const webhookIdHeader = request.headers["webhook-id"]; @@ -72,6 +97,7 @@ export async function registerWebhookRoutes( timestamp, webhookId, mode, + projectId, builder ); diff --git a/src/storage/adapter/clickhouse/handlers/addAiTokenUsage.ts b/src/storage/adapter/clickhouse/handlers/addAiTokenUsage.ts index 052bde7..1061e1d 100644 --- a/src/storage/adapter/clickhouse/handlers/addAiTokenUsage.ts +++ b/src/storage/adapter/clickhouse/handlers/addAiTokenUsage.ts @@ -135,6 +135,7 @@ function buildAiTokenInsertRows( return { id: index === 0 ? firstId : crypto.randomUUID(), + project_id: auth.projectId, event_id: aggEvent.eventId, idempotency_key: aggEvent.idempotencyKey, user_id: aggEvent.userId, @@ -166,7 +167,7 @@ export async function handleAddAiTokenUsage( const firstEvent = events[0]; if (firstEvent) { - await ensureUserExists(firstEvent.userId); + await ensureUserExists(auth.projectId, firstEvent.userId); } const aggregatedEvents = aggregateAiTokenEvents(events); diff --git a/src/storage/adapter/clickhouse/handlers/addBasicUsage.ts b/src/storage/adapter/clickhouse/handlers/addBasicUsage.ts index 55a474d..6c2221c 100644 --- a/src/storage/adapter/clickhouse/handlers/addBasicUsage.ts +++ b/src/storage/adapter/clickhouse/handlers/addBasicUsage.ts @@ -27,7 +27,7 @@ export async function handleAddBasicUsage( } const reportedTimestamp = toClickHouseDateTime(event_data.reported_timestamp); - await ensureUserExists(event_data.userId); + await ensureUserExists(auth.projectId, event_data.userId); const id = crypto.randomUUID(); @@ -37,6 +37,7 @@ export async function handleAddBasicUsage( values: [ { id, + project_id: auth.projectId, event_id: event_data.eventId, idempotency_key: event_data.idempotencyKey, user_id: event_data.userId, diff --git a/src/storage/adapter/clickhouse/handlers/priceRequestAiTokenUsage.ts b/src/storage/adapter/clickhouse/handlers/priceRequestAiTokenUsage.ts index 13310f9..7cd4f13 100644 --- a/src/storage/adapter/clickhouse/handlers/priceRequestAiTokenUsage.ts +++ b/src/storage/adapter/clickhouse/handlers/priceRequestAiTokenUsage.ts @@ -5,8 +5,8 @@ import { runClickHousePriceQuery } from "../utils"; const VALUE_EXPR = "JSONExtractInt(metrics, 'debit_amount', 'input') + JSONExtractInt(metrics, 'debit_amount', 'input_cache') + JSONExtractInt(metrics, 'debit_amount', 'output')"; -const BASE_QUERY = `SELECT sum(${VALUE_EXPR}) as total FROM ai_token_usage_events WHERE user_id = {userId:String} AND mode = {mode:String} AND reported_timestamp < {before:DateTime64(3, 'UTC')}`; -const WINDOW_QUERY = `SELECT sum(${VALUE_EXPR}) as total FROM ai_token_usage_events WHERE user_id = {userId:String} AND mode = {mode:String} AND reported_timestamp > {lastBilled:DateTime64(3, 'UTC')} AND reported_timestamp < {before:DateTime64(3, 'UTC')}`; +const BASE_QUERY = `SELECT sum(${VALUE_EXPR}) as total FROM ai_token_usage_events WHERE project_id = {projectId:String} AND user_id = {userId:String} AND mode = {mode:String} AND reported_timestamp < {before:DateTime64(3, 'UTC')}`; +const WINDOW_QUERY = `SELECT sum(${VALUE_EXPR}) as total FROM ai_token_usage_events WHERE project_id = {projectId:String} AND user_id = {userId:String} AND mode = {mode:String} AND reported_timestamp > {lastBilled:DateTime64(3, 'UTC')} AND reported_timestamp < {before:DateTime64(3, 'UTC')}`; export async function handlePriceRequestAiTokenUsage( userId: UserId, diff --git a/src/storage/adapter/clickhouse/handlers/priceRequestBasicUsage.ts b/src/storage/adapter/clickhouse/handlers/priceRequestBasicUsage.ts index 86022de..6b7a8c5 100644 --- a/src/storage/adapter/clickhouse/handlers/priceRequestBasicUsage.ts +++ b/src/storage/adapter/clickhouse/handlers/priceRequestBasicUsage.ts @@ -4,9 +4,9 @@ import type { AuthContext } from "../../../../context/auth"; import { runClickHousePriceQuery } from "../utils"; const BASE_QUERY = - "SELECT sum(debit_amount) as total FROM basic_usage_events WHERE user_id = {userId:String} AND mode = {mode:String} AND reported_timestamp < {before:DateTime64(3, 'UTC')}"; + "SELECT sum(debit_amount) as total FROM basic_usage_events WHERE project_id = {projectId:String} AND user_id = {userId:String} AND mode = {mode:String} AND reported_timestamp < {before:DateTime64(3, 'UTC')}"; const WINDOW_QUERY = - "SELECT sum(debit_amount) as total FROM basic_usage_events WHERE user_id = {userId:String} AND mode = {mode:String} AND reported_timestamp > {lastBilled:DateTime64(3, 'UTC')} AND reported_timestamp < {before:DateTime64(3, 'UTC')}"; + "SELECT sum(debit_amount) as total FROM basic_usage_events WHERE project_id = {projectId:String} AND user_id = {userId:String} AND mode = {mode:String} AND reported_timestamp > {lastBilled:DateTime64(3, 'UTC')} AND reported_timestamp < {before:DateTime64(3, 'UTC')}"; export async function handlePriceRequestBasicUsage( userId: UserId, diff --git a/src/storage/adapter/clickhouse/handlers/queryEvents.ts b/src/storage/adapter/clickhouse/handlers/queryEvents.ts index 2f17e0e..abf3fa9 100644 --- a/src/storage/adapter/clickhouse/handlers/queryEvents.ts +++ b/src/storage/adapter/clickhouse/handlers/queryEvents.ts @@ -211,9 +211,9 @@ export async function handleQueryEvents( try { if (request.aggregation) { - return await handleAggregationQuery(request, tables); + return await handleAggregationQuery(request, tables, auth); } - return await handleListQuery(request, tables); + return await handleListQuery(request, tables, auth); } catch (e) { if ( e && @@ -232,11 +232,12 @@ export async function handleQueryEvents( async function handleListQuery( request: QueryRequest, - tables: EventTableName[] + tables: EventTableName[], + auth: AuthContext ): Promise { const client = getClickHouseDB(); const paramIndex = { value: 0 }; - const params: Record = {}; + const params: Record = { projectId: auth.projectId }; const queries = tables.map((t) => { const whereClause = buildWhereFromGroup( @@ -246,7 +247,8 @@ async function handleListQuery( paramIndex ); let q = `SELECT ${buildSelectColumns(t)} FROM ${t}`; - if (whereClause) q += ` WHERE ${whereClause}`; + q += ` WHERE project_id = {projectId:String}`; + if (whereClause) q += ` AND ${whereClause}`; return q; }); @@ -276,20 +278,21 @@ async function handleListQuery( data as unknown as Record[] ).map(normalizeRow); - const total = await getTotalCount(request, tables); + const total = await getTotalCount(request, tables, auth); return { rows, total }; } async function handleAggregationQuery( request: QueryRequest, - tables: EventTableName[] + tables: EventTableName[], + auth: AuthContext ): Promise { const client = getClickHouseDB(); const agg = request.aggregation!; const isSum = agg.type === "SUM"; const paramIndex = { value: 0 }; - const params: Record = {}; + const params: Record = { projectId: auth.projectId }; const subQueries = tables.map((t) => { const cols: string[] = []; @@ -325,7 +328,8 @@ async function handleAggregationQuery( paramIndex ); let q = `SELECT ${cols.join(", ")} FROM ${t}`; - if (whereClause) q += ` WHERE ${whereClause}`; + q += ` WHERE project_id = {projectId:String}`; + if (whereClause) q += ` AND ${whereClause}`; return q; }); @@ -364,11 +368,12 @@ async function handleAggregationQuery( async function getTotalCount( request: QueryRequest, - tables: EventTableName[] + tables: EventTableName[], + auth: AuthContext ): Promise { const client = getClickHouseDB(); const paramIndex = { value: 0 }; - const params: Record = {}; + const params: Record = { projectId: auth.projectId }; const subQueries = tables.map((t) => { const whereClause = buildWhereFromGroup( @@ -378,7 +383,8 @@ async function getTotalCount( paramIndex ); let q = `SELECT count() as cnt FROM ${t}`; - if (whereClause) q += ` WHERE ${whereClause}`; + q += ` WHERE project_id = {projectId:String}`; + if (whereClause) q += ` AND ${whereClause}`; return q; }); diff --git a/src/storage/adapter/clickhouse/schema.ts b/src/storage/adapter/clickhouse/schema.ts index ae12dcd..653c662 100644 --- a/src/storage/adapter/clickhouse/schema.ts +++ b/src/storage/adapter/clickhouse/schema.ts @@ -5,6 +5,7 @@ import { innerProduct } from "drizzle-orm"; const BASIC_USAGE_EVENTS_TABLE = ` CREATE TABLE IF NOT EXISTS basic_usage_events ( id UUID DEFAULT generateUUIDv4(), + project_id String, event_id String, idempotency_key String, user_id String, @@ -16,12 +17,13 @@ CREATE TABLE IF NOT EXISTS basic_usage_events ( debit_amount Int64, metadata JSON ) ENGINE = ReplacingMergeTree() -ORDER BY (idempotency_key, user_id) +ORDER BY (project_id, idempotency_key, user_id) `; const AI_TOKEN_USAGE_EVENTS_TABLE = ` CREATE TABLE IF NOT EXISTS ai_token_usage_events ( id UUID DEFAULT generateUUIDv4(), + project_id String, event_id String, idempotency_key String, user_id String, @@ -34,7 +36,7 @@ CREATE TABLE IF NOT EXISTS ai_token_usage_events ( metrics String, metadata JSON ) ENGINE = ReplacingMergeTree() -ORDER BY (idempotency_key, user_id) +ORDER BY (project_id, idempotency_key, user_id) `; export async function runClickHouseMigrations(): Promise { diff --git a/src/storage/adapter/clickhouse/utils.ts b/src/storage/adapter/clickhouse/utils.ts index 2858ba9..4873556 100644 --- a/src/storage/adapter/clickhouse/utils.ts +++ b/src/storage/adapter/clickhouse/utils.ts @@ -2,7 +2,7 @@ import type { DateTime } from "luxon"; import { DateTime as LuxonDateTime } from "luxon"; import { getPostgresDB } from "../../db/postgres/db"; import { usersTable } from "../../db/postgres/schema"; -import { eq } from "drizzle-orm"; +import { eq, and } from "drizzle-orm"; import { StorageError } from "../../../errors/storage"; import { getClickHouseDB } from "../../db/clickhouse"; import type { UserId } from "../../../config/identifiers"; @@ -12,13 +12,18 @@ export function toClickHouseDateTime(dt: DateTime): string { return dt.toUTC().toFormat("yyyy-MM-dd HH:mm:ss.SSS"); } -async function fetchLastBilled(userId: string): Promise { +async function fetchLastBilled( + userId: string, + projectId: string +): Promise { const pgDb = getPostgresDB(); try { const [user] = await pgDb .select({ lastBilled: usersTable.last_billed_timestamp }) .from(usersTable) - .where(eq(usersTable.id, userId)) + .where( + and(eq(usersTable.id, userId), eq(usersTable.projectId, projectId)) + ) .limit(1); return user?.lastBilled ?? null; } catch { @@ -49,7 +54,7 @@ export async function runClickHousePriceQuery( } const beforeTs = toClickHouseDateTime(beforeTimestamp); - const lastBilled = await fetchLastBilled(userId); + const lastBilled = await fetchLastBilled(userId, auth.projectId); try { let query: string; @@ -70,6 +75,7 @@ export async function runClickHousePriceQuery( query = baseQuery; } + params.projectId = auth.projectId; params.mode = auth.mode; const rs = await chClient.query({ diff --git a/src/storage/adapter/postgres/handlers/addAiTokenUsage.ts b/src/storage/adapter/postgres/handlers/addAiTokenUsage.ts index 09a8bb1..1117fa2 100644 --- a/src/storage/adapter/postgres/handlers/addAiTokenUsage.ts +++ b/src/storage/adapter/postgres/handlers/addAiTokenUsage.ts @@ -117,6 +117,7 @@ function buildAiTokenInsertValues( auth: AuthContext ) { return aggregatedEvents.map((aggEvent) => ({ + projectId: auth.projectId, eventId: aggEvent.eventId, idempotencyKey: aggEvent.idempotencyKey, reportedTimestamp: aggEvent.reported_timestamp, @@ -166,7 +167,7 @@ export async function handleAddAiTokenUsage( `storing ${events.length} AI_TOKEN_USAGE event(s)`, async (txn) => { if (firstEvent) { - await ensureUserExists(firstEvent.userId, txn); + await ensureUserExists(auth.projectId, firstEvent.userId, txn); } try { diff --git a/src/storage/adapter/postgres/handlers/addBasicUsage.ts b/src/storage/adapter/postgres/handlers/addBasicUsage.ts index 127b4c0..0de79d3 100644 --- a/src/storage/adapter/postgres/handlers/addBasicUsage.ts +++ b/src/storage/adapter/postgres/handlers/addBasicUsage.ts @@ -28,7 +28,11 @@ export async function handleAddBasicUsage( connectionObject, "storing BASIC_USAGE event", async (txn) => { - const ensurePromise = ensureUserExists(event_data.userId, txn); + const ensurePromise = ensureUserExists( + auth.projectId, + event_data.userId, + txn + ); const reportedTimestamp = await validateAndPrepareTimestamp( event_data.reported_timestamp @@ -38,6 +42,7 @@ export async function handleAddBasicUsage( const [result] = await txn .insert(basicUsageEventsTable) .values({ + projectId: auth.projectId, eventId: event_data.eventId, idempotencyKey: event_data.idempotencyKey, reportedTimestamp, diff --git a/src/storage/adapter/postgres/handlers/priceRequest.ts b/src/storage/adapter/postgres/handlers/priceRequest.ts index 2cde117..5a2bfe2 100644 --- a/src/storage/adapter/postgres/handlers/priceRequest.ts +++ b/src/storage/adapter/postgres/handlers/priceRequest.ts @@ -37,7 +37,7 @@ export async function handlePriceRequest( let result; try { - const baseCondition = sql`${priceTable.reportedTimestamp} > ${usersTable.last_billed_timestamp} AND ${priceTable.userId} = ${userId} AND ${priceTable.mode} = ${auth.mode}`; + const baseCondition = sql`${priceTable.reportedTimestamp} > ${usersTable.last_billed_timestamp} AND ${priceTable.userId} = ${userId} AND ${priceTable.mode} = ${auth.mode} AND ${priceTable.projectId} = ${auth.projectId}`; const whereClause = beforeTimestamp ? and( baseCondition, @@ -50,7 +50,13 @@ export async function handlePriceRequest( price: sum(priceColumn), }) .from(priceTable) - .innerJoin(usersTable, eq(priceTable.userId, usersTable.id)) + .innerJoin( + usersTable, + and( + eq(priceTable.userId, usersTable.id), + eq(priceTable.projectId, usersTable.projectId) + ) + ) .where(whereClause) .groupBy(priceTable.userId); } catch (e) { @@ -82,15 +88,7 @@ export async function handlePriceRequest( return 0; } - let parsedPrice: number; - try { - parsedPrice = parseInt(priceValue); - } catch (e) { - throw StorageError.priceCalculationFailed( - userId, - new Error(`Failed to parse price value: ${priceValue}`) - ); - } + const parsedPrice = parseInt(priceValue, 10); if (isNaN(parsedPrice)) { throw StorageError.priceCalculationFailed( diff --git a/src/storage/adapter/postgres/handlers/queryEvents.ts b/src/storage/adapter/postgres/handlers/queryEvents.ts index 70bbfc8..3bcf970 100644 --- a/src/storage/adapter/postgres/handlers/queryEvents.ts +++ b/src/storage/adapter/postgres/handlers/queryEvents.ts @@ -255,9 +255,9 @@ export async function handleQueryEvents( try { if (request.aggregation) { - return await handleAggregationQuery(request, tables); + return await handleAggregationQuery(request, tables, auth); } - return await handleListQuery(request, tables); + return await handleListQuery(request, tables, auth); } catch (e) { if ( e && @@ -276,16 +276,21 @@ export async function handleQueryEvents( async function handleListQuery( request: QueryRequest, - tables: EventTableName[] + tables: EventTableName[], + auth: AuthContext ): Promise { const db = getPostgresDB(); const selectExpr = tables.map((t) => buildSelectColumns(t)); const whereExpr = tables.map((t) => buildWhereClause(request.where, t)); + const projectFilter = sql`project_id = ${auth.projectId}`; const subqueries = tables.map((t, i) => { const base = sql`SELECT ${selectExpr[i]} FROM ${sql.raw(t)}`; - return whereExpr[i] ? sql`${base} WHERE ${whereExpr[i]}` : base; + const fullWhere = whereExpr[i] + ? sql`${whereExpr[i]} AND ${projectFilter}` + : projectFilter; + return sql`${base} WHERE ${fullWhere}`; }); const unionQuery = sql.join(subqueries, sql` UNION ALL `); @@ -301,18 +306,20 @@ async function handleListQuery( const data = result as unknown as Record[]; const rows: QueryResultRow[] = data.map(normalizeRow); - const total = await getTotalCount(request, tables); + const total = await getTotalCount(request, tables, auth); return { rows, total }; } async function handleAggregationQuery( request: QueryRequest, - tables: EventTableName[] + tables: EventTableName[], + auth: AuthContext ): Promise { const db = getPostgresDB(); const agg = request.aggregation!; const isSum = agg.type === "SUM"; + const projectFilter = sql`project_id = ${auth.projectId}`; const subqueries = tables.map((t) => { const cols: SQL[] = []; @@ -349,7 +356,10 @@ async function handleAggregationQuery( const whereClause = buildWhereClause(request.where, t); const base = sql`SELECT ${sql.join(cols, sql`, `)} FROM ${sql.raw(t)}`; - return whereClause ? sql`${base} WHERE ${whereClause}` : base; + const fullWhere = whereClause + ? sql`${whereClause} AND ${projectFilter}` + : projectFilter; + return sql`${base} WHERE ${fullWhere}`; }); const unionQuery = sql.join(subqueries, sql` UNION ALL `); @@ -395,14 +405,19 @@ async function handleAggregationQuery( async function getTotalCount( request: QueryRequest, - tables: EventTableName[] + tables: EventTableName[], + auth: AuthContext ): Promise { const db = getPostgresDB(); + const projectFilter = sql`project_id = ${auth.projectId}`; const subqueries = tables.map((t) => { const whereClause = buildWhereClause(request.where, t); const base = sql`SELECT count(*)::int as cnt FROM ${sql.raw(t)}`; - return whereClause ? sql`${base} WHERE ${whereClause}` : base; + const fullWhere = whereClause + ? sql`${whereClause} AND ${projectFilter}` + : projectFilter; + return sql`${base} WHERE ${fullWhere}`; }); const countQuery = sql` diff --git a/src/storage/db/postgres/helpers/apiKeys.ts b/src/storage/db/postgres/helpers/apiKeys.ts index 5b36c3b..950ae8b 100644 --- a/src/storage/db/postgres/helpers/apiKeys.ts +++ b/src/storage/db/postgres/helpers/apiKeys.ts @@ -9,6 +9,7 @@ type CreateApiKeyInput = { key: string; role: string; expiresAt: string; + projectId: string; }; export async function createApiKey( @@ -37,6 +38,7 @@ export async function createApiKey( key: input.key, role: input.role as "dashboard" | "production" | "test", expiresAt: input.expiresAt, + projectId: input.projectId, }) .returning({ id: apiKeysTable.id }); @@ -74,16 +76,20 @@ type ApiKeyRecord = { role: string; expiresAt: string; revoked: boolean; + projectId: string; }; export async function getApiKeyRoleById( id: string -): Promise<{ role: "dashboard" | "production" | "test" } | null> { +): Promise<{ + role: "dashboard" | "production" | "test"; + projectId: string; +} | null> { const db = getPostgresDB(); try { const [record] = await db - .select({ role: apiKeysTable.role }) + .select({ role: apiKeysTable.role, projectId: apiKeysTable.projectId }) .from(apiKeysTable) .where(eq(apiKeysTable.id, id)) .limit(1); @@ -109,6 +115,7 @@ export async function findApiKeyByHash( role: apiKeysTable.role, expiresAt: apiKeysTable.expiresAt, revoked: apiKeysTable.revoked, + projectId: apiKeysTable.projectId, }) .from(apiKeysTable) .where(eq(apiKeysTable.key, apiKeyHash)) diff --git a/src/storage/db/postgres/helpers/expressions.ts b/src/storage/db/postgres/helpers/expressions.ts index 546b319..41639d6 100644 --- a/src/storage/db/postgres/helpers/expressions.ts +++ b/src/storage/db/postgres/helpers/expressions.ts @@ -4,14 +4,19 @@ import { eq, and, isNull } from "drizzle-orm"; import { StorageError } from "../../../../errors/storage"; import { DateTime } from "luxon"; -export async function listExpressions(): Promise { +export async function listExpressions(projectId: string): Promise { const db = getPostgresDB(); try { const rows = await db .select({ key: expressionsTable.key }) .from(expressionsTable) - .where(isNull(expressionsTable.deletedAt)); + .where( + and( + eq(expressionsTable.projectId, projectId), + isNull(expressionsTable.deletedAt) + ) + ); return rows.map((row) => row.key); } catch (e) { throw StorageError.queryFailed( @@ -21,7 +26,10 @@ export async function listExpressions(): Promise { } } -export async function findExpressionByKey(key: string): Promise { +export async function findExpressionByKey( + projectId: string, + key: string +): Promise { const db = getPostgresDB(); try { @@ -29,7 +37,11 @@ export async function findExpressionByKey(key: string): Promise { .select({ expr: expressionsTable.expr }) .from(expressionsTable) .where( - and(eq(expressionsTable.key, key), isNull(expressionsTable.deletedAt)) + and( + eq(expressionsTable.projectId, projectId), + eq(expressionsTable.key, key), + isNull(expressionsTable.deletedAt) + ) ) .limit(1); @@ -43,6 +55,7 @@ export async function findExpressionByKey(key: string): Promise { } export async function createExpression( + projectId: string, key: string, expr: string ): Promise { @@ -53,7 +66,11 @@ export async function createExpression( .select({ id: expressionsTable.id }) .from(expressionsTable) .where( - and(eq(expressionsTable.key, key), isNull(expressionsTable.deletedAt)) + and( + eq(expressionsTable.projectId, projectId), + eq(expressionsTable.key, key), + isNull(expressionsTable.deletedAt) + ) ) .limit(1); @@ -65,7 +82,7 @@ export async function createExpression( return; } - await db.insert(expressionsTable).values({ key, expr }); + await db.insert(expressionsTable).values({ projectId, key, expr }); } catch (e) { throw StorageError.insertFailed( `Failed to upsert expression '${key}'`, @@ -74,7 +91,10 @@ export async function createExpression( } } -export async function deleteExpression(key: string): Promise { +export async function deleteExpression( + projectId: string, + key: string +): Promise { const db = getPostgresDB(); try { @@ -83,7 +103,11 @@ export async function deleteExpression(key: string): Promise { .update(expressionsTable) .set({ deletedAt: now }) .where( - and(eq(expressionsTable.key, key), isNull(expressionsTable.deletedAt)) + and( + eq(expressionsTable.projectId, projectId), + eq(expressionsTable.key, key), + isNull(expressionsTable.deletedAt) + ) ); return (result.count ?? 0) > 0; diff --git a/src/storage/db/postgres/helpers/metadata.ts b/src/storage/db/postgres/helpers/metadata.ts index 4d83430..d14389b 100644 --- a/src/storage/db/postgres/helpers/metadata.ts +++ b/src/storage/db/postgres/helpers/metadata.ts @@ -2,73 +2,20 @@ import { getPostgresDB } from "../db"; import { metadataTable } from "../schema"; import { StorageError } from "../../../../errors/storage"; import { eq } from "drizzle-orm"; -import { executeInTransaction } from "../../../adapter/postgres/handlers/addEventUtils"; -export type UpsertMetadataInput = { - dodo_live_api_key?: string; - dodo_test_api_key?: string; - dodo_live_product_id?: string; - dodo_test_product_id?: string; - dodo_live_webhook_secret?: string; - dodo_test_webhook_secret?: string; - currency?: string; - redirect_url?: string; -}; - -export async function upsertMetadata( - input: UpsertMetadataInput -): Promise { +export async function getMetadata( + projectId: string +): Promise { const db = getPostgresDB(); - - await executeInTransaction(db, "upsert metadata", async (txn) => { - try { - const [existingMetadata] = await txn - .select({ id: metadataTable.id }) - .from(metadataTable) - .limit(1) - .for("update"); - - const setValues: Partial = {}; - if (input.dodo_live_api_key !== undefined) - setValues.dodo_live_api_key = input.dodo_live_api_key; - if (input.dodo_test_api_key !== undefined) - setValues.dodo_test_api_key = input.dodo_test_api_key; - if (input.dodo_live_product_id !== undefined) - setValues.dodo_live_product_id = input.dodo_live_product_id; - if (input.dodo_test_product_id !== undefined) - setValues.dodo_test_product_id = input.dodo_test_product_id; - if (input.dodo_live_webhook_secret !== undefined) - setValues.dodo_live_webhook_secret = input.dodo_live_webhook_secret; - if (input.dodo_test_webhook_secret !== undefined) - setValues.dodo_test_webhook_secret = input.dodo_test_webhook_secret; - if (input.currency !== undefined) setValues.currency = input.currency; - if (input.redirect_url !== undefined) - setValues.redirect_url = input.redirect_url; - - if (existingMetadata) { - if (Object.keys(setValues).length > 0) { - await txn - .update(metadataTable) - .set(setValues) - .where(eq(metadataTable.id, existingMetadata.id)); - } - return; - } - - const insertValues: typeof metadataTable.$inferInsert = { - ...setValues, - } as typeof metadataTable.$inferInsert; - await txn.insert(metadataTable).values(insertValues); - } catch (e) { - throw StorageError.insertFailed( - "Failed to upsert metadata record", - e instanceof Error ? e : new Error(String(e)) - ); - } - }); + const [metadata] = await db + .select() + .from(metadataTable) + .where(eq(metadataTable.projectId, projectId)) + .limit(1); + return metadata; } -export async function getMetadata(): Promise< +export async function getAnyMetadata(): Promise< typeof metadataTable.$inferSelect | undefined > { const db = getPostgresDB(); diff --git a/src/storage/db/postgres/helpers/payments.ts b/src/storage/db/postgres/helpers/payments.ts index 9af0688..507e8f9 100644 --- a/src/storage/db/postgres/helpers/payments.ts +++ b/src/storage/db/postgres/helpers/payments.ts @@ -5,6 +5,7 @@ import { DateTime } from "luxon"; import type { PgTransaction } from "drizzle-orm/pg-core"; export async function handleAddPayment( + projectId: string, userId: string, creditAmount: number, apiKeyId: string, @@ -30,6 +31,7 @@ export async function handleAddPayment( const [result] = await db .insert(paymentEventsTable) .values({ + projectId, reportedTimestamp: DateTime.utc().toISO()!, userId, apiKeyId, diff --git a/src/storage/db/postgres/helpers/sessions.ts b/src/storage/db/postgres/helpers/sessions.ts index 4e3cf69..8598390 100644 --- a/src/storage/db/postgres/helpers/sessions.ts +++ b/src/storage/db/postgres/helpers/sessions.ts @@ -32,6 +32,7 @@ export async function updateSessionStatus( export async function checkIfExistingCheckoutLink( txn: PgTransaction, + projectId: string, userId: UserId, mode: "test" | "production" ): Promise { @@ -45,6 +46,7 @@ export async function checkIfExistingCheckoutLink( .from(sessionsTable) .where( and( + eq(sessionsTable.projectId, projectId), eq(sessionsTable.userId, userId), eq(sessionsTable.processed, "pending"), eq(sessionsTable.mode, mode), @@ -68,6 +70,7 @@ export async function checkIfExistingCheckoutLink( } export async function handleAddSession( + projectId: string, userId: UserId, sessionId: string, billedUpto: DateTime, @@ -91,6 +94,7 @@ export async function handleAddSession( const insertResult = await connectionObject .insert(sessionsTable) .values({ + projectId, userId: userId, sessionId: sessionId, billed_upto: billedUptoStr, diff --git a/src/storage/db/postgres/helpers/tags.ts b/src/storage/db/postgres/helpers/tags.ts index 9056f0f..e20d072 100644 --- a/src/storage/db/postgres/helpers/tags.ts +++ b/src/storage/db/postgres/helpers/tags.ts @@ -5,14 +5,18 @@ import { StorageError } from "../../../../errors/storage"; import { DateTime } from "luxon"; import { tagCache } from "../../../../utils/tagCache"; -export async function listTags(): Promise<{ key: string; amount: number }[]> { +export async function listTags( + projectId: string +): Promise<{ key: string; amount: number }[]> { const db = getPostgresDB(); try { const rows = await db .select({ key: tagsTable.key, amount: tagsTable.amount }) .from(tagsTable) - .where(isNull(tagsTable.deletedAt)); + .where( + and(eq(tagsTable.projectId, projectId), isNull(tagsTable.deletedAt)) + ); return rows; } catch (e) { throw StorageError.queryFailed( @@ -22,14 +26,24 @@ export async function listTags(): Promise<{ key: string; amount: number }[]> { } } -export async function createTag(key: string, amount: number): Promise { +export async function createTag( + projectId: string, + key: string, + amount: number +): Promise { const db = getPostgresDB(); try { const existing = await db .select({ id: tagsTable.id }) .from(tagsTable) - .where(and(eq(tagsTable.key, key), isNull(tagsTable.deletedAt))) + .where( + and( + eq(tagsTable.projectId, projectId), + eq(tagsTable.key, key), + isNull(tagsTable.deletedAt) + ) + ) .limit(1); if (existing[0]) { @@ -37,12 +51,12 @@ export async function createTag(key: string, amount: number): Promise { .update(tagsTable) .set({ amount }) .where(eq(tagsTable.id, existing[0].id)); - tagCache.delete(key); + tagCache.delete(`${projectId}:${key}`); return; } - await db.insert(tagsTable).values({ key, amount }); - tagCache.delete(key); + await db.insert(tagsTable).values({ projectId, key, amount }); + tagCache.delete(`${projectId}:${key}`); } catch (e) { throw StorageError.insertFailed( `Failed to upsert tag '${key}'`, @@ -51,7 +65,10 @@ export async function createTag(key: string, amount: number): Promise { } } -export async function deleteTag(key: string): Promise { +export async function deleteTag( + projectId: string, + key: string +): Promise { const db = getPostgresDB(); try { @@ -59,10 +76,16 @@ export async function deleteTag(key: string): Promise { const result = await db .update(tagsTable) .set({ deletedAt: now }) - .where(and(eq(tagsTable.key, key), isNull(tagsTable.deletedAt))); + .where( + and( + eq(tagsTable.projectId, projectId), + eq(tagsTable.key, key), + isNull(tagsTable.deletedAt) + ) + ); if ((result.count ?? 0) > 0) { - tagCache.delete(key); + tagCache.delete(`${projectId}:${key}`); return true; } return false; diff --git a/src/storage/db/postgres/helpers/users.ts b/src/storage/db/postgres/helpers/users.ts index eab1db6..624a5a5 100644 --- a/src/storage/db/postgres/helpers/users.ts +++ b/src/storage/db/postgres/helpers/users.ts @@ -1,10 +1,11 @@ import { getPostgresDB } from "../db"; import { usersTable } from "../schema"; -import { eq } from "drizzle-orm"; +import { eq, and } from "drizzle-orm"; import { StorageError } from "../../../../errors/storage"; import type { PgTransaction } from "drizzle-orm/pg-core"; export async function updateUserBilledTimestamp( + projectId: string, userId: string, billedUpto: string, txn?: PgTransaction @@ -15,7 +16,9 @@ export async function updateUserBilledTimestamp( await db .update(usersTable) .set({ last_billed_timestamp: billedUpto }) - .where(eq(usersTable.id, userId)); + .where( + and(eq(usersTable.projectId, projectId), eq(usersTable.id, userId)) + ); } catch (e) { throw StorageError.queryFailed( "Failed to update user billed timestamp", @@ -24,17 +27,21 @@ export async function updateUserBilledTimestamp( } } -export async function userExists(userId: string): Promise { +export async function userExists( + projectId: string, + userId: string +): Promise { const db = getPostgresDB(); const result = await db .select({ id: usersTable.id }) .from(usersTable) - .where(eq(usersTable.id, userId)) + .where(and(eq(usersTable.projectId, projectId), eq(usersTable.id, userId))) .limit(1); return result.length > 0; } export async function ensureUserExists( + projectId: string, userId: string, txn?: PgTransaction ): Promise { @@ -43,15 +50,12 @@ export async function ensureUserExists( try { await db .insert(usersTable) - .values({ id: userId }) - .onConflictDoNothing({ target: usersTable.id }); + .values({ id: userId, projectId }) + .onConflictDoNothing(); } catch (e) { - if ( - e instanceof Error && - (e.message.includes("duplicate") || e.message.includes("unique")) - ) { - return; - } - throw e; + throw StorageError.queryFailed( + "Failed to ensure user exists", + e instanceof Error ? e : new Error(String(e)) + ); } } diff --git a/src/storage/db/postgres/helpers/webhookEndpoints.ts b/src/storage/db/postgres/helpers/webhookEndpoints.ts index 3d82ef6..341ae20 100644 --- a/src/storage/db/postgres/helpers/webhookEndpoints.ts +++ b/src/storage/db/postgres/helpers/webhookEndpoints.ts @@ -1,12 +1,14 @@ import { getPostgresDB } from "../db"; import { webhookEndpointsTable } from "../schema"; import { eq, and, isNull } from "drizzle-orm"; +import type { PgTransaction } from "drizzle-orm/pg-core"; import { StorageError } from "../../../../errors/storage"; import { DateTime } from "luxon"; export type WebhookEndpoint = typeof webhookEndpointsTable.$inferSelect; export async function getWebhookEndpointByApiKeyId( + projectId: string, apiKeyId: string ): Promise { const db = getPostgresDB(); @@ -17,6 +19,7 @@ export async function getWebhookEndpointByApiKeyId( .from(webhookEndpointsTable) .where( and( + eq(webhookEndpointsTable.projectId, projectId), eq(webhookEndpointsTable.apiKeyId, apiKeyId), isNull(webhookEndpointsTable.deletedAt) ) @@ -33,12 +36,14 @@ export async function getWebhookEndpointByApiKeyId( } export async function upsertWebhookEndpoint( + projectId: string, apiKeyId: string, url: string, privateKey: string, - publicKey: string + publicKey: string, + txn?: PgTransaction ): Promise { - const db = getPostgresDB(); + const db = txn ?? getPostgresDB(); try { const now = DateTime.utc().toISO(); @@ -46,6 +51,7 @@ export async function upsertWebhookEndpoint( const [result] = await db .insert(webhookEndpointsTable) .values({ + projectId, apiKeyId, url, privateKey, @@ -85,6 +91,7 @@ export async function upsertWebhookEndpoint( } export async function deleteWebhookEndpoint( + projectId: string, apiKeyId: string ): Promise { const db = getPostgresDB(); @@ -97,6 +104,7 @@ export async function deleteWebhookEndpoint( .set({ deletedAt: now }) .where( and( + eq(webhookEndpointsTable.projectId, projectId), eq(webhookEndpointsTable.apiKeyId, apiKeyId), isNull(webhookEndpointsTable.deletedAt) ) diff --git a/src/storage/db/postgres/schema.ts b/src/storage/db/postgres/schema.ts index c266d0a..510c9ff 100644 --- a/src/storage/db/postgres/schema.ts +++ b/src/storage/db/postgres/schema.ts @@ -9,26 +9,52 @@ import { boolean, jsonb, uniqueIndex, + primaryKey, + foreignKey, } from "drizzle-orm/pg-core"; import { USER_ID_CONFIG } from "../../../config/identifiers"; import { DateTime } from "luxon"; import { type Metrics } from "../../../zod/metrics"; -export const usersTable = pgTable("users", { - id: USER_ID_CONFIG.dbType("id").primaryKey(), - last_billed_timestamp: timestamp("last_billed_timestamp", { +export const projectsTable = pgTable("projects", { + id: uuid("id").primaryKey().defaultRandom(), + name: text("name").notNull(), + createdAt: timestamp("created_at", { withTimezone: true, mode: "string", }) - .default(DateTime.utc(1).toString()) + .defaultNow() .notNull(), - payment_provider_user_id: text("payment_provider_user_id"), - mode: text("mode", { enum: ["test", "production"] }) - .notNull() - .default("production"), }); -export const usersRelation = relations(usersTable, ({ many }) => ({ +export const usersTable = pgTable( + "users", + { + id: USER_ID_CONFIG.dbType("id").notNull(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), + last_billed_timestamp: timestamp("last_billed_timestamp", { + withTimezone: true, + mode: "string", + }) + .default(DateTime.utc(1).toString()) + .notNull(), + payment_provider_user_id: text("payment_provider_user_id"), + mode: text("mode", { enum: ["test", "production"] }) + .notNull() + .default("production"), + }, + (table) => ({ + pk: primaryKey({ columns: [table.projectId, table.id] }), + }) +); + +export const usersRelation = relations(usersTable, ({ one, many }) => ({ + project: one(projectsTable, { + fields: [usersTable.projectId], + references: [projectsTable.id], + }), sessions: many(sessionsTable), basicUsageEvents: many(basicUsageEventsTable), paymentEvents: many(paymentEventsTable), @@ -39,13 +65,14 @@ export const sessionsTable = pgTable( "sessions", { proxy_link_id: uuid("proxy_link_id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), sessionId: text("session_id").notNull().unique(), processed: text("processed", { enum: ["pending", "failed", "succeeded"] }) .default("pending") .notNull(), - userId: USER_ID_CONFIG.dbType("user_id") - .references(() => usersTable.id) - .notNull(), + userId: USER_ID_CONFIG.dbType("user_id").notNull(), apiKeyId: uuid("api_key_id") .references(() => apiKeysTable.id) .notNull(), @@ -66,13 +93,21 @@ export const sessionsTable = pgTable( }, (table) => ({ uniqueSessionId: uniqueIndex("unique_session_id").on(table.sessionId), + userFk: foreignKey({ + columns: [table.projectId, table.userId], + foreignColumns: [usersTable.projectId, usersTable.id], + }), }) ); export const sessionRelations = relations(sessionsTable, ({ one, many }) => ({ + project: one(projectsTable, { + fields: [sessionsTable.projectId], + references: [projectsTable.id], + }), user: one(usersTable, { - fields: [sessionsTable.userId], - references: [usersTable.id], + fields: [sessionsTable.projectId, sessionsTable.userId], + references: [usersTable.projectId, usersTable.id], }), apiKey: one(apiKeysTable, { fields: [sessionsTable.apiKeyId], @@ -85,6 +120,9 @@ export const apiKeysTable = pgTable( "api_keys", { id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), name: text("name").notNull(), key: text("key").notNull().unique(), role: text("role", { enum: ["dashboard", "production", "test"] }) @@ -108,50 +146,68 @@ export const apiKeysTable = pgTable( }, (table) => ({ uniqueActiveName: uniqueIndex("unique_active_name") - .on(table.name) + .on(table.projectId, table.name) .where(sql`${table.revoked} = false`), }) ); -export const apiKeysRelation = relations(apiKeysTable, ({ many }) => ({ +export const apiKeysRelation = relations(apiKeysTable, ({ one, many }) => ({ + project: one(projectsTable, { + fields: [apiKeysTable.projectId], + references: [projectsTable.id], + }), sessions: many(sessionsTable), basicUsageEvents: many(basicUsageEventsTable), paymentEvents: many(paymentEventsTable), aiTokenUsageEvents: many(aiTokenUsageEventsTable), })); -export const basicUsageEventsTable = pgTable("basic_usage_events", { - id: uuid("id").primaryKey().defaultRandom(), - eventId: uuid("event_id").notNull(), - idempotencyKey: text("idempotency_key").notNull().unique(), - reportedTimestamp: timestamp("reported_timestamp", { - withTimezone: true, - mode: "string", - }).notNull(), - ingestedTimestamp: timestamp("ingested_timestamp", { - withTimezone: true, - mode: "string", +export const basicUsageEventsTable = pgTable( + "basic_usage_events", + { + id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), + eventId: uuid("event_id").notNull(), + idempotencyKey: text("idempotency_key").notNull().unique(), + reportedTimestamp: timestamp("reported_timestamp", { + withTimezone: true, + mode: "string", + }).notNull(), + ingestedTimestamp: timestamp("ingested_timestamp", { + withTimezone: true, + mode: "string", + }) + .defaultNow() + .notNull(), + userId: USER_ID_CONFIG.dbType("user_id").notNull(), + apiKeyId: uuid("api_key_id") + .references(() => apiKeysTable.id) + .notNull(), + mode: text("mode", { enum: ["test", "production"] }).notNull(), + type: text("type", { enum: ["RAW", "MIDDLEWARE_CALL"] }).notNull(), + debitAmount: bigint("debit_amount", { mode: "number" }).notNull(), + metadata: jsonb("metadata").$type>(), + }, + (table) => ({ + userFk: foreignKey({ + columns: [table.projectId, table.userId], + foreignColumns: [usersTable.projectId, usersTable.id], + }), }) - .defaultNow() - .notNull(), - userId: USER_ID_CONFIG.dbType("user_id") - .references(() => usersTable.id) - .notNull(), - apiKeyId: uuid("api_key_id") - .references(() => apiKeysTable.id) - .notNull(), - mode: text("mode", { enum: ["test", "production"] }).notNull(), - type: text("type", { enum: ["RAW", "MIDDLEWARE_CALL"] }).notNull(), - debitAmount: bigint("debit_amount", { mode: "number" }).notNull(), - metadata: jsonb("metadata").$type>(), -}); +); export const basicUsageEventsRelation = relations( basicUsageEventsTable, ({ one }) => ({ + project: one(projectsTable, { + fields: [basicUsageEventsTable.projectId], + references: [projectsTable.id], + }), user: one(usersTable, { - fields: [basicUsageEventsTable.userId], - references: [usersTable.id], + fields: [basicUsageEventsTable.projectId, basicUsageEventsTable.userId], + references: [usersTable.projectId, usersTable.id], }), apiKey: one(apiKeysTable, { fields: [basicUsageEventsTable.apiKeyId], @@ -160,37 +216,51 @@ export const basicUsageEventsRelation = relations( }) ); -export const paymentEventsTable = pgTable("payment_events", { - id: uuid("id").primaryKey().defaultRandom(), - reportedTimestamp: timestamp("reported_timestamp", { - withTimezone: true, - mode: "string", - }).notNull(), - ingestedTimestamp: timestamp("ingested_timestamp", { - withTimezone: true, - mode: "string", +export const paymentEventsTable = pgTable( + "payment_events", + { + id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), + reportedTimestamp: timestamp("reported_timestamp", { + withTimezone: true, + mode: "string", + }).notNull(), + ingestedTimestamp: timestamp("ingested_timestamp", { + withTimezone: true, + mode: "string", + }) + .defaultNow() + .notNull(), + userId: USER_ID_CONFIG.dbType("user_id").notNull(), + apiKeyId: uuid("api_key_id") + .references(() => apiKeysTable.id) + .notNull(), + mode: text("mode", { enum: ["test", "production"] }).notNull(), + creditAmount: bigint("credit_amount", { mode: "number" }).notNull(), + proxyId: uuid("proxy_id") + .references(() => sessionsTable.proxy_link_id) + .notNull(), + }, + (table) => ({ + userFk: foreignKey({ + columns: [table.projectId, table.userId], + foreignColumns: [usersTable.projectId, usersTable.id], + }), }) - .defaultNow() - .notNull(), - userId: USER_ID_CONFIG.dbType("user_id") - .references(() => usersTable.id) - .notNull(), - apiKeyId: uuid("api_key_id") - .references(() => apiKeysTable.id) - .notNull(), - mode: text("mode", { enum: ["test", "production"] }).notNull(), - creditAmount: bigint("credit_amount", { mode: "number" }).notNull(), - proxyId: uuid("proxy_id") - .references(() => sessionsTable.proxy_link_id) - .notNull(), -}); +); export const paymentEventsRelation = relations( paymentEventsTable, ({ one }) => ({ + project: one(projectsTable, { + fields: [paymentEventsTable.projectId], + references: [projectsTable.id], + }), user: one(usersTable, { - fields: [paymentEventsTable.userId], - references: [usersTable.id], + fields: [paymentEventsTable.projectId, paymentEventsTable.userId], + references: [usersTable.projectId, usersTable.id], }), apiKey: one(apiKeysTable, { fields: [paymentEventsTable.apiKeyId], @@ -203,39 +273,56 @@ export const paymentEventsRelation = relations( }) ); -export const aiTokenUsageEventsTable = pgTable("ai_token_usage_events", { - id: uuid("id").primaryKey().defaultRandom(), - eventId: uuid("event_id").notNull(), - idempotencyKey: text("idempotency_key").notNull().unique(), - reportedTimestamp: timestamp("reported_timestamp", { - withTimezone: true, - mode: "string", - }).notNull(), - ingestedTimestamp: timestamp("ingested_timestamp", { - withTimezone: true, - mode: "string", +export const aiTokenUsageEventsTable = pgTable( + "ai_token_usage_events", + { + id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), + eventId: uuid("event_id").notNull(), + idempotencyKey: text("idempotency_key").notNull().unique(), + reportedTimestamp: timestamp("reported_timestamp", { + withTimezone: true, + mode: "string", + }).notNull(), + ingestedTimestamp: timestamp("ingested_timestamp", { + withTimezone: true, + mode: "string", + }) + .defaultNow() + .notNull(), + userId: USER_ID_CONFIG.dbType("user_id").notNull(), + apiKeyId: uuid("api_key_id") + .references(() => apiKeysTable.id) + .notNull(), + mode: text("mode", { enum: ["test", "production"] }).notNull(), + model: text("model").notNull(), + provider: text("provider").notNull(), + metrics: jsonb("metrics").$type().notNull(), + metadata: jsonb("metadata").$type>(), + }, + (table) => ({ + userFk: foreignKey({ + columns: [table.projectId, table.userId], + foreignColumns: [usersTable.projectId, usersTable.id], + }), }) - .defaultNow() - .notNull(), - userId: USER_ID_CONFIG.dbType("user_id") - .references(() => usersTable.id) - .notNull(), - apiKeyId: uuid("api_key_id") - .references(() => apiKeysTable.id) - .notNull(), - mode: text("mode", { enum: ["test", "production"] }).notNull(), - model: text("model").notNull(), - provider: text("provider").notNull(), - metrics: jsonb("metrics").$type().notNull(), - metadata: jsonb("metadata").$type>(), -}); +); export const aiTokenUsageEventsRelation = relations( aiTokenUsageEventsTable, ({ one }) => ({ + project: one(projectsTable, { + fields: [aiTokenUsageEventsTable.projectId], + references: [projectsTable.id], + }), user: one(usersTable, { - fields: [aiTokenUsageEventsTable.userId], - references: [usersTable.id], + fields: [ + aiTokenUsageEventsTable.projectId, + aiTokenUsageEventsTable.userId, + ], + references: [usersTable.projectId, usersTable.id], }), apiKey: one(apiKeysTable, { fields: [aiTokenUsageEventsTable.apiKeyId], @@ -246,6 +333,9 @@ export const aiTokenUsageEventsRelation = relations( export const tagsTable = pgTable("tags", { id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), key: text("key").notNull(), amount: integer("amount").notNull(), deletedAt: timestamp("deleted_at", { @@ -254,24 +344,50 @@ export const tagsTable = pgTable("tags", { }), }); -export const metadataTable = pgTable("metadata", { - id: uuid("id").primaryKey().defaultRandom(), - last_run_at: timestamp("last_run_at", { - withTimezone: true, - mode: "string", +export const tagsRelation = relations(tagsTable, ({ one }) => ({ + project: one(projectsTable, { + fields: [tagsTable.projectId], + references: [projectsTable.id], }), - dodo_live_api_key: text("dodo_live_api_key").notNull(), - dodo_test_api_key: text("dodo_test_api_key").notNull(), - dodo_live_product_id: text("dodo_live_product_id").notNull(), - dodo_test_product_id: text("dodo_test_product_id").notNull(), - dodo_live_webhook_secret: text("dodo_live_webhook_secret").notNull(), - dodo_test_webhook_secret: text("dodo_test_webhook_secret").notNull(), - currency: text("currency").notNull().default("usd"), - redirect_url: text("redirect_url").notNull(), -}); +})); + +export const metadataTable = pgTable( + "metadata", + { + id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), + last_run_at: timestamp("last_run_at", { + withTimezone: true, + mode: "string", + }), + dodo_live_api_key: text("dodo_live_api_key").notNull(), + dodo_test_api_key: text("dodo_test_api_key").notNull(), + dodo_live_product_id: text("dodo_live_product_id").notNull(), + dodo_test_product_id: text("dodo_test_product_id").notNull(), + dodo_live_webhook_secret: text("dodo_live_webhook_secret").notNull(), + dodo_test_webhook_secret: text("dodo_test_webhook_secret").notNull(), + currency: text("currency").notNull().default("usd"), + redirect_url: text("redirect_url").notNull(), + }, + (table) => ({ + uniqueProjectId: uniqueIndex("unique_project_id").on(table.projectId), + }) +); + +export const metadataRelation = relations(metadataTable, ({ one }) => ({ + project: one(projectsTable, { + fields: [metadataTable.projectId], + references: [projectsTable.id], + }), +})); export const expressionsTable = pgTable("expressions", { id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), key: text("key").notNull(), expr: text("expr").notNull(), deletedAt: timestamp("deleted_at", { @@ -280,10 +396,20 @@ export const expressionsTable = pgTable("expressions", { }), }); +export const expressionsRelation = relations(expressionsTable, ({ one }) => ({ + project: one(projectsTable, { + fields: [expressionsTable.projectId], + references: [projectsTable.id], + }), +})); + export const webhookEndpointsTable = pgTable( "webhook_endpoints", { id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), apiKeyId: uuid("api_key_id") .references(() => apiKeysTable.id) .notNull(), @@ -315,6 +441,10 @@ export const webhookEndpointsTable = pgTable( export const webhookEndpointsRelation = relations( webhookEndpointsTable, ({ one }) => ({ + project: one(projectsTable, { + fields: [webhookEndpointsTable.projectId], + references: [projectsTable.id], + }), apiKey: one(apiKeysTable, { fields: [webhookEndpointsTable.apiKeyId], references: [apiKeysTable.id], @@ -324,6 +454,9 @@ export const webhookEndpointsRelation = relations( export const webhookDeliveriesTable = pgTable("webhook_deliveries", { id: uuid("id").primaryKey().defaultRandom(), + projectId: uuid("project_id") + .references(() => projectsTable.id) + .notNull(), endpointId: uuid("endpoint_id") .references(() => webhookEndpointsTable.id) .notNull(), @@ -346,9 +479,27 @@ export const webhookDeliveriesTable = pgTable("webhook_deliveries", { export const webhookDeliveriesRelation = relations( webhookDeliveriesTable, ({ one }) => ({ + project: one(projectsTable, { + fields: [webhookDeliveriesTable.projectId], + references: [projectsTable.id], + }), endpoint: one(webhookEndpointsTable, { fields: [webhookDeliveriesTable.endpointId], references: [webhookEndpointsTable.id], }), }) ); + +export const projectsRelation = relations(projectsTable, ({ many }) => ({ + users: many(usersTable), + sessions: many(sessionsTable), + apiKeys: many(apiKeysTable), + basicUsageEvents: many(basicUsageEventsTable), + paymentEvents: many(paymentEventsTable), + aiTokenUsageEvents: many(aiTokenUsageEventsTable), + tags: many(tagsTable), + metadata: many(metadataTable), + expressions: many(expressionsTable), + webhookEndpoints: many(webhookEndpointsTable), + webhookDeliveries: many(webhookDeliveriesTable), +})); diff --git a/src/utils/apiKeyCache.ts b/src/utils/apiKeyCache.ts index 86f7cf9..d8bdfa3 100644 --- a/src/utils/apiKeyCache.ts +++ b/src/utils/apiKeyCache.ts @@ -6,6 +6,7 @@ interface CachedAPIKey { id: string; role: ApiKeyRole; mode: "production" | "test" | null; + projectId: string; expiresAt: string; } @@ -13,7 +14,7 @@ const store = Cache.getStore("api-keys", { max: 1000, ttlMs: 5 * 60 * 1000, validate: (value) => - DateTime.utc() <= DateTime.fromSQL(value.expiresAt, { zone: "utc" }), + DateTime.utc() <= DateTime.fromISO(value.expiresAt, { zone: "utc" }), }); export const apiKeyCache = { diff --git a/src/utils/authenticateHttpApiKey.ts b/src/utils/authenticateHttpApiKey.ts index a467705..142853a 100644 --- a/src/utils/authenticateHttpApiKey.ts +++ b/src/utils/authenticateHttpApiKey.ts @@ -44,7 +44,12 @@ export async function authenticateHttpApiKey( `Key prefix ${role} doesn't match stored role ${cached.role}` ); } - return { apiKeyId: cached.id, role: cached.role, mode: cached.mode }; + return { + apiKeyId: cached.id, + role: cached.role, + mode: cached.mode, + projectId: cached.projectId, + }; } const apiKeyRecord = await findApiKeyByHash(apiKeyHash); @@ -76,8 +81,14 @@ export async function authenticateHttpApiKey( id: apiKeyRecord.id, role: recordRole, mode, + projectId: apiKeyRecord.projectId, expiresAt: apiKeyRecord.expiresAt, }); - return { apiKeyId: apiKeyRecord.id, role: recordRole, mode }; + return { + apiKeyId: apiKeyRecord.id, + role: recordRole, + mode, + projectId: apiKeyRecord.projectId, + }; } diff --git a/src/utils/authenticateMasterApiKey.ts b/src/utils/authenticateMasterApiKey.ts new file mode 100644 index 0000000..a3ac643 --- /dev/null +++ b/src/utils/authenticateMasterApiKey.ts @@ -0,0 +1,40 @@ +import { createHmac, timingSafeEqual } from "crypto"; +import { AuthError } from "../errors/auth"; + +function getMasterKeyHash(): string { + const hash = process.env.MASTER_API_KEY_HASH; + if (!hash) { + throw new Error("MASTER_API_KEY_HASH environment variable is not set"); + } + return hash; +} + +export function authenticateMasterApiKey(authHeader: string | undefined): void { + if (!authHeader) { + throw AuthError.missingHeader(); + } + + if (!authHeader.startsWith("Bearer ")) { + throw AuthError.invalidHeaderFormat(); + } + + const apiKey = authHeader.slice("Bearer ".length).trim(); + + const hmacSecret = process.env.HMAC_SECRET; + if (!hmacSecret) { + throw new Error("HMAC_SECRET environment variable is not set"); + } + + const incomingHash = createHmac("sha256", hmacSecret) + .update(apiKey) + .digest("hex"); + + const storedHash = getMasterKeyHash(); + + if ( + incomingHash.length !== storedHash.length || + !timingSafeEqual(Buffer.from(incomingHash), Buffer.from(storedHash)) + ) { + throw AuthError.permissionDenied("Invalid master API key"); + } +} diff --git a/src/utils/fetchTagAmount.ts b/src/utils/fetchTagAmount.ts index 3b72d32..88ef60e 100644 --- a/src/utils/fetchTagAmount.ts +++ b/src/utils/fetchTagAmount.ts @@ -5,10 +5,12 @@ import { tagsTable } from "../storage/db/postgres/schema"; import { tagCache } from "./tagCache"; export async function fetchTagAmount( + projectId: string, tag: string, notFoundMessage: string ): Promise { - const cachedAmount = tagCache.get(tag); + const cacheKey = `${projectId}:${tag}`; + const cachedAmount = tagCache.get(cacheKey); if (cachedAmount !== undefined) { return cachedAmount; } @@ -17,13 +19,19 @@ export async function fetchTagAmount( const [tagRow] = await db .select() .from(tagsTable) - .where(and(eq(tagsTable.key, tag), isNull(tagsTable.deletedAt))) + .where( + and( + eq(tagsTable.projectId, projectId), + eq(tagsTable.key, tag), + isNull(tagsTable.deletedAt) + ) + ) .limit(1); if (!tagRow) { throw EventError.validationFailed(notFoundMessage); } - tagCache.set(tag, tagRow.amount); + tagCache.set(cacheKey, tagRow.amount); return tagRow.amount; } diff --git a/src/utils/generateInitialAPIKey.ts b/src/utils/generateInitialAPIKey.ts index ecff084..7815baf 100644 --- a/src/utils/generateInitialAPIKey.ts +++ b/src/utils/generateInitialAPIKey.ts @@ -1,6 +1,9 @@ +import { sql } from "drizzle-orm"; import { createHmac, randomUUID } from "crypto"; import { generateAPIKey } from "./generateAPIKey"; import { DateTime } from "luxon"; +import { getPostgresDB } from "../storage/db/postgres/db"; +import { projectsTable, apiKeysTable } from "../storage/db/postgres/schema"; const HMAC_SECRET = process.env.HMAC_SECRET; @@ -17,18 +20,19 @@ function hashAPIKey(apiKey: string): string { } export type InitialApiKeyData = { + projectId: string; apiKeyId: string; apiKey: string; apiKeyHash: string; name: string; - role: string; + role: "dashboard" | "production" | "test"; createdAt: string; expiresAt: string; - insertSql: string; authorizationHeader: string; }; export function generateInitialApiKeyData(): InitialApiKeyData { + const projectId = randomUUID(); const apiKeyId = randomUUID(); const apiKey = generateAPIKey("dashboard"); const apiKeyHash = hashAPIKey(apiKey); @@ -37,20 +41,8 @@ export function generateInitialApiKeyData(): InitialApiKeyData { const createdAt = DateTime.utc().toISO(); const expiresAt = DateTime.utc().plus({ days: 365 }).toISO(); - const insertSql = - "INSERT INTO api_keys (id, name, key, role, created_at, expires_at, revoked, revoked_at)\n" + - "VALUES (\n" + - ` '${apiKeyId}',\n` + - ` '${name}',\n` + - ` '${apiKeyHash}',\n` + - ` '${role}',\n` + - ` '${createdAt}',\n` + - ` '${expiresAt}',\n` + - " false,\n" + - " NULL\n" + - ");"; - return { + projectId, apiKeyId, apiKey, apiKeyHash, @@ -58,9 +50,49 @@ export function generateInitialApiKeyData(): InitialApiKeyData { role, createdAt, expiresAt, - insertSql, authorizationHeader: `Authorization: Bearer ${apiKey}`, }; } -console.log(generateInitialApiKeyData()); +async function insertInitialData(data: InitialApiKeyData) { + const db = getPostgresDB(process.env.DATABASE_URL); + + const existing = await db + .select({ id: projectsTable.id }) + .from(projectsTable) + .where(sql`name = 'Default Project'`) + .limit(1); + + if (existing.length > 0) { + console.log( + `Default Project already exists (id=${existing[0]!.id}). Skipping seed.` + ); + return; + } + + await db.insert(projectsTable).values({ + id: data.projectId, + name: "Default Project", + createdAt: data.createdAt, + }); + + await db.insert(apiKeysTable).values({ + id: data.apiKeyId, + projectId: data.projectId, + name: data.name, + key: data.apiKeyHash, + role: data.role, + createdAt: data.createdAt, + expiresAt: data.expiresAt, + revoked: false, + revokedAt: null, + }); +} + +const data = generateInitialApiKeyData(); + +await insertInitialData(data); + +console.log("Initial API key generation was successful.."); +console.log(data); +process.exit(0); diff --git a/src/utils/parseExpr.ts b/src/utils/parseExpr.ts index c50584b..8a5559a 100644 --- a/src/utils/parseExpr.ts +++ b/src/utils/parseExpr.ts @@ -37,10 +37,10 @@ const ALLOWED_FUNCTIONS = new Set([ "div", "tag", "expr", - "inputTokens", - "outputTokens", - "inputCacheTokens", - "outputCacheTokens", + "inputtokens", + "outputtokens", + "inputcachetokens", + "outputcachetokens", ]); /** @@ -184,6 +184,7 @@ export function validateExprSyntax(exprString: string): void { */ export async function resolveExprRefsInExpression( exprString: string, + projectId: string, resolving: Set = new Set() ): Promise { const refs = extractExprRefs(exprString); @@ -201,14 +202,18 @@ export async function resolveExprRefsInExpression( ); } - const storedExpr = await findExpressionByKey(refName); + const storedExpr = await findExpressionByKey(projectId, refName); if (!storedExpr) { throw EventError.validationFailed(`Expression not found: ${refName}`); } resolving.add(refName); - const expanded = await resolveExprRefsInExpression(storedExpr, resolving); + const expanded = await resolveExprRefsInExpression( + storedExpr, + projectId, + resolving + ); const refPattern = new RegExp(`expr\\(${refName}\\)`, "g"); resolved = resolved.replace(refPattern, `(${expanded})`); @@ -242,7 +247,10 @@ function extractExprRefs(exprString: string): string[] { * @returns The expression string with tags replaced by their numeric values * @throws EventError if any tag is not found */ -async function resolveTagsInExpression(exprString: string): Promise { +async function resolveTagsInExpression( + exprString: string, + projectId: string +): Promise { const tagNames = extractTagNames(exprString); if (tagNames.length === 0) { @@ -253,7 +261,11 @@ async function resolveTagsInExpression(exprString: string): Promise { const tagValues = new Map(); for (const tagName of tagNames) { - const value = await fetchTagAmount(tagName, `Tag not found: ${tagName}`); + const value = await fetchTagAmount( + projectId, + tagName, + `Tag not found: ${tagName}` + ); tagValues.set(tagName, value); } @@ -323,16 +335,20 @@ function resolveTokenPlaceholders( */ export async function parseAndEvaluateExpr( exprString: string, + projectId: string, tokenContext?: EvalTokenContext ): Promise { // Step 1: Validate syntax validateExprSyntax(exprString); // Step 2: Resolve all expr(NAME) references (recursive, from DB) - const expandedExpr = await resolveExprRefsInExpression(exprString); + const expandedExpr = await resolveExprRefsInExpression(exprString, projectId); // Step 3: Resolve all tags to their values - const tagResolvedExpr = await resolveTagsInExpression(expandedExpr); + const tagResolvedExpr = await resolveTagsInExpression( + expandedExpr, + projectId + ); // Step 4: Resolve token placeholders if context provided const finalExpr = tokenContext diff --git a/src/zod/event.ts b/src/zod/event.ts index c6d5e6a..bb58230 100644 --- a/src/zod/event.ts +++ b/src/zod/event.ts @@ -28,166 +28,195 @@ const BaseEvent = z.object({ idempotencyKey: z.string().min(1), }); -const BasicUsageDataSchema: z.ZodType = z - .object({ - basicUsageType: z.union([ - z - .literal(BasicUsageType.BASIC_USAGE_TYPE_UNSPECIFIED) - .transform(() => "RAW" as const), - z.literal(BasicUsageType.RAW).transform(() => "RAW" as const), - z - .literal(BasicUsageType.MIDDLEWARE_CALL) - .transform(() => "MIDDLEWARE_CALL" as const), - ]), - amount: z.number().optional(), - tag: z.string().optional(), - expr: z.string().optional(), - metadata: z.string().optional(), - }) - .transform(async (v): Promise => { - let debitAmount: number; - if (v.tag) { - debitAmount = await fetchTagAmount(v.tag, `Tag not found: ${v.tag}`); - } else if (v.expr) { - debitAmount = await parseAndEvaluateExpr(v.expr); - } else { - debitAmount = v.amount ?? 0; - } - return { - basicUsageType: v.basicUsageType, - debitAmount, - metadata: v.metadata ? parseMetadata(v.metadata) : undefined, - }; - }); - -const AITokenUsageDataSchema: z.ZodType = z - .object({ - model: z.string().min(1), - provider: z.string().optional().default("unknown"), - inputTokens: z.number().int().min(0), - inputCacheTokens: z.number().int().min(0), - outputTokens: z.number().int().min(0), - inputAmount: z.number().optional(), - inputTag: z.string().optional(), - inputExpr: z.string().optional(), - inputCacheAmount: z.number().optional(), - inputCacheTag: z.string().optional(), - inputCacheExpr: z.string().optional(), - outputCacheTokens: z.number().int().min(0), - outputCacheAmount: z.number().optional(), - outputCacheTag: z.string().optional(), - outputCacheExpr: z.string().optional(), - outputAmount: z.number().optional(), - outputTag: z.string().optional(), - outputExpr: z.string().optional(), - metadata: z.string().optional(), - }) - .transform(async (v): Promise => { - const tokenContext = { - inputTokens: v.inputTokens, - inputCacheTokens: v.inputCacheTokens, - outputTokens: v.outputTokens, - outputCacheTokens: v.outputCacheTokens, - }; - - let inputDebitAmount: number; - if (v.inputTag) { - inputDebitAmount = await fetchTagAmount( - v.inputTag, - `Input tag not found: ${v.inputTag}` - ); - } else if (v.inputExpr) { - inputDebitAmount = await parseAndEvaluateExpr(v.inputExpr, tokenContext); - } else { - inputDebitAmount = v.inputAmount ?? 0; - } +function createBasicUsageDataSchema( + projectId: string +): z.ZodType { + return z + .object({ + basicUsageType: z.union([ + z + .literal(BasicUsageType.BASIC_USAGE_TYPE_UNSPECIFIED) + .transform(() => "RAW" as const), + z.literal(BasicUsageType.RAW).transform(() => "RAW" as const), + z + .literal(BasicUsageType.MIDDLEWARE_CALL) + .transform(() => "MIDDLEWARE_CALL" as const), + ]), + amount: z.number().optional(), + tag: z.string().optional(), + expr: z.string().optional(), + metadata: z.string().optional(), + }) + .transform(async (v): Promise => { + let debitAmount: number; + if (v.tag) { + debitAmount = await fetchTagAmount( + projectId, + v.tag, + `Tag not found: ${v.tag}` + ); + } else if (v.expr) { + debitAmount = await parseAndEvaluateExpr(v.expr, projectId); + } else { + debitAmount = v.amount ?? 0; + } + return { + basicUsageType: v.basicUsageType, + debitAmount, + metadata: v.metadata ? parseMetadata(v.metadata) : undefined, + }; + }); +} - let inputCacheDebitAmount: number; - if (v.inputCacheTag) { - inputCacheDebitAmount = await fetchTagAmount( - v.inputCacheTag, - `Input cache tag not found: ${v.inputCacheTag}` - ); - } else if (v.inputCacheExpr) { - inputCacheDebitAmount = await parseAndEvaluateExpr( - v.inputCacheExpr, - tokenContext - ); - } else { - inputCacheDebitAmount = v.inputCacheAmount ?? 0; - } +function createAITokenUsageDataSchema( + projectId: string +): z.ZodType { + return z + .object({ + model: z.string().min(1), + provider: z.string().optional().default("unknown"), + inputTokens: z.number().int().min(0), + inputCacheTokens: z.number().int().min(0), + outputTokens: z.number().int().min(0), + outputCacheTokens: z.number().int().min(0), + inputTag: z.string().optional(), + inputExpr: z.string().optional(), + inputAmount: z.number().optional(), + inputCacheTag: z.string().optional(), + inputCacheExpr: z.string().optional(), + inputCacheAmount: z.number().optional(), + outputTag: z.string().optional(), + outputExpr: z.string().optional(), + outputAmount: z.number().optional(), + outputCacheTag: z.string().optional(), + outputCacheExpr: z.string().optional(), + outputCacheAmount: z.number().optional(), + metadata: z.string().optional(), + }) + .transform(async (v): Promise => { + const tokenContext = { + inputTokens: v.inputTokens, + inputCacheTokens: v.inputCacheTokens, + outputTokens: v.outputTokens, + outputCacheTokens: v.outputCacheTokens, + }; - let outputCacheDebitAmount: number; - if (v.outputCacheTag) { - outputCacheDebitAmount = await fetchTagAmount( - v.outputCacheTag, - `Output cache tag not found: ${v.outputCacheTag}` - ); - } else if (v.outputCacheExpr) { - outputCacheDebitAmount = await parseAndEvaluateExpr( - v.outputCacheExpr, - tokenContext - ); - } else { - outputCacheDebitAmount = v.outputCacheAmount ?? 0; - } + let inputDebitAmount: number; + if (v.inputTag) { + inputDebitAmount = await fetchTagAmount( + projectId, + v.inputTag, + `Input tag not found: ${v.inputTag}` + ); + } else if (v.inputExpr) { + inputDebitAmount = await parseAndEvaluateExpr( + v.inputExpr, + projectId, + tokenContext + ); + } else { + inputDebitAmount = v.inputAmount ?? 0; + } - let outputDebitAmount: number; - if (v.outputTag) { - outputDebitAmount = await fetchTagAmount( - v.outputTag, - `Output tag not found: ${v.outputTag}` - ); - } else if (v.outputExpr) { - outputDebitAmount = await parseAndEvaluateExpr( - v.outputExpr, - tokenContext - ); - } else { - outputDebitAmount = v.outputAmount ?? 0; - } + let inputCacheDebitAmount: number; + if (v.inputCacheTag) { + inputCacheDebitAmount = await fetchTagAmount( + projectId, + v.inputCacheTag, + `Input cache tag not found: ${v.inputCacheTag}` + ); + } else if (v.inputCacheExpr) { + inputCacheDebitAmount = await parseAndEvaluateExpr( + v.inputCacheExpr, + projectId, + tokenContext + ); + } else { + inputCacheDebitAmount = v.inputCacheAmount ?? 0; + } - return { - model: v.model, - provider: v.provider, - inputTokens: v.inputTokens, - inputCacheTokens: v.inputCacheTokens, - outputTokens: v.outputTokens, - outputCacheTokens: v.outputCacheTokens, - inputDebitAmount, - inputCacheDebitAmount, - outputCacheDebitAmount, - outputDebitAmount, - metadata: v.metadata ? parseMetadata(v.metadata) : undefined, - }; - }); + let outputCacheDebitAmount: number; + if (v.outputCacheTag) { + outputCacheDebitAmount = await fetchTagAmount( + projectId, + v.outputCacheTag, + `Output cache tag not found: ${v.outputCacheTag}` + ); + } else if (v.outputCacheExpr) { + outputCacheDebitAmount = await parseAndEvaluateExpr( + v.outputCacheExpr, + projectId, + tokenContext + ); + } else { + outputCacheDebitAmount = v.outputCacheAmount ?? 0; + } -const RegisterEventBasicUsage = BaseEvent.extend({ - type: z - .literal(EventType.BASIC_USAGE) - .transform(() => "BASIC_USAGE" as const), - basicUsage: BasicUsageDataSchema, -}); + let outputDebitAmount: number; + if (v.outputTag) { + outputDebitAmount = await fetchTagAmount( + projectId, + v.outputTag, + `Output tag not found: ${v.outputTag}` + ); + } else if (v.outputExpr) { + outputDebitAmount = await parseAndEvaluateExpr( + v.outputExpr, + projectId, + tokenContext + ); + } else { + outputDebitAmount = v.outputAmount ?? 0; + } -const StreamEventBasicUsage = BaseEvent.extend({ - type: z - .literal(EventType.BASIC_USAGE) - .transform(() => "BASIC_USAGE" as const), - basicUsage: BasicUsageDataSchema, -}); + return { + model: v.model, + provider: v.provider, + inputTokens: v.inputTokens, + inputCacheTokens: v.inputCacheTokens, + outputTokens: v.outputTokens, + outputCacheTokens: v.outputCacheTokens, + inputDebitAmount, + inputCacheDebitAmount, + outputCacheDebitAmount, + outputDebitAmount, + metadata: v.metadata ? parseMetadata(v.metadata) : undefined, + }; + }); +} -const StreamEventAITokenUsage = BaseEvent.extend({ - type: z - .literal(EventType.AI_TOKEN_USAGE) - .transform(() => "AI_TOKEN_USAGE" as const), - aiTokenUsage: AITokenUsageDataSchema, -}); +export function createRegisterEventSchema(projectId: string) { + const BasicUsageDataSchema = createBasicUsageDataSchema(projectId); + return BaseEvent.extend({ + type: z + .literal(EventType.BASIC_USAGE) + .transform(() => "BASIC_USAGE" as const), + basicUsage: BasicUsageDataSchema, + }); +} -export const registerEventSchema = RegisterEventBasicUsage; -export type RegisterEventSchemaType = z.output; +export function createStreamEventSchema(projectId: string) { + const BasicUsageDataSchema = createBasicUsageDataSchema(projectId); + const AITokenUsageDataSchema = createAITokenUsageDataSchema(projectId); + return z.union([ + BaseEvent.extend({ + type: z + .literal(EventType.BASIC_USAGE) + .transform(() => "BASIC_USAGE" as const), + basicUsage: BasicUsageDataSchema, + }), + BaseEvent.extend({ + type: z + .literal(EventType.AI_TOKEN_USAGE) + .transform(() => "AI_TOKEN_USAGE" as const), + aiTokenUsage: AITokenUsageDataSchema, + }), + ]); +} -export const streamEventSchema = z.union([ - StreamEventBasicUsage, - StreamEventAITokenUsage, -]); -export type StreamEventSchemaType = z.output; +export type RegisterEventSchemaType = z.output< + ReturnType +>; +export type StreamEventSchemaType = z.output< + ReturnType +>; diff --git a/src/zod/internals.ts b/src/zod/internals.ts index 938d40b..525b71d 100644 --- a/src/zod/internals.ts +++ b/src/zod/internals.ts @@ -33,6 +33,7 @@ export function createFilterGroupSchema( } export const onboardingSchema = z.object({ + name: z.string().min(1, "Project name is required").max(255), dodoLiveApiKey: z.string().min(1, "Dodo live API key is required"), dodoTestApiKey: z.string().min(1, "Dodo test API key is required"), dodoLiveProductId: z.string().min(1, "Dodo live product ID is required"),