Skip to content
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5432/scrawn
CLICKHOUSE_URL=http://default:clickhouse@localhost:8123/scrawn

HMAC_SECRET=
MASTER_API_KEY_HASH= # hex(HMAC-SHA256(HMAC_SECRET, <raw master key>)) — computed with HMAC_SECRET as the signing key (used for project creation during onboarding)
APP_URL=http://localhost:8060 # URL the Scrawn backend is hosted on

# SENTRY_DSN=
Expand Down
3 changes: 2 additions & 1 deletion src/__tests__/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ import { getPostgresDB } from "../storage/db/postgres/db";
import { webhookEndpointsTable } from "../storage/db/postgres/schema";
import { DateTime } from "luxon";
import { clearDatabase } from "./db";
import { insertKey } from "./fixtures/apiKey";
import { insertKey, TEST_PROJECT_ID } from "./fixtures/apiKey";

async function insertWebhookEndpoint(apiKeyId: string): Promise<void> {
const db = getPostgresDB();
await db.insert(webhookEndpointsTable).values({
projectId: TEST_PROJECT_ID,
apiKeyId,
url: "https://example.com/webhook",
privateKey: "test-private-key",
Expand Down
4 changes: 3 additions & 1 deletion src/__tests__/createAPIKey.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
registerEvent,
} from "./fixtures/grpc";
import { verifyApiKeyCreated } from "./assertions/events";
import { createTestApiKey } from "./fixtures/apiKey";
import { createTestApiKey, TEST_PROJECT_ID } from "./fixtures/apiKey";
import { getPostgresDB } from "../storage/db/postgres/db";
import { hashAPIKey } from "../utils/hashAPIKey";
import {
Expand Down Expand Up @@ -44,6 +44,7 @@ async function createDashboardApiKey(): Promise<{
key: hashAPIKey(rawKey),
role: "dashboard",
expiresAt: DateTime.utc().plus({ years: 1 }).toISO(),
projectId: TEST_PROJECT_ID,
})
.returning({ id: apiKeysTable.id });
return { rawKey, id: key!.id };
Expand Down Expand Up @@ -149,6 +150,7 @@ describe("AuthService", () => {

const db = getPostgresDB();
await db.insert(webhookEndpointsTable).values({
projectId: TEST_PROJECT_ID,
apiKeyId: res.apiKeyId,
url: "https://example.com/webhook",
privateKey: "test-private-key",
Expand Down
3 changes: 2 additions & 1 deletion src/__tests__/db/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ export async function clearDatabase() {
users,
tags,
metadata,
expressions
expressions,
projects
RESTART IDENTITY CASCADE
`);

Expand Down
23 changes: 23 additions & 0 deletions src/__tests__/fixtures/apiKey.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
import { getPostgresDB } from "../../storage/db/postgres/db";
import {
projectsTable,
apiKeysTable,
webhookEndpointsTable,
} from "../../storage/db/postgres/schema";
import { eq } from "drizzle-orm";
import { hashAPIKey } from "../../utils/hashAPIKey";
import { DateTime } from "luxon";

export const TEST_PROJECT_ID = "00000000-0000-0000-0000-000000000001";

async function ensureTestProject(): Promise<void> {
const db = getPostgresDB();
const [existing] = await db
.select({ id: projectsTable.id })
.from(projectsTable)
.where(eq(projectsTable.id, TEST_PROJECT_ID))
.limit(1);
if (existing) return;
await db.insert(projectsTable).values({
id: TEST_PROJECT_ID,
name: "test-project",
});
}

export async function createTestApiKey(): Promise<{
rawKey: string;
id: string;
}> {
const db = getPostgresDB();
await ensureTestProject();
const rawKey = `scrn_test_${crypto.randomUUID().replace(/-/g, "").slice(0, 32)}`;
const [key] = await db
.insert(apiKeysTable)
Expand All @@ -19,10 +38,12 @@ export async function createTestApiKey(): Promise<{
key: hashAPIKey(rawKey),
role: "test",
expiresAt: DateTime.utc().plus({ years: 1 }).toISO(),
projectId: TEST_PROJECT_ID,
})
.returning({ id: apiKeysTable.id });

await db.insert(webhookEndpointsTable).values({
projectId: TEST_PROJECT_ID,
apiKeyId: key!.id,
url: "https://example.com/webhook",
privateKey: "test-private-key",
Expand All @@ -38,6 +59,7 @@ export async function insertKey(
overrides: Partial<{ revoked: boolean; expiresAt: string }> = {}
): Promise<string> {
const db = getPostgresDB();
await ensureTestProject();
const [key] = await db
.insert(apiKeysTable)
.values({
Expand All @@ -47,6 +69,7 @@ export async function insertKey(
expiresAt:
overrides.expiresAt ?? DateTime.utc().plus({ years: 1 }).toISO(),
revoked: overrides.revoked ?? false,
projectId: TEST_PROJECT_ID,
})
.returning({ id: apiKeysTable.id });
return key!.id;
Expand Down
1 change: 1 addition & 0 deletions src/context/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ export interface AuthContext {
apiKeyId: string;
role: ApiKeyRole;
mode: "production" | "test" | null;
projectId: string;
}
6 changes: 5 additions & 1 deletion src/interceptors/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ export function authInterceptor<Req, Res>(
apiKeyId: cached.id,
role: cached.role,
mode: cached.mode,
projectId: cached.projectId,
};
wideEventBuilder?.setAuth(cached.id, true);

Expand Down Expand Up @@ -201,7 +202,7 @@ export function authInterceptor<Req, Res>(

if (
DateTime.utc() >
DateTime.fromSQL(apiKeyRecord.expiresAt, { zone: "utc" })
DateTime.fromISO(apiKeyRecord.expiresAt, { zone: "utc" })
) {
return callback?.(AuthError.expiredAPIKey());
}
Expand All @@ -220,13 +221,15 @@ export function authInterceptor<Req, Res>(
id: apiKeyRecord.id,
role: apiKeyRecord.role as ApiKeyRole,
mode: recordMode,
projectId: apiKeyRecord.projectId,
expiresAt: apiKeyRecord.expiresAt,
});

call[apiKeyContextKey] = {
apiKeyId: apiKeyRecord.id,
role: apiKeyRecord.role as ApiKeyRole,
mode: recordMode,
projectId: apiKeyRecord.projectId,
};
wideEventBuilder?.setAuth(apiKeyRecord.id, false);

Expand Down Expand Up @@ -270,6 +273,7 @@ async function lookupApiKey(apiKeyHash: string) {
role: apiKeysTable.role,
expiresAt: apiKeysTable.expiresAt,
revoked: apiKeysTable.revoked,
projectId: apiKeysTable.projectId,
})
.from(apiKeysTable)
.where(eq(apiKeysTable.key, apiKeyHash))
Expand Down
1 change: 1 addition & 0 deletions src/routes/gRPC/auth/createAPIKey.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ export async function createAPIKey(
key: apiKeyHash,
role: validatedData.role,
expiresAt: expiresAt.toISO(),
projectId: auth.projectId,
});

if (!keyEventData) {
Expand Down
34 changes: 19 additions & 15 deletions src/routes/gRPC/data/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import type { sendUnaryData } from "@grpc/grpc-js";
import { QueryRequest, QueryResponse, Row } from "../../../gen/data/v1/data";
import { dataQuerySchema, type DataQueryRequest } from "../../../zod/data";
import { EventError } from "../../../errors/event";
import { AuthError } from "../../../errors/auth";
import { formatZodError } from "../../../utils/formatZodError";
import { getPostgresDB } from "../../../storage/db/postgres/db";
import {
usersTable,
sessionsTable,
tagsTable,
expressionsTable,
metadataTable,
} from "../../../storage/db/postgres/schema";
import {
eq,
Expand All @@ -29,21 +29,23 @@ import type { SQL } from "drizzle-orm";
import type { AnyPgColumn } from "drizzle-orm/pg-core";
import type { WideEventBuilder } from "../../../context/requestContext";
import { wideEventContextKey } from "../../../context/requestContext";
import { apiKeyContextKey } from "../../../context/auth";
import type { ContextUnaryCall } from "../../../interface/types/context.js";

interface FieldDef {
col: AnyPgColumn;
cast: "text" | "integer" | "uuid" | "timestamptz" | "boolean";
}

type ScopedTable =
| typeof usersTable
| typeof sessionsTable
| typeof tagsTable
| typeof expressionsTable;

interface TableDef {
tableName: string;
table:
| typeof usersTable
| typeof sessionsTable
| typeof tagsTable
| typeof expressionsTable
| typeof metadataTable;
table: ScopedTable;
fields: Record<string, FieldDef>;
}

Expand Down Expand Up @@ -95,13 +97,6 @@ const TABLE_REGISTRY: Record<string, TableDef> = {
expr: { col: expressionsTable.expr, cast: "text" },
},
},
metadata: {
tableName: "metadata",
table: metadataTable,
fields: {
id: { col: metadataTable.id, cast: "uuid" },
},
},
};

function castValue(
Expand Down Expand Up @@ -206,6 +201,11 @@ export async function queryData(
| undefined;

try {
const auth = call[apiKeyContextKey];
if (!auth) {
return callback?.(AuthError.invalidAPIKey("API key context not found"));
}

const req = { ...call.request } as Record<string, unknown>;

const validated = dataQuerySchema.parse(req);
Expand All @@ -223,7 +223,11 @@ export async function queryData(
}

const db = getPostgresDB();
const whereClause = buildWhere(validated.where, tableDef);
const userWhere = buildWhere(validated.where, tableDef);
const projectFilter = eq(tableDef.table.projectId, auth.projectId) as SQL;
const whereClause = userWhere
? and(projectFilter, userWhere)
: projectFilter;
const selectCols = buildSelect(tableDef);
const columns = Object.keys(tableDef.fields);

Expand Down
6 changes: 4 additions & 2 deletions src/routes/gRPC/events/registerEvent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
import type { WideEventBuilder } from "../../../context/requestContext";
import { apiKeyContextKey } from "../../../context/auth";
import { wideEventContextKey } from "../../../context/requestContext";
import { registerEventSchema } from "../../../zod/event";
import { createRegisterEventSchema } from "../../../zod/event";
import { EventError } from "../../../errors/event";
import { AuthError } from "../../../errors/auth";
import { createEventInstance, storeEvent } from "../../../utils/eventHelpers";
Expand Down Expand Up @@ -33,7 +33,9 @@ export async function registerEvent(
);
}

const eventSkeleton = await registerEventSchema.parseAsync({ ...req });
const eventSkeleton = await createRegisterEventSchema(
auth.projectId
).parseAsync({ ...req });

wideEventBuilder?.setUser(eventSkeleton.userId);
wideEventBuilder?.setEventContext({ eventType: eventSkeleton.type });
Expand Down
6 changes: 4 additions & 2 deletions src/routes/gRPC/events/streamEvents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
import { EventError } from "../../../errors/event";
import { AuthError } from "../../../errors/auth";
import { StorageError } from "../../../errors/storage";
import { streamEventSchema } from "../../../zod/event";
import { createStreamEventSchema } from "../../../zod/event";
import { createEventInstance, storeEvent } from "../../../utils/eventHelpers";
import { apiKeyContextKey } from "../../../context/auth";
import { wideEventContextKey } from "../../../context/requestContext";
Expand Down Expand Up @@ -88,7 +88,9 @@ export async function streamEvents(

for await (const req of call) {
try {
const eventSkeleton = await streamEventSchema.parseAsync({ ...req });
const eventSkeleton = await createStreamEventSchema(
auth.projectId
).parseAsync({ ...req });

wideEventBuilder?.setUser(eventSkeleton.userId);
wideEventBuilder?.setEventContext({ eventType: "AI_TOKEN_USAGE" });
Expand Down
24 changes: 19 additions & 5 deletions src/routes/gRPC/payment/createCheckoutLink.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import { getPostgresDB } from "../../../storage/db/postgres/db";
import { checkIfExistingCheckoutLink } from "../../../storage/db/postgres/helpers/sessions";
import { ensureUserExists } from "../../../storage/db/postgres/helpers/users";
import { usersTable } from "../../../storage/db/postgres/schema";
import { eq } from "drizzle-orm";
import { eq, and } from "drizzle-orm";

export async function createCheckoutLink(
call: ContextUnaryCall<CreateCheckoutLinkRequest, CreateCheckoutLinkResponse>,
Expand Down Expand Up @@ -64,7 +64,7 @@ export async function createCheckoutLink(

const mode = auth.mode;

const config = await getPaymentProviderConfig(mode);
const config = await getPaymentProviderConfig(auth.projectId, mode);
const validatedData = validateRequest(req);
wideEventBuilder?.setUser(validatedData.userId);

Expand All @@ -80,6 +80,7 @@ export async function createCheckoutLink(
wideEventBuilder?.setPaymentContext({ priceAmount: custom_price });

const checkoutResult = await createCheckoutSession(
auth.projectId,
config,
custom_price,
validatedData.userId,
Expand All @@ -92,16 +93,22 @@ export async function createCheckoutLink(
db,
"create checkout link",
async (txn) => {
await ensureUserExists(validatedData.userId, txn);
await ensureUserExists(auth.projectId, validatedData.userId, txn);

await txn
.select({ id: usersTable.id })
.from(usersTable)
.where(eq(usersTable.id, validatedData.userId))
.where(
and(
eq(usersTable.projectId, auth.projectId),
eq(usersTable.id, validatedData.userId)
)
)
.for("update");

const existingId = await checkIfExistingCheckoutLink(
txn,
auth.projectId,
validatedData.userId,
mode
);
Expand All @@ -112,6 +119,7 @@ export async function createCheckoutLink(
}

const sessionResult = await handleAddSession(
auth.projectId,
validatedData.userId,
checkoutResult.sessionId,
beforeTimestamp,
Expand Down Expand Up @@ -164,6 +172,7 @@ async function calculatePrice(
}

async function createCheckoutSession(
projectId: string,
config: PaymentProviderConfig,
customPrice: number,
userId: string,
Expand All @@ -177,7 +186,12 @@ async function createCheckoutSession(
apiKeyId,
};

const checkoutResult = await createProviderCheckout(config, params, mode);
const checkoutResult = await createProviderCheckout(
projectId,
config,
params,
mode
);

if (
!checkoutResult.checkoutUrl ||
Expand Down
Loading
Loading