diff --git a/client/src/app/app.routes.ts b/client/src/app/app.routes.ts index 0ecb81d..76e4225 100644 --- a/client/src/app/app.routes.ts +++ b/client/src/app/app.routes.ts @@ -28,6 +28,11 @@ export const routes: Routes = [ canActivate: [authGuard], loadComponent: () => import('./features/chat/chat.component').then((m) => m.ChatComponent), }, + { + path: 'billing', + canActivate: [authGuard], + loadComponent: () => import('./pages/billing/billing').then((m) => m.BillingComponent), + }, { path: 'projects/:projectId/pipeline', canActivate: [authGuard], diff --git a/client/src/app/core/services/billing.service.ts b/client/src/app/core/services/billing.service.ts new file mode 100644 index 0000000..5dd4769 --- /dev/null +++ b/client/src/app/core/services/billing.service.ts @@ -0,0 +1,57 @@ +import { Injectable, inject, signal } from '@angular/core'; +import { HttpClient } from '@angular/common/http'; +import { firstValueFrom } from 'rxjs'; + +export interface SubscriptionInfo { + tier: 'Free' | 'Pro' | 'Enterprise'; + status: string; + currentPeriodEnd: string | null; + stripeCustomerId: string | null; +} + +export interface UsageInfo { + queriesUsed: number; + queriesLimit: number; + documentsUsed: number; + documentsLimit: number; + projectsUsed: number; + projectsLimit: number; + storageBytesUsed: number; + storageBytesLimit: number; + tier: 'Free' | 'Pro' | 'Enterprise'; +} + +@Injectable({ providedIn: 'root' }) +export class BillingService { + private http = inject(HttpClient); + + readonly subscription = signal(null); + readonly usage = signal(null); + readonly loading = signal(false); + + async loadSubscription(): Promise { + const sub = await firstValueFrom(this.http.get('/api/billing/subscription')); + this.subscription.set(sub); + } + + async loadUsage(): Promise { + const usage = await firstValueFrom(this.http.get('/api/billing/usage')); + this.usage.set(usage); + } + + async createCheckoutSession(tier: string): Promise { + const res = await firstValueFrom(this.http.post<{ url: string }>('/api/billing/create-checkout-session', { + tier, + successUrl: `${window.location.origin}/billing?success=true`, + cancelUrl: `${window.location.origin}/billing?cancelled=true`, + })); + return res.url; + } + + async createPortalSession(): Promise { + const res = await firstValueFrom(this.http.post<{ url: string }>('/api/billing/create-portal-session', { + returnUrl: `${window.location.origin}/billing`, + })); + return res.url; + } +} diff --git a/client/src/app/features/pipeline-builder/pipeline-block.model.spec.ts b/client/src/app/features/pipeline-builder/pipeline-block.model.spec.ts index 6349cd8..6e61bed 100644 --- a/client/src/app/features/pipeline-builder/pipeline-block.model.spec.ts +++ b/client/src/app/features/pipeline-builder/pipeline-block.model.spec.ts @@ -1,93 +1,58 @@ -import { defaultBlocks, getBlockFields, BLOCK_ICONS, BLOCK_LABELS, BlockType } from './pipeline-block.model'; +import { createDefaultBlocks, BlockType } from './pipeline-block.model'; describe('PipelineBlock Model', () => { - describe('defaultBlocks', () => { + describe('createDefaultBlocks', () => { it('should return 5 blocks', () => { - expect(defaultBlocks().length).toBe(5); + expect(createDefaultBlocks().length).toBe(5); }); it('should have correct order: source → chunking → embedding → retrieval → generation', () => { - const types = defaultBlocks().map(b => b.type); + const types = createDefaultBlocks().map(b => b.type); expect(types).toEqual(['source', 'chunking', 'embedding', 'retrieval', 'generation']); }); it('should have unique IDs', () => { - const ids = defaultBlocks().map(b => b.id); + const ids = createDefaultBlocks().map(b => b.id); expect(new Set(ids).size).toBe(5); }); it('should have default config values for chunking', () => { - const chunking = defaultBlocks().find(b => b.type === 'chunking')!; + const chunking = createDefaultBlocks().find(b => b.type === 'chunking')!; expect(chunking.config['chunkSize']).toBe(512); expect(chunking.config['chunkOverlap']).toBe(50); expect(chunking.config['strategy']).toBe('recursive'); }); it('should have default config values for embedding', () => { - const embedding = defaultBlocks().find(b => b.type === 'embedding')!; + const embedding = createDefaultBlocks().find(b => b.type === 'embedding')!; expect(embedding.config['model']).toBe('text-embedding-3-small'); expect(embedding.config['dimensions']).toBe(1536); }); it('should have default config values for generation', () => { - const gen = defaultBlocks().find(b => b.type === 'generation')!; + const gen = createDefaultBlocks().find(b => b.type === 'generation')!; expect(gen.config['model']).toBe('gpt-4o-mini'); expect(gen.config['temperature']).toBe(0.7); expect(gen.config['maxTokens']).toBe(2048); }); it('should return independent copies each call', () => { - const a = defaultBlocks(); - const b = defaultBlocks(); + const a = createDefaultBlocks(); + const b = createDefaultBlocks(); a[0].config['test'] = 'modified'; expect(b[0].config['test']).toBeUndefined(); }); - }); - describe('BLOCK_ICONS', () => { - it('should have icons for all 5 block types', () => { - const types: BlockType[] = ['source', 'chunking', 'embedding', 'retrieval', 'generation']; - for (const t of types) { - expect(BLOCK_ICONS[t]).toBeTruthy(); + it('should have icons for all blocks', () => { + for (const block of createDefaultBlocks()) { + expect(block.icon).toBeTruthy(); } }); - }); - - describe('BLOCK_LABELS', () => { - it('should have labels for all block types', () => { - expect(BLOCK_LABELS['source']).toBe('Data Source'); - expect(BLOCK_LABELS['generation']).toBe('Generation'); - }); - }); - - describe('getBlockFields', () => { - it('should return fields for chunking block', () => { - const fields = getBlockFields('chunking'); - expect(fields.length).toBe(3); - expect(fields.map(f => f.key)).toEqual(['strategy', 'chunkSize', 'chunkOverlap']); - }); - it('should return fields for embedding block', () => { - const fields = getBlockFields('embedding'); - expect(fields.length).toBe(2); - expect(fields[0].key).toBe('model'); - }); - - it('should return fields for retrieval block', () => { - const fields = getBlockFields('retrieval'); - expect(fields.find(f => f.key === 'topK')?.max).toBe(20); - }); - - it('should return fields for generation block', () => { - const fields = getBlockFields('generation'); - const modelField = fields.find(f => f.key === 'model')!; - expect(modelField.options!.length).toBeGreaterThanOrEqual(3); - }); - - it('should return fields for source block', () => { - const fields = getBlockFields('source'); - expect(fields.length).toBe(1); - expect(fields[0].key).toBe('sourceType'); + it('should have labels for all blocks', () => { + for (const block of createDefaultBlocks()) { + expect(block.label).toBeTruthy(); + } }); }); }); diff --git a/client/src/app/features/pipeline-builder/pipeline-block.model.ts b/client/src/app/features/pipeline-builder/pipeline-block.model.ts new file mode 100644 index 0000000..c76ddf0 --- /dev/null +++ b/client/src/app/features/pipeline-builder/pipeline-block.model.ts @@ -0,0 +1,25 @@ +export type BlockType = 'source' | 'chunking' | 'embedding' | 'retrieval' | 'generation'; + +export interface PipelineBlock { + id: string; + type: BlockType; + label: string; + icon: string; + config: Record; +} + +export interface ChunkPreview { + index: number; + text: string; + tokens: number; +} + +export function createDefaultBlocks(): PipelineBlock[] { + return [ + { id: crypto.randomUUID(), type: 'source', label: 'Document Source', icon: '📄', config: { sourceType: 'upload', fileTypes: '.pdf,.txt,.md,.docx' } }, + { id: crypto.randomUUID(), type: 'chunking', label: 'Text Chunking', icon: '✂️', config: { strategy: 'recursive', chunkSize: 512, chunkOverlap: 50 } }, + { id: crypto.randomUUID(), type: 'embedding', label: 'Embeddings', icon: '🧬', config: { model: 'text-embedding-3-small', dimensions: 1536, batchSize: 100 } }, + { id: crypto.randomUUID(), type: 'retrieval', label: 'Retrieval', icon: '🔍', config: { strategy: 'semantic', topK: 5, scoreThreshold: 0.7, reranking: false } }, + { id: crypto.randomUUID(), type: 'generation', label: 'Generation', icon: '🤖', config: { model: 'gpt-4o-mini', temperature: 0.7, maxTokens: 2048, systemPrompt: '' } }, + ]; +} diff --git a/client/src/app/features/pipeline-builder/pipeline.service.spec.ts b/client/src/app/features/pipeline-builder/pipeline.service.spec.ts new file mode 100644 index 0000000..0f81192 --- /dev/null +++ b/client/src/app/features/pipeline-builder/pipeline.service.spec.ts @@ -0,0 +1,198 @@ +import { TestBed } from '@angular/core/testing'; +import { HttpTestingController, provideHttpClientTesting } from '@angular/common/http/testing'; +import { provideHttpClient } from '@angular/common/http'; +import { PipelineService } from './pipeline.service'; + +describe('PipelineService', () => { + let service: PipelineService; + let httpMock: HttpTestingController; + + beforeEach(() => { + TestBed.configureTestingModule({ + providers: [PipelineService, provideHttpClient(), provideHttpClientTesting()], + }); + service = TestBed.inject(PipelineService); + httpMock = TestBed.inject(HttpTestingController); + }); + + afterEach(() => httpMock.verify()); + + it('should be created', () => { + expect(service).toBeTruthy(); + }); + + it('should have 5 default blocks', () => { + expect(service.blocks().length).toBe(5); + }); + + it('should start on source step', () => { + expect(service.currentStep()).toBe('source'); + }); + + it('should navigate steps forward', () => { + service.nextStep(); + expect(service.currentStep()).toBe('chunking'); + service.nextStep(); + expect(service.currentStep()).toBe('embedding'); + }); + + it('should navigate steps backward', () => { + service.goToStep('generation'); + service.prevStep(); + expect(service.currentStep()).toBe('retrieval'); + }); + + it('should not go before first step', () => { + service.prevStep(); + expect(service.currentStep()).toBe('source'); + }); + + it('should not go past last step', () => { + service.goToStep('review'); + service.nextStep(); + expect(service.currentStep()).toBe('review'); + }); + + it('should update block config', () => { + const chunkingId = service.blocks().find(b => b.type === 'chunking')!.id; + service.updateBlockConfig(chunkingId, 'chunkSize', 1024); + const chunking = service.blocks().find(b => b.type === 'chunking')!; + expect(chunking.config['chunkSize']).toBe(1024); + }); + + it('should mark dirty on config change', () => { + expect(service.hasUnsavedChanges()).toBeFalsy(); + const chunkingId = service.blocks().find(b => b.type === 'chunking')!.id; + service.updateBlockConfig(chunkingId, 'chunkSize', 1024); + expect(service.hasUnsavedChanges()).toBeTruthy(); + }); + + it('should select block', () => { + const embeddingId = service.blocks().find(b => b.type === 'embedding')!.id; + service.selectBlock(embeddingId); + expect(service.selectedBlockId()).toBe(embeddingId); + expect(service.selectedBlock()?.type).toBe('embedding'); + }); + + it('should move blocks', () => { + service.moveBlock(0, 2); + expect(service.blocks()[0].type).toBe('chunking'); + expect(service.blocks()[2].type).toBe('source'); + }); + + it('should calculate progress percent', () => { + // source is step 0, 1/6 ~ 17% + expect(service.progressPercent()).toBe(17); + service.goToStep('review'); + expect(service.progressPercent()).toBe(100); + }); + + it('should reset to defaults', () => { + const chunkingId = service.blocks().find(b => b.type === 'chunking')!.id; + service.updateBlockConfig(chunkingId, 'chunkSize', 2048); + service.resetDefaults(); + expect(service.blocks().find(b => b.type === 'chunking')!.config['chunkSize']).toBe(512); + expect(service.dirty()).toBeFalsy(); + }); + + it('should generate chunk previews reactively', () => { + const previews = service.chunkPreviews(); + expect(previews.length).toBeGreaterThan(0); + expect(previews[0].tokens).toBeGreaterThan(0); + }); + + it('should compute config from blocks', () => { + const cfg = service.config(); + expect(cfg.chunkSize).toBe(512); + expect(cfg.topK).toBe(5); + }); + + it('should load pipeline from API', async () => { + const promise = service.loadPipeline('proj-1'); + const pipeReq = httpMock.expectOne('/api/projects/proj-1/pipelines'); + pipeReq.flush([{ + id: 'pipe-1', projectId: 'proj-1', name: 'Test', description: null, + config: { chunkSize: 256, chunkOverlap: 30, embeddingModel: 'text-embedding-3-large', retrievalStrategy: 'hybrid', topK: 10, scoreThreshold: 0.8 }, + status: 'idle', createdAt: '2026-01-01T00:00:00Z', updatedAt: null, + }]); + // loadDocuments call + const docReq = httpMock.expectOne('/api/projects/proj-1/documents'); + docReq.flush([]); + await promise; + expect(service.pipelineId()).toBe('pipe-1'); + expect(service.blocks().find(b => b.type === 'chunking')!.config['chunkSize']).toBe(256); + }); + + it('should save pipeline via PUT', async () => { + // Load first + const loadPromise = service.loadPipeline('proj-1'); + httpMock.expectOne('/api/projects/proj-1/pipelines').flush([{ + id: 'pipe-1', projectId: 'proj-1', name: 'Test', description: null, + config: { chunkSize: 512, chunkOverlap: 50, embeddingModel: 'text-embedding-3-small', retrievalStrategy: 'semantic', topK: 5, scoreThreshold: 0.7 }, + status: 'idle', createdAt: '2026-01-01T00:00:00Z', updatedAt: null, + }]); + httpMock.expectOne('/api/projects/proj-1/documents').flush([]); + await loadPromise; + + const savePromise = service.savePipeline('proj-1'); + const saveReq = httpMock.expectOne('/api/projects/proj-1/pipelines/pipe-1'); + expect(saveReq.request.method).toBe('PUT'); + saveReq.flush({}); + await savePromise; + expect(service.dirty()).toBeFalsy(); + }); + + it('should handle save error', async () => { + const loadPromise = service.loadPipeline('proj-1'); + httpMock.expectOne('/api/projects/proj-1/pipelines').flush([{ + id: 'pipe-1', projectId: 'proj-1', name: 'Test', description: null, + config: { chunkSize: 512, chunkOverlap: 50, embeddingModel: 'text-embedding-3-small', retrievalStrategy: 'semantic', topK: 5, scoreThreshold: 0.7 }, + status: 'idle', createdAt: '2026-01-01T00:00:00Z', updatedAt: null, + }]); + httpMock.expectOne('/api/projects/proj-1/documents').flush([]); + await loadPromise; + + const savePromise = service.savePipeline('proj-1').catch(() => {}); + httpMock.expectOne('/api/projects/proj-1/pipelines/pipe-1').flush( + { error: 'Bad request' }, { status: 400, statusText: 'Bad Request' } + ); + await savePromise; + expect(service.saveError()).toBeTruthy(); + }); + + it('should run pipeline and set status', async () => { + service.pipelineId.set('pipe-1'); + const runPromise = service.runPipeline('proj-1'); + expect(service.pipelineStatus()).toBe('running'); + const runReq = httpMock.expectOne('/api/projects/proj-1/pipelines/pipe-1/run'); + expect(runReq.request.method).toBe('POST'); + runReq.flush({}); + await runPromise; + expect(service.pipelineStatus()).toBe('completed'); + expect(service.isRunning()).toBeFalsy(); + }); + + it('should handle run pipeline failure', async () => { + service.pipelineId.set('pipe-1'); + const runPromise = service.runPipeline('proj-1'); + httpMock.expectOne('/api/projects/proj-1/pipelines/pipe-1/run').flush({}, { status: 500, statusText: 'Error' }); + await runPromise; + expect(service.pipelineStatus()).toBe('failed'); + }); + + it('should go to step directly', () => { + service.goToStep('retrieval'); + expect(service.currentStep()).toBe('retrieval'); + expect(service.currentStepIndex()).toBe(3); + }); + + it('should load documents', async () => { + const promise = service.loadDocuments('proj-1'); + httpMock.expectOne('/api/projects/proj-1/documents').flush([ + { id: 'd1', fileName: 'test.pdf', fileType: 'application/pdf', fileSize: 1024, status: 'Processed' }, + ]); + await promise; + expect(service.documents().length).toBe(1); + expect(service.documents()[0].fileName).toBe('test.pdf'); + }); +}); diff --git a/client/src/app/features/pipeline-builder/pipeline.service.ts b/client/src/app/features/pipeline-builder/pipeline.service.ts index c6ee490..b931024 100644 --- a/client/src/app/features/pipeline-builder/pipeline.service.ts +++ b/client/src/app/features/pipeline-builder/pipeline.service.ts @@ -21,6 +21,9 @@ export interface PipelineDto { updatedAt: string | null; } +export type PipelineStep = 'source' | 'chunking' | 'embedding' | 'retrieval' | 'generation' | 'review'; +const STEPS: PipelineStep[] = ['source', 'chunking', 'embedding', 'retrieval', 'generation', 'review']; + const SAMPLE_TEXT = 'Retrieval-Augmented Generation (RAG) is a technique that combines information retrieval with text generation. ' + 'It first retrieves relevant documents from a knowledge base using semantic search, then feeds those documents ' + @@ -40,6 +43,17 @@ export class PipelineService { readonly saving = signal(false); readonly dirty = signal(false); readonly pipelineId = signal(null); + readonly pipelineStatus = signal('idle'); + readonly isRunning = signal(false); + readonly saveError = signal(null); + readonly documents = signal<{ id: string; fileName: string; fileType: string; fileSize: number; status: string }[]>([]); + readonly currentStep = signal('source'); + readonly currentStepIndex = computed(() => STEPS.indexOf(this.currentStep())); + readonly stepCount = STEPS.length; + readonly steps = STEPS; + readonly progressPercent = computed(() => Math.round(((this.currentStepIndex() + 1) / this.stepCount) * 100)); + readonly isSaving = this.saving; + readonly isLoading = this.loading; readonly selectedBlock = computed(() => { const id = this.selectedBlockId(); @@ -49,8 +63,8 @@ export class PipelineService { readonly chunkPreviews = computed(() => { const chunkingBlock = this.blocks().find((b: PipelineBlock) => b.type === 'chunking'); if (!chunkingBlock) return []; - const size: number = chunkingBlock.config['chunkSize'] ?? 512; - const overlap: number = chunkingBlock.config['chunkOverlap'] ?? 50; + const size = Number(chunkingBlock.config['chunkSize'] ?? 512); + const overlap = Number(chunkingBlock.config['chunkOverlap'] ?? 50); return this.generateChunkPreviews(SAMPLE_TEXT, size, overlap); }); @@ -59,12 +73,12 @@ export class PipelineService { const retrieval = this.blocks().find((b: PipelineBlock) => b.type === 'retrieval'); const generation = this.blocks().find((b: PipelineBlock) => b.type === 'generation'); return { - chunkSize: (chunking?.config['chunkSize'] as number) ?? 512, - chunkOverlap: (chunking?.config['chunkOverlap'] as number) ?? 50, - retrievalStrategy: (retrieval?.config['strategy'] as string) ?? 'semantic', - topK: (retrieval?.config['topK'] as number) ?? 5, - scoreThreshold: (retrieval?.config['scoreThreshold'] as number) ?? 0.7, - temperature: (generation?.config['temperature'] as number) ?? 0.7, + chunkSize: Number(chunking?.config['chunkSize'] ?? 512), + chunkOverlap: Number(chunking?.config['chunkOverlap'] ?? 50), + retrievalStrategy: String(retrieval?.config['strategy'] ?? 'semantic'), + topK: Number(retrieval?.config['topK'] ?? 5), + scoreThreshold: Number(retrieval?.config['scoreThreshold'] ?? 0.7), + temperature: Number(generation?.config['temperature'] ?? 0.7), }; }); @@ -135,6 +149,7 @@ export class PipelineService { retrieval.config['scoreThreshold'] = p.config.scoreThreshold; this.blocks.set(blocks); } + await this.loadDocuments(projectId); } finally { this.loading.set(false); this.dirty.set(false); @@ -151,9 +166,7 @@ export class PipelineService { config: { chunkSize: cfg.chunkSize, chunkOverlap: cfg.chunkOverlap, - embeddingModel: - (this.blocks().find((b: PipelineBlock) => b.type === 'embedding')?.config['model'] as string) ?? - 'text-embedding-3-small', + embeddingModel: String(this.blocks().find((b: PipelineBlock) => b.type === 'embedding')?.config['model'] ?? 'text-embedding-3-small'), retrievalStrategy: cfg.retrievalStrategy, topK: cfg.topK, scoreThreshold: cfg.scoreThreshold, @@ -175,6 +188,48 @@ export class PipelineService { } } + nextStep(): void { + const idx = this.currentStepIndex(); + if (idx < STEPS.length - 1) this.currentStep.set(STEPS[idx + 1]); + } + + prevStep(): void { + const idx = this.currentStepIndex(); + if (idx > 0) this.currentStep.set(STEPS[idx - 1]); + } + + goToStep(step: PipelineStep): void { + this.currentStep.set(step); + } + + async runPipeline(projectId: string): Promise { + const pid = this.pipelineId(); + if (!pid) return; + this.isRunning.set(true); + this.pipelineStatus.set('running'); + try { + await firstValueFrom(this.http.post('/api/projects/' + projectId + '/pipelines/' + pid + '/run', {})); + this.pipelineStatus.set('completed'); + } catch { + this.pipelineStatus.set('failed'); + } finally { + this.isRunning.set(false); + } + } + + async loadDocuments(projectId: string): Promise { + try { + const docs = await firstValueFrom( + this.http.get<{ id: string; fileName: string; fileType: string; fileSize: number; status: string }[]>( + '/api/projects/' + projectId + '/documents' + ) + ); + this.documents.set(docs); + } catch { + this.documents.set([]); + } + } + private generateChunkPreviews(text: string, size: number, overlap: number): ChunkPreview[] { const charSize = Math.max(size, 50); const charOverlap = Math.min(overlap, charSize - 10); diff --git a/client/src/app/pages/billing/billing.ts b/client/src/app/pages/billing/billing.ts new file mode 100644 index 0000000..f210062 --- /dev/null +++ b/client/src/app/pages/billing/billing.ts @@ -0,0 +1,176 @@ +import { Component, inject, OnInit, computed, signal } from '@angular/core'; +import { CommonModule } from '@angular/common'; +import { NavbarComponent } from '../../shared/components/navbar/navbar'; +import { BillingService, UsageInfo } from '../../core/services/billing.service'; +import { AuthService } from '../../core/services/auth.service'; + +@Component({ + selector: 'app-billing', + standalone: true, + imports: [CommonModule, NavbarComponent], + template: ` + +
+

Billing & Usage

+ + +
+
+
+

Current Plan

+

{{ currentTier() }}

+ @if (subscription()?.currentPeriodEnd) { +

+ Renews {{ subscription()!.currentPeriodEnd | date:'mediumDate' }} +

+ } +
+ @if (currentTier() !== 'Free') { + + } +
+
+ + + @if (usage()) { +
+
+

Queries Today

+

{{ usage()!.queriesUsed }}

+
+
+
+

of {{ formatLimit(usage()!.queriesLimit) }}

+
+
+

Documents

+

{{ usage()!.documentsUsed }}

+
+
+
+

of {{ formatLimit(usage()!.documentsLimit) }}

+
+
+

Projects

+

{{ usage()!.projectsUsed }}

+
+
+
+

of {{ formatLimit(usage()!.projectsLimit) }}

+
+
+

Storage

+

{{ formatBytes(usage()!.storageBytesUsed) }}

+
+
+
+

of {{ formatBytes(usage()!.storageBytesLimit) }}

+
+
+ } + + +
+

Plans

+
+ @for (plan of plans; track plan.tier) { +
+

{{ plan.tier }}

+

{{ plan.price }}

+

{{ plan.period }}

+
    + @for (f of plan.features; track f) { +
  • + {{ f }} +
  • + } +
+ @if (plan.tier === currentTier()) { + + } @else if (plan.tier === 'Free') { + + } @else { + + } +
+ } +
+
+
+ `, +}) +export class BillingComponent implements OnInit { + private billing = inject(BillingService); + private auth = inject(AuthService); + + subscription = this.billing.subscription; + usage = this.billing.usage; + currentTier = computed(() => this.auth.userTier()); + + queryPct = computed(() => this.pct(this.usage()?.queriesUsed, this.usage()?.queriesLimit)); + docPct = computed(() => this.pct(this.usage()?.documentsUsed, this.usage()?.documentsLimit)); + projPct = computed(() => this.pct(this.usage()?.projectsUsed, this.usage()?.projectsLimit)); + storagePct = computed(() => this.pct(this.usage()?.storageBytesUsed, this.usage()?.storageBytesLimit)); + + plans = [ + { + tier: 'Free', price: '$0', period: 'forever', + features: ['100 queries/day', '10 documents', '1 project', '50 MB storage'], + }, + { + tier: 'Pro', price: '$29', period: '/month', + features: ['10,000 queries/day', '1,000 documents', '20 projects', '5 GB storage'], + }, + { + tier: 'Enterprise', price: '$99', period: '/month', + features: ['Unlimited queries', 'Unlimited documents', 'Unlimited projects', 'Unlimited storage'], + }, + ]; + + ngOnInit() { + this.billing.loadSubscription(); + this.billing.loadUsage(); + } + + async upgrade(tier: string) { + const url = await this.billing.createCheckoutSession(tier); + window.location.href = url; + } + + async manageSubscription() { + const url = await this.billing.createPortalSession(); + window.location.href = url; + } + + formatLimit(n: number): string { + return n >= 2147483647 ? 'Unlimited' : n.toLocaleString(); + } + + formatBytes(bytes: number): string { + if (bytes >= 9007199254740991) return 'Unlimited'; + if (bytes >= 1073741824) return (bytes / 1073741824).toFixed(1) + ' GB'; + if (bytes >= 1048576) return (bytes / 1048576).toFixed(1) + ' MB'; + if (bytes >= 1024) return (bytes / 1024).toFixed(0) + ' KB'; + return bytes + ' B'; + } + + private pct(used?: number, limit?: number): number { + if (!used || !limit || limit >= 2147483647) return 0; + return Math.min(100, (used / limit) * 100); + } +} diff --git a/client/src/app/pages/pipeline/pipeline.html b/client/src/app/pages/pipeline/pipeline.html index e7bdad9..34dfea7 100644 --- a/client/src/app/pages/pipeline/pipeline.html +++ b/client/src/app/pages/pipeline/pipeline.html @@ -1,462 +1,228 @@ -
- -
-
-

Pipeline Builder

-

Configure your RAG pipeline step by step

-
-
- @if (svc.hasUnsavedChanges()) { - ● Unsaved changes - } - +
+
+
+
+

Pipeline Builder

+

Configure your RAG pipeline by arranging and configuring blocks

+
+
+ @if (svc.dirty()) { + ● Unsaved changes + } + +
- @if (svc.saveError()) { -
- {{ svc.saveError() }} + @if (svc.loading()) { +
+
- } - - @if (svc.isLoading()) { -
Loading pipeline…
} @else { - -
-
- @for (step of svc.steps; track step; let i = $index) { - - @if (i < svc.steps.length - 1) { -
- } - } -
- -
-
-
-
- -
- -
-
-

Pipeline Flow

-
- @for (block of svc.blocks(); track block.id) { -
-
- ⠿ -
-
- {{ block.icon }} -
-
-
- {{ block.label }} - - {{ block.type }} - +
+
+
+
+

Pipeline Flow

+
+ @for (block of svc.blocks(); track block.id; let last = $last) { +
+
+ + +
-

{{ blockSummary(block) }}

-
-
-
- - @if (!$last) { -
- - - +
+ {{ block.icon }} +
+
+ {{ block.type }} +

{{ block.label }}

+

{{ blockSummary(block) }}

+
+ @if (svc.selectedBlockId() === block.id) { +
+ } +
+ @if (!last) { +
+ + + +
+ } } - } +
- -
-
-
- Status: - @switch (svc.pipelineStatus()) { - @case ('idle') { Idle } - @case ('running') { Running… } - @case ('completed') { Completed } - @case ('failed') { Failed } - } -
- +
+

📋 Chunk Preview

+

Preview how documents will be split based on your chunking settings

+
+ @for (chunk of svc.chunkPreviews(); track chunk.index) { +
+
+ Chunk {{ chunk.index + 1 }} + ~{{ chunk.tokens }} tokens +
+

{{ chunk.text }}

+
+ }
-
- -
-
- @switch (svc.currentStep()) { - - @case ('source') { -

📄 Data Source

-

Upload documents or connect a data source for your chatbot.

- @if (svc.documents().length > 0) { -
- @for (doc of svc.documents(); track doc.id) { -
- - @switch (doc.fileType) { - @case ('application/pdf') { 📕 } - @case ('text/plain') { 📝 } - @case ('text/markdown') { 📝 } - @default { 📄 } - } - -
-

{{ doc.fileName }}

-

{{ (doc.fileSize / 1024).toFixed(1) }} KB

-
- - {{ doc.status }} - -
- } -
-

{{ svc.documents().length }} document(s) loaded. Configure chunking in the next step.

- } @else { -
-
📁
-

No documents uploaded yet

-

Upload documents from the project page first

+
+ @if (svc.selectedBlock(); as block) { +
+
+ {{ block.icon }} +
+

{{ block.label }}

+ {{ block.type }}
- } - @for (block of svc.blocks(); track block.id) { - @if (block.type === 'source') { -
- -
- } - } - } +
- - @case ('chunking') { -

✂️ Chunking

-

Configure how documents are split into smaller chunks for processing.

- @for (block of svc.blocks(); track block.id) { - @if (block.type === 'chunking') { -
- - - + @if (block.type === 'source') { +
+
+ +
- - - @if (svc.chunkPreview(); as preview) { -
-

Chunk Preview ({{ preview.totalCount }} total)

-
- @for (chunk of preview.chunks; track chunk.index) { -
-
- Chunk {{ chunk.index + 1 }} - {{ chunk.tokenCount }} tokens -
-

{{ chunk.content }}

-
- } -
-
- } - } +
+ + +
+
} - } - - @case ('embedding') { -

🧮 Embedding

-

Choose the embedding model that converts text into vector representations.

- @for (block of svc.blocks(); track block.id) { - @if (block.type === 'embedding') { -
- - - -
-
-

Quality

-

- @switch (block.config['model']) { - @case ('text-embedding-3-large') { ★★★ } - @case ('embed-v4') { ★★★ } - @case ('text-embedding-3-small') { ★★☆ } - @default { ★★☆ } - } -

-
-
-

Cost

-

- @switch (block.config['model']) { - @case ('text-embedding-3-large') { $$ } - @case ('embed-v4') { $ } - @case ('bge-m3') { Free } - @default { $ } - } -

-
-
+ @if (block.type === 'chunking') { +
+
+ +
- } +
+ + +
+
+ + +
+
} - } - - @case ('retrieval') { -

🔍 Retrieval

-

Configure how relevant chunks are retrieved for each query.

- @for (block of svc.blocks(); track block.id) { - @if (block.type === 'retrieval') { -
- - - + @if (block.type === 'embedding') { +
+
+ +
- } +
+ + +
+
+ + +
+
} - } - - @case ('generation') { -

🤖 Generation

-

Choose the LLM and configure how responses are generated.

- @for (block of svc.blocks(); track block.id) { - @if (block.type === 'generation') { -
- - - - + @if (block.type === 'retrieval') { +
+
+ +
- } +
+ + +
+
+ + +
+
+ + +
+
} - } - - @case ('review') { -

📋 Review & Run

-

Review your pipeline configuration before running.

-
- @for (block of svc.blocks(); track block.id) { -
- {{ block.icon }} -
-

{{ block.label }}

-

{{ blockSummary(block) }}

-
- + @if (block.type === 'generation') { +
+
+ +
- } -
-
- - -
- } - } - - - @if (svc.currentStep() !== 'review') { -
- - +
+ + +
+
+ + +
+
+ + +
+
+ } +
+ } @else { +
+
👆
+

Select a Block

+

Click on any pipeline block to configure its settings

}
diff --git a/client/src/app/pages/pipeline/pipeline.spec.ts b/client/src/app/pages/pipeline/pipeline.spec.ts new file mode 100644 index 0000000..822dd3e --- /dev/null +++ b/client/src/app/pages/pipeline/pipeline.spec.ts @@ -0,0 +1,73 @@ +import { ComponentFixture, TestBed } from '@angular/core/testing'; +import { provideHttpClient } from '@angular/common/http'; +import { provideHttpClientTesting } from '@angular/common/http/testing'; +import { ActivatedRoute } from '@angular/router'; +import { PipelineComponent } from './pipeline'; +import { PipelineService } from '../../features/pipeline-builder/pipeline.service'; + +describe('PipelineComponent', () => { + let component: PipelineComponent; + let fixture: ComponentFixture; + let svc: PipelineService; + + beforeEach(async () => { + await TestBed.configureTestingModule({ + imports: [PipelineComponent], + providers: [ + provideHttpClient(), + provideHttpClientTesting(), + { + provide: ActivatedRoute, + useValue: { snapshot: { paramMap: { get: () => 'test-project-id' } } }, + }, + ], + }).compileComponents(); + + fixture = TestBed.createComponent(PipelineComponent); + component = fixture.componentInstance; + svc = TestBed.inject(PipelineService); + }); + + it('should create', () => { + expect(component).toBeTruthy(); + }); + + it('should return correct block classes for selected block', () => { + const block = svc.blocks()[0]; // source + svc.selectBlock(block.id); + expect(component.blockClass(block)).toContain('border-emerald-500'); + }); + + it('should return correct block classes for unselected block', () => { + svc.selectBlock('nonexistent'); + const block = svc.blocks()[0]; + expect(component.blockClass(block)).toContain('border-gray-200'); + }); + + it('should return block icon background', () => { + const block = svc.blocks().find(b => b.type === 'embedding')!; + expect(component.blockIconBg(block)).toBe('bg-purple-100'); + }); + + it('should return block type badge color', () => { + const block = svc.blocks().find(b => b.type === 'retrieval')!; + expect(component.blockTypeBadge(block)).toBe('text-blue-700'); + }); + + it('should return block summary for chunking', () => { + const block = svc.blocks().find(b => b.type === 'chunking')!; + const summary = component.blockSummary(block); + expect(summary).toContain('recursive'); + expect(summary).toContain('512'); + }); + + it('should return block summary for generation', () => { + const block = svc.blocks().find(b => b.type === 'generation')!; + expect(component.blockSummary(block)).toContain('gpt-4o-mini'); + }); + + it('should return block summary for embedding', () => { + const block = svc.blocks().find(b => b.type === 'embedding')!; + expect(component.blockSummary(block)).toContain('1536d'); + }); +}); diff --git a/client/src/app/pages/pipeline/pipeline.ts b/client/src/app/pages/pipeline/pipeline.ts index edb4ac0..e781caf 100644 --- a/client/src/app/pages/pipeline/pipeline.ts +++ b/client/src/app/pages/pipeline/pipeline.ts @@ -16,27 +16,43 @@ export class PipelineComponent implements OnInit { private route = inject(ActivatedRoute); private projectId = ''; - ngOnInit() { + ngOnInit(): void { this.projectId = this.route.snapshot.paramMap.get('projectId') ?? ''; if (this.projectId) { this.svc.loadPipeline(this.projectId); } } - onDrop(event: CdkDragDrop) { + onDrop(event: CdkDragDrop): void { this.svc.moveBlock(event.previousIndex, event.currentIndex); } - save() { + save(): void { if (this.projectId) { this.svc.savePipeline(this.projectId); } } - runPipeline() { - if (this.projectId) { - this.svc.runPipeline(this.projectId); - } + onConfigChange(blockId: string, key: string, event: Event): void { + const el = event.target as HTMLInputElement | HTMLSelectElement | HTMLTextAreaElement; + this.svc.updateBlockConfig(blockId, key, el.value); + } + + onConfigChangeNum(blockId: string, key: string, event: Event): void { + const el = event.target as HTMLInputElement | HTMLSelectElement | HTMLTextAreaElement; + const parsed = parseInt(el.value, 10); + if (!isNaN(parsed)) this.svc.updateBlockConfig(blockId, key, parsed); + } + + onConfigChangeFloat(blockId: string, key: string, event: Event): void { + const el = event.target as HTMLInputElement | HTMLSelectElement | HTMLTextAreaElement; + const parsed = parseFloat(el.value); + if (!isNaN(parsed)) this.svc.updateBlockConfig(blockId, key, parsed); + } + + onConfigChangeBool(blockId: string, key: string, event: Event): void { + const el = event.target as HTMLInputElement; + this.svc.updateBlockConfig(blockId, key, el.checked); } blockClass(block: PipelineBlock): string { @@ -53,60 +69,29 @@ export class PipelineComponent implements OnInit { blockIconBg(block: PipelineBlock): string { const colors: Record = { - source: 'bg-emerald-100', - chunking: 'bg-amber-100', - embedding: 'bg-purple-100', - retrieval: 'bg-blue-100', - generation: 'bg-rose-100', + source: 'bg-emerald-100', chunking: 'bg-amber-100', embedding: 'bg-purple-100', + retrieval: 'bg-blue-100', generation: 'bg-rose-100', }; return colors[block.type]; } blockTypeBadge(block: PipelineBlock): string { const colors: Record = { - source: 'text-emerald-700', - chunking: 'text-amber-700', - embedding: 'text-purple-700', - retrieval: 'text-blue-700', - generation: 'text-rose-700', + source: 'text-emerald-700', chunking: 'text-amber-700', embedding: 'text-purple-700', + retrieval: 'text-blue-700', generation: 'text-rose-700', }; return colors[block.type]; } blockSummary(block: PipelineBlock): string { + const c = block.config; switch (block.type) { - case 'source': return `${block.config['sourceType']} · ${block.config['fileTypes']}`; - case 'chunking': return `${block.config['strategy']} · ${block.config['chunkSize']} tokens · ${block.config['chunkOverlap']} overlap`; - case 'embedding': return `${block.config['model']} · ${block.config['dimensions']}d`; - case 'retrieval': return `${block.config['strategy']} · top ${block.config['topK']} · threshold ${block.config['scoreThreshold']}`; - case 'generation': return `${block.config['model']} · temp ${block.config['temperature']}`; + case 'source': return c['sourceType'] + ' · ' + c['fileTypes']; + case 'chunking': return c['strategy'] + ' · ' + c['chunkSize'] + ' tokens · ' + c['chunkOverlap'] + ' overlap'; + case 'embedding': return c['model'] + ' · ' + c['dimensions'] + 'd'; + case 'retrieval': return c['strategy'] + ' · top ' + c['topK'] + ' · threshold ' + c['scoreThreshold']; + case 'generation': return c['model'] + ' · temp ' + c['temperature']; default: return ''; } } - - onConfigChange(blockId: string, key: string, event: Event) { - const el = event.target as HTMLInputElement | HTMLSelectElement | HTMLTextAreaElement; - this.svc.updateBlockConfig(blockId, key, el.value); - } - - onConfigChangeNum(blockId: string, key: string, event: Event) { - const el = event.target as HTMLInputElement; - const parsed = parseInt(el.value, 10); - if (!isNaN(parsed)) { - this.svc.updateBlockConfig(blockId, key, parsed); - } - } - - onConfigChangeFloat(blockId: string, key: string, event: Event) { - const el = event.target as HTMLInputElement; - const parsed = parseFloat(el.value); - if (!isNaN(parsed)) { - this.svc.updateBlockConfig(blockId, key, parsed); - } - } - - onConfigChangeBool(blockId: string, key: string, event: Event) { - const el = event.target as HTMLInputElement; - this.svc.updateBlockConfig(blockId, key, el.checked); - } } diff --git a/client/src/app/shared/components/navbar/navbar.ts b/client/src/app/shared/components/navbar/navbar.ts index 8bcca26..311aaae 100644 --- a/client/src/app/shared/components/navbar/navbar.ts +++ b/client/src/app/shared/components/navbar/navbar.ts @@ -14,6 +14,7 @@ import { AuthService } from '../../../core/services/auth.service';
+ Billing {{ userTier() }} diff --git a/docs/features/pr-009-billing-tier-enforcement.md b/docs/features/pr-009-billing-tier-enforcement.md new file mode 100644 index 0000000..66f8cf9 --- /dev/null +++ b/docs/features/pr-009-billing-tier-enforcement.md @@ -0,0 +1,53 @@ +# PR #9: Billing + Tier Enforcement (Stripe Integration) + +## Overview +Stripe-powered billing with tiered usage enforcement for PipeRAG. + +## Tier Limits + +| Feature | Free | Pro ($29/mo) | Enterprise ($99/mo) | +|---------|------|-------------|-------------------| +| Queries/day | 100 | 10,000 | Unlimited | +| Documents | 10 | 1,000 | Unlimited | +| Projects | 1 | 20 | Unlimited | +| Storage | 50 MB | 5 GB | Unlimited | + +## API Endpoints + +- `POST /api/billing/create-checkout-session` — Start Stripe Checkout for upgrade +- `POST /api/billing/create-portal-session` — Open Stripe Customer Portal +- `POST /api/billing/webhook` — Stripe webhook handler +- `GET /api/billing/subscription` — Current subscription status +- `GET /api/billing/usage` — Current usage stats + +## Architecture + +### Backend +- **StripeService** — Handles Stripe customer creation, checkout sessions, portal sessions, webhook processing +- **UsageTrackingService** — Tracks queries/day, documents, projects, storage per user +- **TierEnforcementMiddleware** — Intercepts POST requests and checks tier limits before allowing operations +- **Entities**: `Subscription`, `UsageRecord` tables + +### Frontend +- **Billing page** (`/billing`) — Shows current plan, usage dashboard with progress bars, pricing comparison table +- **BillingService** — Angular service for billing API calls + +## Webhook Events Handled +- `checkout.session.completed` — Activate subscription after payment +- `customer.subscription.updated` — Sync subscription status changes +- `customer.subscription.deleted` — Downgrade to Free tier on cancellation + +## Configuration +Required in `appsettings.json`: +```json +{ + "Stripe": { + "SecretKey": "sk_...", + "WebhookSecret": "whsec_...", + "ProPriceId": "price_...", + "EnterprisePriceId": "price_..." + } +} +``` + +> ⚠️ **Security:** Never commit real Stripe keys to source control. Use environment variables, user secrets (`dotnet user-secrets`), or a key vault for production deployments. diff --git a/src/PipeRAG.Api/Controllers/BillingController.cs b/src/PipeRAG.Api/Controllers/BillingController.cs new file mode 100644 index 0000000..b286703 --- /dev/null +++ b/src/PipeRAG.Api/Controllers/BillingController.cs @@ -0,0 +1,84 @@ +using System.Security.Claims; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using PipeRAG.Core.Enums; +using PipeRAG.Core.Interfaces; + +namespace PipeRAG.Api.Controllers; + +[ApiController] +[Route("api/billing")] +public class BillingController : ControllerBase +{ + private readonly IBillingService _billing; + private readonly IUsageTrackingService _usage; + + public BillingController(IBillingService billing, IUsageTrackingService usage) + { + _billing = billing; + _usage = usage; + } + + private Guid GetUserId() + { + var claim = User.FindFirstValue(ClaimTypes.NameIdentifier); + if (!Guid.TryParse(claim, out var userId)) + throw new UnauthorizedAccessException("Invalid or missing user identifier"); + return userId; + } + + [Authorize] + [HttpPost("create-checkout-session")] + public async Task CreateCheckoutSession([FromBody] CheckoutRequest request) + { + if (!Enum.TryParse(request.Tier, true, out var tier) || tier == UserTier.Free) + return BadRequest(new { error = "Invalid tier" }); + + var url = await _billing.CreateCheckoutSessionAsync( + GetUserId(), tier, request.SuccessUrl, request.CancelUrl); + return Ok(new { url }); + } + + [Authorize] + [HttpPost("create-portal-session")] + public async Task CreatePortalSession([FromBody] PortalRequest request) + { + var url = await _billing.CreatePortalSessionAsync(GetUserId(), request.ReturnUrl); + return Ok(new { url }); + } + + [HttpPost("webhook")] + public async Task Webhook() + { + var json = await new StreamReader(HttpContext.Request.Body).ReadToEndAsync(); + var signature = Request.Headers["Stripe-Signature"].ToString(); + try + { + await _billing.HandleWebhookAsync(json, signature); + return Ok(); + } + catch (Exception ex) + { + return BadRequest(new { error = ex.Message }); + } + } + + [Authorize] + [HttpGet("subscription")] + public async Task GetSubscription() + { + var sub = await _billing.GetSubscriptionAsync(GetUserId()); + return Ok(sub); + } + + [Authorize] + [HttpGet("usage")] + public async Task GetUsage() + { + var usage = await _usage.GetUsageAsync(GetUserId()); + return Ok(usage); + } +} + +public record CheckoutRequest(string Tier, string SuccessUrl, string CancelUrl); +public record PortalRequest(string ReturnUrl); diff --git a/src/PipeRAG.Api/Controllers/WidgetController.cs b/src/PipeRAG.Api/Controllers/WidgetController.cs new file mode 100644 index 0000000..ee7dfa7 --- /dev/null +++ b/src/PipeRAG.Api/Controllers/WidgetController.cs @@ -0,0 +1,105 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; +using PipeRAG.Core.DTOs; +using PipeRAG.Core.Entities; +using PipeRAG.Infrastructure.Data; +using System.Security.Claims; + +namespace PipeRAG.Api.Controllers; + +/// +/// Manages widget configuration for a project (authenticated). +/// +[ApiController] +[Route("api/projects/{projectId:guid}/widget")] +[Authorize] +public class WidgetController : ControllerBase +{ + private readonly PipeRagDbContext _db; + + public WidgetController(PipeRagDbContext db) => _db = db; + + /// + /// Get the widget configuration for a project. + /// + [HttpGet] + public async Task> Get(Guid projectId, CancellationToken ct) + { + var project = await GetAuthorizedProjectAsync(projectId, ct); + if (project is null) return NotFound(new { error = "Project not found." }); + + var config = await _db.WidgetConfigs.FirstOrDefaultAsync(w => w.ProjectId == projectId, ct); + if (config is null) return NotFound(new { error = "Widget not configured." }); + + return Ok(MapToResponse(config)); + } + + /// + /// Create or update the widget configuration for a project. + /// + [HttpPut] + public async Task> Upsert(Guid projectId, [FromBody] WidgetConfigRequest request, CancellationToken ct) + { + var project = await GetAuthorizedProjectAsync(projectId, ct); + if (project is null) return NotFound(new { error = "Project not found." }); + + var config = await _db.WidgetConfigs.FirstOrDefaultAsync(w => w.ProjectId == projectId, ct); + if (config is null) + { + config = new WidgetConfig { ProjectId = projectId }; + _db.WidgetConfigs.Add(config); + } + + if (request.PrimaryColor is not null) config.PrimaryColor = request.PrimaryColor; + if (request.BackgroundColor is not null) config.BackgroundColor = request.BackgroundColor; + if (request.TextColor is not null) config.TextColor = request.TextColor; + if (request.Position is not null) config.Position = request.Position; + if (request.AvatarUrl is not null) config.AvatarUrl = request.AvatarUrl; + if (request.Title is not null) config.Title = request.Title; + if (request.Subtitle is not null) config.Subtitle = request.Subtitle; + if (request.PlaceholderText is not null) config.PlaceholderText = request.PlaceholderText; + if (request.AllowedOrigins is not null) config.AllowedOrigins = request.AllowedOrigins; + if (request.IsActive.HasValue) config.IsActive = request.IsActive.Value; + config.UpdatedAt = DateTime.UtcNow; + + await _db.SaveChangesAsync(ct); + return Ok(MapToResponse(config)); + } + + /// + /// Delete the widget configuration for a project. + /// + [HttpDelete] + public async Task Delete(Guid projectId, CancellationToken ct) + { + var project = await GetAuthorizedProjectAsync(projectId, ct); + if (project is null) return NotFound(new { error = "Project not found." }); + + var config = await _db.WidgetConfigs.FirstOrDefaultAsync(w => w.ProjectId == projectId, ct); + if (config is null) return NotFound(new { error = "Widget not configured." }); + + _db.WidgetConfigs.Remove(config); + await _db.SaveChangesAsync(ct); + return NoContent(); + } + + private async Task GetAuthorizedProjectAsync(Guid projectId, CancellationToken ct) + { + var project = await _db.Projects.FindAsync([projectId], ct); + if (project is null || project.OwnerId != GetUserId()) return null; + return project; + } + + private Guid GetUserId() + { + var claim = User.FindFirst(ClaimTypes.NameIdentifier)?.Value + ?? User.FindFirst("sub")?.Value; + return Guid.TryParse(claim, out var id) ? id : Guid.Empty; + } + + private static WidgetConfigResponse MapToResponse(WidgetConfig c) => new( + c.Id, c.ProjectId, c.PrimaryColor, c.BackgroundColor, c.TextColor, + c.Position, c.AvatarUrl, c.Title, c.Subtitle, c.PlaceholderText, + c.AllowedOrigins, c.IsActive, c.CreatedAt, c.UpdatedAt); +} diff --git a/src/PipeRAG.Api/Middleware/TierEnforcementMiddleware.cs b/src/PipeRAG.Api/Middleware/TierEnforcementMiddleware.cs new file mode 100644 index 0000000..08cbf4f --- /dev/null +++ b/src/PipeRAG.Api/Middleware/TierEnforcementMiddleware.cs @@ -0,0 +1,74 @@ +using System.Security.Claims; +using PipeRAG.Core.Interfaces; + +namespace PipeRAG.Api.Middleware; + +public class TierEnforcementMiddleware +{ + private readonly RequestDelegate _next; + private readonly ILogger _logger; + + private static readonly HashSet QueryPaths = ["/api/chat"]; + private static readonly HashSet DocumentPaths = ["/api/documents"]; + private static readonly HashSet ProjectPaths = ["/api/projects"]; + + public TierEnforcementMiddleware(RequestDelegate next, ILogger logger) + { + _next = next; + _logger = logger; + } + + public async Task InvokeAsync(HttpContext context, IUsageTrackingService usageService) + { + var path = context.Request.Path.Value?.ToLower() ?? ""; + var method = context.Request.Method; + + if (!HttpMethods.IsPost(method) || context.User.Identity?.IsAuthenticated != true) + { + await _next(context); + return; + } + + var userIdClaim = context.User.FindFirstValue(ClaimTypes.NameIdentifier); + if (!Guid.TryParse(userIdClaim, out var userId)) + { + await _next(context); + return; + } + + // Check query limits + if (QueryPaths.Any(p => path.StartsWith(p))) + { + if (!await usageService.CanPerformQueryAsync(userId)) + { + context.Response.StatusCode = 429; + await context.Response.WriteAsJsonAsync(new { error = "Daily query limit reached. Upgrade your plan." }); + return; + } + } + + // Check document limits + if (DocumentPaths.Any(p => path.StartsWith(p))) + { + if (!await usageService.CanCreateDocumentAsync(userId)) + { + context.Response.StatusCode = 429; + await context.Response.WriteAsJsonAsync(new { error = "Document limit reached. Upgrade your plan." }); + return; + } + } + + // Check project limits + if (ProjectPaths.Any(p => path.StartsWith(p))) + { + if (!await usageService.CanCreateProjectAsync(userId)) + { + context.Response.StatusCode = 429; + await context.Response.WriteAsJsonAsync(new { error = "Project limit reached. Upgrade your plan." }); + return; + } + } + + await _next(context); + } +} diff --git a/src/PipeRAG.Api/Program.cs b/src/PipeRAG.Api/Program.cs index 87d3004..8520393 100644 --- a/src/PipeRAG.Api/Program.cs +++ b/src/PipeRAG.Api/Program.cs @@ -51,6 +51,10 @@ builder.Services.AddScoped(); builder.Services.AddScoped(); +// Billing + Usage services +builder.Services.AddScoped(); +builder.Services.AddScoped(); + // JWT Authentication var jwtSection = builder.Configuration.GetSection("Jwt"); var jwtSecret = jwtSection["Secret"]; @@ -94,6 +98,7 @@ app.UseAuthentication(); app.UseAuthorization(); app.UseMiddleware(); +app.UseMiddleware(); app.MapControllers(); // Health check endpoint diff --git a/src/PipeRAG.Core/DTOs/WidgetDtos.cs b/src/PipeRAG.Core/DTOs/WidgetDtos.cs new file mode 100644 index 0000000..7f9aa21 --- /dev/null +++ b/src/PipeRAG.Core/DTOs/WidgetDtos.cs @@ -0,0 +1,35 @@ +namespace PipeRAG.Core.DTOs; + +/// +/// Request to create or update a widget configuration. +/// +public record WidgetConfigRequest( + string? PrimaryColor = null, + string? BackgroundColor = null, + string? TextColor = null, + string? Position = null, + string? AvatarUrl = null, + string? Title = null, + string? Subtitle = null, + string? PlaceholderText = null, + string? AllowedOrigins = null, + bool? IsActive = null); + +/// +/// Response with widget configuration details. +/// +public record WidgetConfigResponse( + Guid Id, + Guid ProjectId, + string PrimaryColor, + string BackgroundColor, + string TextColor, + string Position, + string? AvatarUrl, + string Title, + string Subtitle, + string PlaceholderText, + string AllowedOrigins, + bool IsActive, + DateTime CreatedAt, + DateTime? UpdatedAt); diff --git a/src/PipeRAG.Core/Entities/Subscription.cs b/src/PipeRAG.Core/Entities/Subscription.cs new file mode 100644 index 0000000..42614aa --- /dev/null +++ b/src/PipeRAG.Core/Entities/Subscription.cs @@ -0,0 +1,21 @@ +using PipeRAG.Core.Enums; + +namespace PipeRAG.Core.Entities; + +public class Subscription +{ + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string StripeCustomerId { get; set; } = string.Empty; + public string StripeSubscriptionId { get; set; } = string.Empty; + public string StripePriceId { get; set; } = string.Empty; + public UserTier Tier { get; set; } = UserTier.Free; + public SubscriptionStatus Status { get; set; } = SubscriptionStatus.Active; + public DateTime? CurrentPeriodStart { get; set; } + public DateTime? CurrentPeriodEnd { get; set; } + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + public DateTime? UpdatedAt { get; set; } + public DateTime? CancelledAt { get; set; } + + public User User { get; set; } = null!; +} diff --git a/src/PipeRAG.Core/Entities/UsageRecord.cs b/src/PipeRAG.Core/Entities/UsageRecord.cs new file mode 100644 index 0000000..0e70f7b --- /dev/null +++ b/src/PipeRAG.Core/Entities/UsageRecord.cs @@ -0,0 +1,15 @@ +namespace PipeRAG.Core.Entities; + +public class UsageRecord +{ + public Guid Id { get; set; } + public Guid UserId { get; set; } + public DateTime Date { get; set; } = DateTime.UtcNow.Date; + public int QueryCount { get; set; } + public int DocumentCount { get; set; } + public int ProjectCount { get; set; } + public long StorageBytes { get; set; } + public DateTime UpdatedAt { get; set; } = DateTime.UtcNow; + + public User User { get; set; } = null!; +} diff --git a/src/PipeRAG.Core/Entities/User.cs b/src/PipeRAG.Core/Entities/User.cs index 0842467..4286121 100644 --- a/src/PipeRAG.Core/Entities/User.cs +++ b/src/PipeRAG.Core/Entities/User.cs @@ -23,4 +23,6 @@ public class User public ICollection ChatSessions { get; set; } = []; public ICollection AuditLogs { get; set; } = []; public ICollection RefreshTokens { get; set; } = []; + public Subscription? Subscription { get; set; } + public ICollection UsageRecords { get; set; } = []; } diff --git a/src/PipeRAG.Core/Entities/WidgetConfig.cs b/src/PipeRAG.Core/Entities/WidgetConfig.cs new file mode 100644 index 0000000..b685af7 --- /dev/null +++ b/src/PipeRAG.Core/Entities/WidgetConfig.cs @@ -0,0 +1,36 @@ +namespace PipeRAG.Core.Entities; + +/// +/// Widget configuration for an embeddable chat widget on a project. +/// +public class WidgetConfig +{ + public Guid Id { get; set; } + public Guid ProjectId { get; set; } + + // Theme + public string PrimaryColor { get; set; } = "#6366f1"; + public string BackgroundColor { get; set; } = "#1e1e2e"; + public string TextColor { get; set; } = "#ffffff"; + + // Position: bottom-right, bottom-left + public string Position { get; set; } = "bottom-right"; + + // Avatar + public string? AvatarUrl { get; set; } + + // Display + public string Title { get; set; } = "Chat with us"; + public string Subtitle { get; set; } = "Ask anything about our docs"; + public string PlaceholderText { get; set; } = "Type a message..."; + + // Security + public string AllowedOrigins { get; set; } = "*"; + + public bool IsActive { get; set; } = true; + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + public DateTime? UpdatedAt { get; set; } + + // Navigation + public Project Project { get; set; } = null!; +} diff --git a/src/PipeRAG.Core/Enums/Enums.cs b/src/PipeRAG.Core/Enums/Enums.cs index cc5cfc6..1e683db 100644 --- a/src/PipeRAG.Core/Enums/Enums.cs +++ b/src/PipeRAG.Core/Enums/Enums.cs @@ -47,6 +47,15 @@ public enum ChatMessageRole System } +public enum SubscriptionStatus +{ + Active, + PastDue, + Cancelled, + Incomplete, + Trialing +} + public enum AuditAction { Create, diff --git a/src/PipeRAG.Core/Interfaces/IBillingService.cs b/src/PipeRAG.Core/Interfaces/IBillingService.cs new file mode 100644 index 0000000..21de60b --- /dev/null +++ b/src/PipeRAG.Core/Interfaces/IBillingService.cs @@ -0,0 +1,17 @@ +using PipeRAG.Core.Enums; + +namespace PipeRAG.Core.Interfaces; + +public interface IBillingService +{ + Task CreateCheckoutSessionAsync(Guid userId, UserTier tier, string successUrl, string cancelUrl); + Task CreatePortalSessionAsync(Guid userId, string returnUrl); + Task HandleWebhookAsync(string json, string stripeSignature); + Task GetSubscriptionAsync(Guid userId); +} + +public record SubscriptionDto( + UserTier Tier, + string Status, + DateTime? CurrentPeriodEnd, + string? StripeCustomerId); diff --git a/src/PipeRAG.Core/Interfaces/IUsageTrackingService.cs b/src/PipeRAG.Core/Interfaces/IUsageTrackingService.cs new file mode 100644 index 0000000..5fdd2e8 --- /dev/null +++ b/src/PipeRAG.Core/Interfaces/IUsageTrackingService.cs @@ -0,0 +1,38 @@ +using PipeRAG.Core.Enums; + +namespace PipeRAG.Core.Interfaces; + +public interface IUsageTrackingService +{ + Task IncrementQueryCountAsync(Guid userId); + Task GetUsageAsync(Guid userId); + Task CanPerformQueryAsync(Guid userId); + Task CanCreateDocumentAsync(Guid userId); + Task CanCreateProjectAsync(Guid userId); + Task CanUploadStorageAsync(Guid userId, long additionalBytes); + Task RecalculateDocumentCountAsync(Guid userId); + Task RecalculateProjectCountAsync(Guid userId); + Task RecalculateStorageAsync(Guid userId); +} + +public record UsageDto( + int QueriesUsed, + int QueriesLimit, + int DocumentsUsed, + int DocumentsLimit, + int ProjectsUsed, + int ProjectsLimit, + long StorageBytesUsed, + long StorageBytesLimit, + UserTier Tier); + +public static class TierLimits +{ + public static (int QueriesPerDay, int MaxDocuments, int MaxProjects, long MaxStorageBytes) GetLimits(UserTier tier) => tier switch + { + UserTier.Free => (100, 10, 1, 50L * 1024 * 1024), + UserTier.Pro => (10_000, 1_000, 20, 5L * 1024 * 1024 * 1024), + UserTier.Enterprise => (int.MaxValue, int.MaxValue, int.MaxValue, long.MaxValue), + _ => (100, 10, 1, 50L * 1024 * 1024) + }; +} diff --git a/src/PipeRAG.Infrastructure/Data/PipeRagDbContext.cs b/src/PipeRAG.Infrastructure/Data/PipeRagDbContext.cs index 3b207bd..425a6c3 100644 --- a/src/PipeRAG.Infrastructure/Data/PipeRagDbContext.cs +++ b/src/PipeRAG.Infrastructure/Data/PipeRagDbContext.cs @@ -32,6 +32,9 @@ public PipeRagDbContext(DbContextOptions options) : base(optio public DbSet ApiKeys => Set(); public DbSet AuditLogs => Set(); public DbSet RefreshTokens => Set(); + public DbSet Subscriptions => Set(); + public DbSet UsageRecords => Set(); + public DbSet WidgetConfigs => Set(); protected override void OnModelCreating(ModelBuilder modelBuilder) { @@ -135,5 +138,27 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) e.HasIndex(r => r.UserId); e.HasOne(r => r.User).WithMany(u => u.RefreshTokens).HasForeignKey(r => r.UserId); }); + + modelBuilder.Entity(e => + { + e.HasIndex(s => s.UserId).IsUnique(); + e.HasIndex(s => s.StripeCustomerId); + e.HasIndex(s => s.StripeSubscriptionId); + e.Property(s => s.Tier).HasConversion().HasMaxLength(20); + e.Property(s => s.Status).HasConversion().HasMaxLength(20); + e.HasOne(s => s.User).WithOne(u => u.Subscription).HasForeignKey(s => s.UserId); + }); + + modelBuilder.Entity(e => + { + e.HasIndex(u => new { u.UserId, u.Date }).IsUnique(); + e.HasOne(u => u.User).WithMany(u => u.UsageRecords).HasForeignKey(u => u.UserId); + }); + + modelBuilder.Entity(e => + { + e.HasIndex(w => w.ProjectId).IsUnique(); + e.HasOne(w => w.Project).WithMany().HasForeignKey(w => w.ProjectId).OnDelete(DeleteBehavior.Cascade); + }); } } diff --git a/src/PipeRAG.Infrastructure/PipeRAG.Infrastructure.csproj b/src/PipeRAG.Infrastructure/PipeRAG.Infrastructure.csproj index 1be504b..f78f314 100644 --- a/src/PipeRAG.Infrastructure/PipeRAG.Infrastructure.csproj +++ b/src/PipeRAG.Infrastructure/PipeRAG.Infrastructure.csproj @@ -12,6 +12,7 @@ + diff --git a/src/PipeRAG.Infrastructure/Services/StripeService.cs b/src/PipeRAG.Infrastructure/Services/StripeService.cs new file mode 100644 index 0000000..b226eb6 --- /dev/null +++ b/src/PipeRAG.Infrastructure/Services/StripeService.cs @@ -0,0 +1,192 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using PipeRAG.Core.Entities; +using PipeRAG.Core.Enums; +using PipeRAG.Core.Interfaces; +using PipeRAG.Infrastructure.Data; +using Stripe; +using Stripe.Checkout; + +namespace PipeRAG.Infrastructure.Services; + +public class StripeService : IBillingService +{ + private readonly PipeRagDbContext _db; + private readonly IConfiguration _config; + private readonly ILogger _logger; + + public StripeService(PipeRagDbContext db, IConfiguration config, ILogger logger) + { + _db = db; + _config = config; + _logger = logger; + StripeConfiguration.ApiKey = _config["Stripe:SecretKey"]; + } + + public async Task CreateCheckoutSessionAsync(Guid userId, UserTier tier, string successUrl, string cancelUrl) + { + var user = await _db.Users.FindAsync(userId) + ?? throw new InvalidOperationException("User not found"); + + var sub = await _db.Subscriptions.FirstOrDefaultAsync(s => s.UserId == userId); + string customerId; + + if (sub?.StripeCustomerId is { Length: > 0 }) + { + customerId = sub.StripeCustomerId; + } + else + { + var customerService = new CustomerService(); + var customer = await customerService.CreateAsync(new CustomerCreateOptions + { + Email = user.Email, + Metadata = new Dictionary { ["userId"] = userId.ToString() } + }); + customerId = customer.Id; + } + + var priceId = tier switch + { + UserTier.Pro => _config["Stripe:ProPriceId"], + UserTier.Enterprise => _config["Stripe:EnterprisePriceId"], + _ => throw new ArgumentException("Cannot checkout for Free tier") + }; + + var sessionService = new SessionService(); + var session = await sessionService.CreateAsync(new SessionCreateOptions + { + Customer = customerId, + Mode = "subscription", + LineItems = [new SessionLineItemOptions { Price = priceId, Quantity = 1 }], + SuccessUrl = successUrl, + CancelUrl = cancelUrl, + Metadata = new Dictionary + { + ["userId"] = userId.ToString(), + ["tier"] = tier.ToString() + } + }); + + return session.Url; + } + + public async Task CreatePortalSessionAsync(Guid userId, string returnUrl) + { + var sub = await _db.Subscriptions.FirstOrDefaultAsync(s => s.UserId == userId) + ?? throw new InvalidOperationException("No subscription found"); + + var portalService = new Stripe.BillingPortal.SessionService(); + var session = await portalService.CreateAsync(new Stripe.BillingPortal.SessionCreateOptions + { + Customer = sub.StripeCustomerId, + ReturnUrl = returnUrl + }); + + return session.Url; + } + + public async Task HandleWebhookAsync(string json, string stripeSignature) + { + var webhookSecret = _config["Stripe:WebhookSecret"]; + var stripeEvent = EventUtility.ConstructEvent(json, stripeSignature, webhookSecret); + + _logger.LogInformation("Stripe webhook: {Type}", stripeEvent.Type); + + switch (stripeEvent.Type) + { + case EventTypes.CheckoutSessionCompleted: + await HandleCheckoutCompleted(stripeEvent); + break; + case EventTypes.CustomerSubscriptionUpdated: + await HandleSubscriptionUpdated(stripeEvent); + break; + case EventTypes.CustomerSubscriptionDeleted: + await HandleSubscriptionDeleted(stripeEvent); + break; + } + } + + public async Task GetSubscriptionAsync(Guid userId) + { + var sub = await _db.Subscriptions.FirstOrDefaultAsync(s => s.UserId == userId); + if (sub == null) + { + var user = await _db.Users.FindAsync(userId); + return user == null ? null : new SubscriptionDto(user.Tier, "active", null, null); + } + return new SubscriptionDto(sub.Tier, sub.Status.ToString().ToLower(), sub.CurrentPeriodEnd, sub.StripeCustomerId); + } + + private async Task HandleCheckoutCompleted(Event stripeEvent) + { + var session = stripeEvent.Data.Object as Stripe.Checkout.Session; + if (session == null) return; + + var userIdStr = session.Metadata.GetValueOrDefault("userId"); + var tierStr = session.Metadata.GetValueOrDefault("tier"); + if (!Guid.TryParse(userIdStr, out var userId) || !Enum.TryParse(tierStr, out var tier)) + return; + + var sub = await _db.Subscriptions.FirstOrDefaultAsync(s => s.UserId == userId); + if (sub == null) + { + sub = new Core.Entities.Subscription { UserId = userId }; + _db.Subscriptions.Add(sub); + } + + sub.StripeCustomerId = session.CustomerId; + sub.StripeSubscriptionId = session.SubscriptionId; + sub.Tier = tier; + sub.Status = SubscriptionStatus.Active; + sub.UpdatedAt = DateTime.UtcNow; + + var user = await _db.Users.FindAsync(userId); + if (user != null) user.Tier = tier; + + await _db.SaveChangesAsync(); + _logger.LogInformation("Checkout completed: user {UserId} -> {Tier}", userId, tier); + } + + private async Task HandleSubscriptionUpdated(Event stripeEvent) + { + if (stripeEvent.Data.Object is not Stripe.Subscription stripeSub) return; + + var sub = await _db.Subscriptions.FirstOrDefaultAsync(s => s.StripeSubscriptionId == stripeSub.Id); + if (sub == null) return; + + sub.Status = stripeSub.Status switch + { + "active" => SubscriptionStatus.Active, + "past_due" => SubscriptionStatus.PastDue, + "canceled" => SubscriptionStatus.Cancelled, + "incomplete" => SubscriptionStatus.Incomplete, + "trialing" => SubscriptionStatus.Trialing, + _ => SubscriptionStatus.Active + }; + sub.CurrentPeriodStart = stripeSub.CurrentPeriodStart; + sub.CurrentPeriodEnd = stripeSub.CurrentPeriodEnd; + sub.UpdatedAt = DateTime.UtcNow; + + await _db.SaveChangesAsync(); + } + + private async Task HandleSubscriptionDeleted(Event stripeEvent) + { + if (stripeEvent.Data.Object is not Stripe.Subscription stripeSub) return; + + var sub = await _db.Subscriptions.FirstOrDefaultAsync(s => s.StripeSubscriptionId == stripeSub.Id); + if (sub == null) return; + + sub.Status = SubscriptionStatus.Cancelled; + sub.CancelledAt = DateTime.UtcNow; + sub.UpdatedAt = DateTime.UtcNow; + + var user = await _db.Users.FindAsync(sub.UserId); + if (user != null) user.Tier = UserTier.Free; + + await _db.SaveChangesAsync(); + _logger.LogInformation("Subscription cancelled: user {UserId} -> Free", sub.UserId); + } +} diff --git a/src/PipeRAG.Infrastructure/Services/UsageTrackingService.cs b/src/PipeRAG.Infrastructure/Services/UsageTrackingService.cs new file mode 100644 index 0000000..8a3994f --- /dev/null +++ b/src/PipeRAG.Infrastructure/Services/UsageTrackingService.cs @@ -0,0 +1,127 @@ +using Microsoft.EntityFrameworkCore; +using PipeRAG.Core.Entities; +using PipeRAG.Core.Enums; +using PipeRAG.Core.Interfaces; +using PipeRAG.Infrastructure.Data; + +namespace PipeRAG.Infrastructure.Services; + +public class UsageTrackingService : IUsageTrackingService +{ + private readonly PipeRagDbContext _db; + + public UsageTrackingService(PipeRagDbContext db) => _db = db; + + public async Task IncrementQueryCountAsync(Guid userId) + { + var record = await GetOrCreateTodayRecord(userId); + record.QueryCount++; + record.UpdatedAt = DateTime.UtcNow; + await _db.SaveChangesAsync(); + } + + public async Task GetUsageAsync(Guid userId) + { + var user = await _db.Users.FindAsync(userId) + ?? throw new InvalidOperationException("User not found"); + var limits = TierLimits.GetLimits(user.Tier); + var today = await GetOrCreateTodayRecord(userId); + + var totalDocs = await _db.Documents + .CountAsync(d => d.Project.OwnerId == userId); + var totalProjects = await _db.Projects + .CountAsync(p => p.OwnerId == userId); + var totalStorage = await _db.Documents + .Where(d => d.Project.OwnerId == userId) + .SumAsync(d => (long)d.FileSizeBytes); + + return new UsageDto( + today.QueryCount, limits.QueriesPerDay, + totalDocs, limits.MaxDocuments, + totalProjects, limits.MaxProjects, + totalStorage, limits.MaxStorageBytes, + user.Tier); + } + + public async Task CanPerformQueryAsync(Guid userId) + { + var user = await _db.Users.FindAsync(userId); + if (user == null) return false; + if (user.Tier == UserTier.Enterprise) return true; + var limits = TierLimits.GetLimits(user.Tier); + var today = await GetOrCreateTodayRecord(userId); + return today.QueryCount < limits.QueriesPerDay; + } + + public async Task CanCreateDocumentAsync(Guid userId) + { + var user = await _db.Users.FindAsync(userId); + if (user == null) return false; + if (user.Tier == UserTier.Enterprise) return true; + var limits = TierLimits.GetLimits(user.Tier); + var count = await _db.Documents.CountAsync(d => d.Project.OwnerId == userId); + return count < limits.MaxDocuments; + } + + public async Task CanCreateProjectAsync(Guid userId) + { + var user = await _db.Users.FindAsync(userId); + if (user == null) return false; + if (user.Tier == UserTier.Enterprise) return true; + var limits = TierLimits.GetLimits(user.Tier); + var count = await _db.Projects.CountAsync(p => p.OwnerId == userId); + return count < limits.MaxProjects; + } + + public async Task CanUploadStorageAsync(Guid userId, long additionalBytes) + { + var user = await _db.Users.FindAsync(userId); + if (user == null) return false; + if (user.Tier == UserTier.Enterprise) return true; + var limits = TierLimits.GetLimits(user.Tier); + var used = await _db.Documents + .Where(d => d.Project.OwnerId == userId) + .SumAsync(d => (long)d.FileSizeBytes); + return (used + additionalBytes) <= limits.MaxStorageBytes; + } + + public async Task RecalculateDocumentCountAsync(Guid userId) + { + var record = await GetOrCreateTodayRecord(userId); + record.DocumentCount = await _db.Documents.CountAsync(d => d.Project.OwnerId == userId); + record.UpdatedAt = DateTime.UtcNow; + await _db.SaveChangesAsync(); + } + + public async Task RecalculateProjectCountAsync(Guid userId) + { + var record = await GetOrCreateTodayRecord(userId); + record.ProjectCount = await _db.Projects.CountAsync(p => p.OwnerId == userId); + record.UpdatedAt = DateTime.UtcNow; + await _db.SaveChangesAsync(); + } + + public async Task RecalculateStorageAsync(Guid userId) + { + var record = await GetOrCreateTodayRecord(userId); + record.StorageBytes = await _db.Documents + .Where(d => d.Project.OwnerId == userId) + .SumAsync(d => (long)d.FileSizeBytes); + record.UpdatedAt = DateTime.UtcNow; + await _db.SaveChangesAsync(); + } + + private async Task GetOrCreateTodayRecord(Guid userId) + { + var today = DateTime.UtcNow.Date; + var record = await _db.UsageRecords + .FirstOrDefaultAsync(r => r.UserId == userId && r.Date == today); + if (record == null) + { + record = new UsageRecord { UserId = userId, Date = today }; + _db.UsageRecords.Add(record); + await _db.SaveChangesAsync(); + } + return record; + } +}