From 3530c140f0d06d7881fa19bb520161199b759e99 Mon Sep 17 00:00:00 2001 From: Prathik Shetty Date: Wed, 11 Feb 2026 23:36:06 +0530 Subject: [PATCH] fix(api): Enforce provider exclusivity and Gemini streaming --- .../ai/__tests__/ai.controller.spec.ts | 229 +++++++++++ .../mukti-api/src/modules/ai/ai.controller.ts | 264 +++++++++--- .../src/modules/ai/dto/ai-settings.dto.ts | 3 +- .../modules/ai/services/ai-policy.service.ts | 138 ++++++- .../__tests__/conversation.controller.spec.ts | 5 + .../conversations/conversation.controller.ts | 65 ++- .../conversations/conversations.module.ts | 3 + .../conversations/dto/send-message.dto.ts | 2 +- .../__tests__/queue.gemini-stream.spec.ts | 226 +++++++++++ .../services/__tests__/queue.service.spec.ts | 18 + .../conversations/services/gemini.service.ts | 76 ++++ .../conversations/services/queue.service.ts | 154 +++++-- .../conversations/services/stream.service.ts | 381 ++++++++++++++---- packages/mukti-api/src/schemas/user.schema.ts | 1 + .../src/app/dashboard/settings/page.tsx | 51 ++- .../src/components/ai/model-selector.tsx | 10 +- packages/mukti-web/src/lib/api/ai.ts | 49 ++- .../hooks/__tests__/use-conversations.spec.ts | 59 +++ .../src/lib/hooks/use-conversations.ts | 81 +++- packages/mukti-web/src/lib/stores/ai-store.ts | 24 +- 20 files changed, 1621 insertions(+), 218 deletions(-) create mode 100644 packages/mukti-api/src/modules/ai/__tests__/ai.controller.spec.ts create mode 100644 packages/mukti-api/src/modules/conversations/services/__tests__/queue.gemini-stream.spec.ts create mode 100644 packages/mukti-api/src/modules/conversations/services/gemini.service.ts diff --git a/packages/mukti-api/src/modules/ai/__tests__/ai.controller.spec.ts b/packages/mukti-api/src/modules/ai/__tests__/ai.controller.spec.ts new file mode 100644 index 00000000..71341c0d --- /dev/null +++ b/packages/mukti-api/src/modules/ai/__tests__/ai.controller.spec.ts @@ -0,0 +1,229 @@ +import { ConfigService } from '@nestjs/config'; +import { getModelToken } from '@nestjs/mongoose'; +import { Test, type TestingModule } from '@nestjs/testing'; + +jest.mock('@openrouter/sdk', () => ({ + OpenRouter: jest.fn(() => ({})), +})); + +import { User } from '../../../schemas/user.schema'; +import { AiController } from '../ai.controller'; +import { AiPolicyService } from '../services/ai-policy.service'; +import { AiSecretsService } from '../services/ai-secrets.service'; +import { OpenRouterModelsService } from '../services/openrouter-models.service'; + +type PlainObject = Record; + +function applyUpdate(target: PlainObject, update: PlainObject) { + const set = update.$set ?? {}; + const unset = update.$unset ?? {}; + + Object.entries(set).forEach(([path, value]) => { + setNested(target, path, value); + }); + + Object.keys(unset).forEach((path) => { + unsetNested(target, path); + }); +} + +function setNested(target: PlainObject, path: string, value: unknown) { + const keys = path.split('.'); + let cursor = target; + + for (let index = 0; index < keys.length - 1; index += 1) { + const key = keys[index]; + if (!cursor[key] || typeof cursor[key] !== 'object') { + cursor[key] = {}; + } + cursor = cursor[key]; + } + + cursor[keys[keys.length - 1]] = value; +} + +function unsetNested(target: PlainObject, path: string) { + const keys = path.split('.'); + let cursor = target; + + for (let index = 0; index < keys.length - 1; index += 1) { + const key = keys[index]; + if (!cursor[key] || typeof cursor[key] !== 'object') { + return; + } + cursor = cursor[key]; + } + + delete cursor[keys[keys.length - 1]]; +} + +describe('AiController', () => { + let controller: AiController; + let aiPolicyService: AiPolicyService; + let openRouterModelsService: jest.Mocked; + let state: PlainObject; + + const mockUserModel = { + findById: jest.fn(), + updateOne: jest.fn(), + }; + + const mockAiSecretsService = { + decryptString: jest.fn((cipher: string) => cipher.replace(/^enc:/, '')), + encryptString: jest.fn((plain: string) => `enc:${plain}`), + }; + + const mockOpenRouterModelsService = { + listModels: jest.fn(), + validateModelExists: jest.fn().mockResolvedValue(true), + }; + + const mockConfigService = { + get: jest.fn((key: string) => { + if (key === 'OPENROUTER_API_KEY') { + return 'server-openrouter-key'; + } + + return undefined; + }), + }; + + const makeFindByIdChain = () => ({ + lean: jest.fn(async () => state), + select: jest.fn().mockImplementation(() => makeFindByIdChain()), + }); + + beforeEach(async () => { + state = { + _id: 'user-123', + geminiApiKeyEncrypted: undefined, + geminiApiKeyLast4: null, + geminiApiKeyUpdatedAt: undefined, + openRouterApiKeyEncrypted: undefined, + openRouterApiKeyLast4: null, + openRouterApiKeyUpdatedAt: undefined, + preferences: { + activeModel: 'openai/gpt-5-mini', + activeProvider: 'openrouter', + }, + }; + + mockUserModel.findById.mockImplementation(() => makeFindByIdChain()); + mockUserModel.updateOne.mockImplementation(async (_query, update) => { + applyUpdate(state, update as PlainObject); + return { acknowledged: true }; + }); + mockOpenRouterModelsService.listModels.mockResolvedValue([ + { id: 'openai/gpt-5-mini', name: 'GPT-5 Mini' }, + ]); + + const module: TestingModule = await Test.createTestingModule({ + controllers: [AiController], + providers: [ + AiPolicyService, + { + provide: getModelToken(User.name), + useValue: mockUserModel, + }, + { + provide: AiSecretsService, + useValue: mockAiSecretsService, + }, + { + provide: OpenRouterModelsService, + useValue: mockOpenRouterModelsService, + }, + { + provide: ConfigService, + useValue: mockConfigService, + }, + ], + }).compile(); + + controller = module.get(AiController); + aiPolicyService = module.get(AiPolicyService); + openRouterModelsService = module.get(OpenRouterModelsService); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('saving Gemini key clears OpenRouter key and persists Gemini as active provider', async () => { + state.openRouterApiKeyEncrypted = 'enc:sk-or-v1-previous'; + state.openRouterApiKeyLast4 = '9999'; + state.openRouterApiKeyUpdatedAt = new Date(); + state.preferences.activeModel = 'openai/gpt-5-mini'; + state.preferences.activeProvider = 'openrouter'; + + const result = await controller.setGeminiKey('user-123', { + apiKey: 'AIzaSy_test_gemini_key', + }); + + expect(result.success).toBe(true); + expect(result.data.activeProvider).toBe('gemini'); + expect(state.openRouterApiKeyEncrypted).toBeUndefined(); + expect(state.openRouterApiKeyLast4).toBeUndefined(); + expect(state.openRouterApiKeyUpdatedAt).toBeUndefined(); + expect(state.preferences.activeProvider).toBe('gemini'); + expect(state.preferences.activeModel).toBe( + aiPolicyService.getDefaultModel('gemini'), + ); + }); + + it('saving OpenRouter key clears Gemini key and persists OpenRouter as active provider', async () => { + state.geminiApiKeyEncrypted = 'enc:AIzaSy_previous'; + state.geminiApiKeyLast4 = '8888'; + state.geminiApiKeyUpdatedAt = new Date(); + state.preferences.activeModel = 'gemini-2.0-flash'; + state.preferences.activeProvider = 'gemini'; + + const result = await controller.setOpenRouterKey('user-123', { + apiKey: 'sk-or-v1-test-openrouter-key', + }); + + expect(openRouterModelsService.listModels).toHaveBeenCalledTimes(1); + expect(result.success).toBe(true); + expect(result.data.activeProvider).toBe('openrouter'); + expect(state.geminiApiKeyEncrypted).toBeUndefined(); + expect(state.geminiApiKeyLast4).toBeUndefined(); + expect(state.geminiApiKeyUpdatedAt).toBeUndefined(); + expect(state.preferences.activeProvider).toBe('openrouter'); + expect(state.preferences.activeModel).toBe( + aiPolicyService.getDefaultModel('openrouter'), + ); + }); + + it('GET /ai/settings returns activeProvider from persisted preferences', async () => { + state.geminiApiKeyEncrypted = 'enc:AIzaSy_test'; + state.geminiApiKeyLast4 = '4321'; + state.geminiApiKeyUpdatedAt = new Date(); + state.preferences.activeProvider = 'gemini'; + state.preferences.activeModel = 'gemini-2.0-flash'; + + const result = await controller.getSettings('user-123'); + + expect(result.success).toBe(true); + expect(result.data.activeProvider).toBe('gemini'); + expect(result.data.activeModel).toBe('gemini-2.0-flash'); + expect(result.data.hasGeminiKey).toBe(true); + expect(result.data.hasOpenRouterKey).toBe(false); + }); + + it('GET /ai/models returns Gemini models when active provider is Gemini', async () => { + state.geminiApiKeyEncrypted = 'enc:AIzaSy_test'; + state.geminiApiKeyUpdatedAt = new Date(); + state.preferences.activeProvider = 'gemini'; + state.preferences.activeModel = 'openai/gpt-5-mini'; + + const result = await controller.getModels('user-123'); + + expect(result.success).toBe(true); + expect(result.data.provider).toBe('gemini'); + expect(result.data.mode).toBe('gemini'); + expect(result.data.models).toEqual(aiPolicyService.getGeminiModels()); + expect(state.preferences.activeModel).toBe( + aiPolicyService.getDefaultModel('gemini'), + ); + }); +}); diff --git a/packages/mukti-api/src/modules/ai/ai.controller.ts b/packages/mukti-api/src/modules/ai/ai.controller.ts index 7e1fc618..8b4bfb50 100644 --- a/packages/mukti-api/src/modules/ai/ai.controller.ts +++ b/packages/mukti-api/src/modules/ai/ai.controller.ts @@ -19,7 +19,7 @@ import { CurrentUser } from '../auth/decorators/current-user.decorator'; import { UpdateAiSettingsDto } from './dto/ai-settings.dto'; import { SetGeminiKeyDto } from './dto/gemini-key.dto'; import { SetOpenRouterKeyDto } from './dto/openrouter-key.dto'; -import { AiPolicyService } from './services/ai-policy.service'; +import { AiPolicyService, type AiProvider } from './services/ai-policy.service'; import { AiSecretsService } from './services/ai-secrets.service'; import { OpenRouterModelsService } from './services/openrouter-models.service'; @@ -40,7 +40,7 @@ export class AiController { const user = await this.userModel .findById(userId) .select( - 'preferences openRouterApiKeyLast4 openRouterApiKeyUpdatedAt geminiApiKeyLast4 geminiApiKeyUpdatedAt', + 'preferences openRouterApiKeyLast4 openRouterApiKeyUpdatedAt geminiApiKeyLast4 geminiApiKeyUpdatedAt +openRouterApiKeyEncrypted +geminiApiKeyEncrypted', ) .lean(); @@ -48,12 +48,43 @@ export class AiController { throw new Error('User not found'); } + const hasOpenRouterByok = this.aiPolicyService.hasUserOpenRouterKey(user); + const hasGeminiKey = this.aiPolicyService.hasUserGeminiKey(user); + const activeProvider = this.aiPolicyService.resolveActiveProvider({ + hasGeminiKey, + hasOpenRouterAccess: hasOpenRouterByok || this.hasServerOpenRouterKey(), + preferredProvider: user.preferences?.activeProvider, + }); + const activeModel = this.aiPolicyService.coerceModelForProvider({ + activeProvider, + hasOpenRouterByok, + model: user.preferences?.activeModel, + }); + + const settingsPatch: Record = {}; + if (user.preferences?.activeProvider !== activeProvider) { + settingsPatch['preferences.activeProvider'] = activeProvider; + } + if (user.preferences?.activeModel !== activeModel) { + settingsPatch['preferences.activeModel'] = activeModel; + } + + if (Object.keys(settingsPatch).length > 0) { + await this.userModel.updateOne( + { _id: userId }, + { + $set: settingsPatch, + }, + ); + } + return { data: { - activeModel: user.preferences?.activeModel, + activeModel, + activeProvider, geminiKeyLast4: user.geminiApiKeyLast4 ?? null, - hasGeminiKey: !!user.geminiApiKeyUpdatedAt, - hasOpenRouterKey: !!user.openRouterApiKeyUpdatedAt, + hasGeminiKey, + hasOpenRouterKey: hasOpenRouterByok, openRouterKeyLast4: user.openRouterApiKeyLast4 ?? null, }, success: true, @@ -65,40 +96,56 @@ export class AiController { @CurrentUser('_id') userId: string, @Body() dto: UpdateAiSettingsDto, ) { - if (!dto.activeModel) { - return { - data: { activeModel: null }, - success: true, - }; - } - const user = await this.userModel .findById(userId) - .select('+openRouterApiKeyEncrypted preferences') + .select('+openRouterApiKeyEncrypted +geminiApiKeyEncrypted preferences') .lean(); if (!user) { throw new Error('User not found'); } - const hasByok = this.aiPolicyService.hasUserOpenRouterKey(user); - - const validationApiKey = this.getValidationApiKey({ hasByok, user }); - - const effectiveModel = await this.aiPolicyService.resolveEffectiveModel({ - hasByok, - requestedModel: dto.activeModel, - userActiveModel: user.preferences?.activeModel, - validationApiKey, + const hasOpenRouterByok = this.aiPolicyService.hasUserOpenRouterKey(user); + const hasGeminiKey = this.aiPolicyService.hasUserGeminiKey(user); + const activeProvider = this.aiPolicyService.resolveActiveProvider({ + hasGeminiKey, + hasOpenRouterAccess: hasOpenRouterByok || this.hasServerOpenRouterKey(), + preferredProvider: user.preferences?.activeProvider, }); + const validationApiKey = + activeProvider === 'openrouter' + ? this.getValidationApiKey({ hasByok: hasOpenRouterByok, user }) + : undefined; + + const effectiveModel = dto.activeModel + ? await this.aiPolicyService.resolveEffectiveModel({ + activeProvider, + hasByok: hasOpenRouterByok, + requestedModel: dto.activeModel, + userActiveModel: user.preferences?.activeModel, + validationApiKey, + }) + : this.aiPolicyService.coerceModelForProvider({ + activeProvider, + hasOpenRouterByok, + model: user.preferences?.activeModel, + }); await this.userModel.updateOne( { _id: userId }, - { $set: { 'preferences.activeModel': effectiveModel } }, + { + $set: { + 'preferences.activeModel': effectiveModel, + 'preferences.activeProvider': activeProvider, + }, + }, ); return { - data: { activeModel: effectiveModel }, + data: { + activeModel: effectiveModel, + activeProvider, + }, success: true, }; } @@ -134,6 +181,15 @@ export class AiController { const encrypted = this.aiSecretsService.encryptString(apiKey); const last4 = apiKey.slice(-4); + const user = await this.userModel + .findById(userId) + .select('preferences') + .lean(); + const activeModel = this.aiPolicyService.coerceModelForProvider({ + activeProvider: 'openrouter', + hasOpenRouterByok: true, + model: user?.preferences?.activeModel, + }); await this.userModel.updateOne( { _id: userId }, @@ -142,12 +198,23 @@ export class AiController { openRouterApiKeyEncrypted: encrypted, openRouterApiKeyLast4: last4, openRouterApiKeyUpdatedAt: new Date(), + 'preferences.activeModel': activeModel, + 'preferences.activeProvider': 'openrouter', + }, + $unset: { + geminiApiKeyEncrypted: 1, + geminiApiKeyLast4: 1, + geminiApiKeyUpdatedAt: 1, }, }, ); return { data: { + activeModel, + activeProvider: 'openrouter', + geminiKeyLast4: null, + hasGeminiKey: false, hasOpenRouterKey: true, openRouterKeyLast4: last4, }, @@ -158,9 +225,38 @@ export class AiController { @Delete('openrouter-key') @HttpCode(HttpStatus.OK) async deleteOpenRouterKey(@CurrentUser('_id') userId: string) { + const user = await this.userModel + .findById(userId) + .select('+geminiApiKeyEncrypted preferences') + .lean(); + + if (!user) { + throw new Error('User not found'); + } + + const hasGeminiKey = this.aiPolicyService.hasUserGeminiKey(user); + const nextPreferredProvider: AiProvider | undefined = + user.preferences?.activeProvider === 'openrouter' && hasGeminiKey + ? 'gemini' + : user.preferences?.activeProvider; + const activeProvider = this.aiPolicyService.resolveActiveProvider({ + hasGeminiKey, + hasOpenRouterAccess: this.hasServerOpenRouterKey(), + preferredProvider: nextPreferredProvider, + }); + const activeModel = this.aiPolicyService.coerceModelForProvider({ + activeProvider, + hasOpenRouterByok: false, + model: user.preferences?.activeModel, + }); + await this.userModel.updateOne( { _id: userId }, { + $set: { + 'preferences.activeModel': activeModel, + 'preferences.activeProvider': activeProvider, + }, $unset: { openRouterApiKeyEncrypted: 1, openRouterApiKeyLast4: 1, @@ -169,29 +265,10 @@ export class AiController { }, ); - // If activeModel is not curated, reset to default. - const user = await this.userModel - .findById(userId) - .select('preferences') - .lean(); - const activeModel = user?.preferences?.activeModel; - const isCurated = this.aiPolicyService - .getCuratedModels() - .some((m) => m.id === activeModel); - - if (!isCurated) { - await this.userModel.updateOne( - { _id: userId }, - { - $set: { - 'preferences.activeModel': this.aiPolicyService.getDefaultModel(), - }, - }, - ); - } - return { data: { + activeModel, + activeProvider, hasOpenRouterKey: false, openRouterKeyLast4: null, }, @@ -221,6 +298,15 @@ export class AiController { const encrypted = this.aiSecretsService.encryptString(apiKey); const last4 = apiKey.slice(-4); + const user = await this.userModel + .findById(userId) + .select('preferences') + .lean(); + const activeModel = this.aiPolicyService.coerceModelForProvider({ + activeProvider: 'gemini', + hasOpenRouterByok: false, + model: user?.preferences?.activeModel, + }); await this.userModel.updateOne( { _id: userId }, @@ -229,14 +315,25 @@ export class AiController { geminiApiKeyEncrypted: encrypted, geminiApiKeyLast4: last4, geminiApiKeyUpdatedAt: new Date(), + 'preferences.activeModel': activeModel, + 'preferences.activeProvider': 'gemini', + }, + $unset: { + openRouterApiKeyEncrypted: 1, + openRouterApiKeyLast4: 1, + openRouterApiKeyUpdatedAt: 1, }, }, ); return { data: { + activeModel, + activeProvider: 'gemini', geminiKeyLast4: last4, hasGeminiKey: true, + hasOpenRouterKey: false, + openRouterKeyLast4: null, }, success: true, }; @@ -245,9 +342,38 @@ export class AiController { @Delete('gemini-key') @HttpCode(HttpStatus.OK) async deleteGeminiKey(@CurrentUser('_id') userId: string) { + const user = await this.userModel + .findById(userId) + .select('+openRouterApiKeyEncrypted preferences') + .lean(); + + if (!user) { + throw new Error('User not found'); + } + + const hasOpenRouterByok = this.aiPolicyService.hasUserOpenRouterKey(user); + const nextPreferredProvider: AiProvider | undefined = + user.preferences?.activeProvider === 'gemini' + ? 'openrouter' + : user.preferences?.activeProvider; + const activeProvider = this.aiPolicyService.resolveActiveProvider({ + hasGeminiKey: false, + hasOpenRouterAccess: hasOpenRouterByok || this.hasServerOpenRouterKey(), + preferredProvider: nextPreferredProvider, + }); + const activeModel = this.aiPolicyService.coerceModelForProvider({ + activeProvider, + hasOpenRouterByok, + model: user.preferences?.activeModel, + }); + await this.userModel.updateOne( { _id: userId }, { + $set: { + 'preferences.activeModel': activeModel, + 'preferences.activeProvider': activeProvider, + }, $unset: { geminiApiKeyEncrypted: 1, geminiApiKeyLast4: 1, @@ -258,6 +384,8 @@ export class AiController { return { data: { + activeModel, + activeProvider, geminiKeyLast4: null, hasGeminiKey: false, }, @@ -269,16 +397,50 @@ export class AiController { async getModels(@CurrentUser('_id') userId: string) { const user = await this.userModel .findById(userId) - .select('+openRouterApiKeyEncrypted preferences') + .select('+openRouterApiKeyEncrypted +geminiApiKeyEncrypted preferences') .lean(); if (!user) { throw new Error('User not found'); } - const hasByok = this.aiPolicyService.hasUserOpenRouterKey(user); + const hasOpenRouterByok = this.aiPolicyService.hasUserOpenRouterKey(user); + const hasGeminiKey = this.aiPolicyService.hasUserGeminiKey(user); + const activeProvider = this.aiPolicyService.resolveActiveProvider({ + hasGeminiKey, + hasOpenRouterAccess: hasOpenRouterByok || this.hasServerOpenRouterKey(), + preferredProvider: user.preferences?.activeProvider, + }); + const activeModel = this.aiPolicyService.coerceModelForProvider({ + activeProvider, + hasOpenRouterByok, + model: user.preferences?.activeModel, + }); + + if (user.preferences?.activeModel !== activeModel) { + await this.userModel.updateOne( + { _id: userId }, + { + $set: { + 'preferences.activeModel': activeModel, + 'preferences.activeProvider': activeProvider, + }, + }, + ); + } + + if (activeProvider === 'gemini') { + return { + data: { + mode: 'gemini', + models: this.aiPolicyService.getGeminiModels(), + provider: 'gemini', + }, + success: true, + }; + } - if (!hasByok) { + if (!hasOpenRouterByok) { const validationApiKey = this.configService.get('OPENROUTER_API_KEY') ?? ''; if (validationApiKey) { @@ -297,6 +459,7 @@ export class AiController { data: { mode: 'curated', models: this.aiPolicyService.getCuratedModels(), + provider: 'openrouter', }, success: true, }; @@ -314,6 +477,7 @@ export class AiController { id: m.id, name: m.name, })), + provider: 'openrouter', }, success: true, }; @@ -335,4 +499,8 @@ export class AiController { return serverKey; } + + private hasServerOpenRouterKey(): boolean { + return !!(this.configService.get('OPENROUTER_API_KEY') ?? ''); + } } diff --git a/packages/mukti-api/src/modules/ai/dto/ai-settings.dto.ts b/packages/mukti-api/src/modules/ai/dto/ai-settings.dto.ts index 5877892e..73cfe4bb 100644 --- a/packages/mukti-api/src/modules/ai/dto/ai-settings.dto.ts +++ b/packages/mukti-api/src/modules/ai/dto/ai-settings.dto.ts @@ -3,7 +3,8 @@ import { IsOptional, IsString } from 'class-validator'; export class UpdateAiSettingsDto { @ApiPropertyOptional({ - description: 'Globally active OpenRouter model id', + description: + 'Globally active AI model id for the currently active provider', example: 'openai/gpt-5-mini', }) @IsOptional() diff --git a/packages/mukti-api/src/modules/ai/services/ai-policy.service.ts b/packages/mukti-api/src/modules/ai/services/ai-policy.service.ts index 22555b10..6b04fb39 100644 --- a/packages/mukti-api/src/modules/ai/services/ai-policy.service.ts +++ b/packages/mukti-api/src/modules/ai/services/ai-policy.service.ts @@ -5,17 +5,25 @@ import type { User } from '../../../schemas/user.schema'; import { OpenRouterModelsService } from './openrouter-models.service'; +export type AiModelMode = 'curated' | 'gemini' | 'openrouter'; +export type AiProvider = 'gemini' | 'openrouter'; + export interface AllowedModel { id: string; label: string; } -const DEFAULT_MODEL = 'openai/gpt-5-mini'; +const DEFAULT_GEMINI_MODEL = 'gemini-2.0-flash'; +const DEFAULT_OPENROUTER_MODEL = 'openai/gpt-5-mini'; -const CURATED_MODELS: AllowedModel[] = [ +const CURATED_OPENROUTER_MODELS: AllowedModel[] = [ { id: 'openai/gpt-5-mini', label: 'GPT-5 Mini' }, ]; +const GEMINI_MODELS: AllowedModel[] = [ + { id: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' }, +]; + @Injectable() export class AiPolicyService { constructor( @@ -24,11 +32,101 @@ export class AiPolicyService { ) {} getCuratedModels(): AllowedModel[] { - return CURATED_MODELS; + return CURATED_OPENROUTER_MODELS; + } + + getDefaultProvider(): AiProvider { + return 'openrouter'; + } + + getDefaultModel(provider: AiProvider = 'openrouter'): string { + return provider === 'gemini' + ? DEFAULT_GEMINI_MODEL + : DEFAULT_OPENROUTER_MODEL; + } + + getGeminiModels(): AllowedModel[] { + return GEMINI_MODELS; + } + + getModelMode(params: { + activeProvider: AiProvider; + hasOpenRouterByok: boolean; + }): AiModelMode { + if (params.activeProvider === 'gemini') { + return 'gemini'; + } + + return params.hasOpenRouterByok ? 'openrouter' : 'curated'; } - getDefaultModel(): string { - return DEFAULT_MODEL; + isGeminiModel(model: string): boolean { + return GEMINI_MODELS.some((allowed) => allowed.id === model.trim()); + } + + isModelCompatibleWithProvider(params: { + activeProvider: AiProvider; + hasOpenRouterByok: boolean; + model?: string; + }): boolean { + const model = params.model?.trim(); + + if (!model) { + return false; + } + + if (params.activeProvider === 'gemini') { + return this.isGeminiModel(model); + } + + if (params.hasOpenRouterByok) { + return !this.isGeminiModel(model); + } + + return CURATED_OPENROUTER_MODELS.some((allowed) => allowed.id === model); + } + + coerceModelForProvider(params: { + activeProvider: AiProvider; + hasOpenRouterByok: boolean; + model?: string; + }): string { + if (this.isModelCompatibleWithProvider(params)) { + return params.model!.trim(); + } + + return this.getDefaultModel(params.activeProvider); + } + + resolveActiveProvider(params: { + hasGeminiKey: boolean; + hasOpenRouterAccess: boolean; + preferredProvider?: AiProvider; + }): AiProvider { + if (params.preferredProvider === 'gemini' && params.hasGeminiKey) { + return 'gemini'; + } + + if ( + params.preferredProvider === 'openrouter' && + params.hasOpenRouterAccess + ) { + return 'openrouter'; + } + + if (params.hasGeminiKey && !params.hasOpenRouterAccess) { + return 'gemini'; + } + + if (params.hasOpenRouterAccess) { + return 'openrouter'; + } + + if (params.hasGeminiKey) { + return 'gemini'; + } + + return this.getDefaultProvider(); } getValidationApiKey(params: { @@ -58,16 +156,36 @@ export class AiPolicyService { } async resolveEffectiveModel(params: { + activeProvider?: AiProvider; hasByok: boolean; requestedModel?: string; userActiveModel?: string; - validationApiKey: string; + validationApiKey?: string; }): Promise { + const activeProvider = params.activeProvider ?? 'openrouter'; const candidate = - params.requestedModel ?? params.userActiveModel ?? DEFAULT_MODEL; + params.requestedModel ?? + params.userActiveModel ?? + this.getDefaultModel(activeProvider); + + if (activeProvider === 'gemini') { + const isAllowed = GEMINI_MODELS.some((m) => m.id === candidate); + if (!isAllowed) { + throw new BadRequestException({ + error: { + code: 'MODEL_NOT_ALLOWED', + message: 'Model not available for Gemini', + }, + }); + } + + return candidate; + } if (!params.hasByok) { - const isCurated = CURATED_MODELS.some((m) => m.id === candidate); + const isCurated = CURATED_OPENROUTER_MODELS.some( + (m) => m.id === candidate, + ); if (!isCurated) { throw new BadRequestException({ error: { @@ -78,6 +196,10 @@ export class AiPolicyService { } } + if (!params.validationApiKey) { + throw new Error('OPENROUTER_API_KEY not configured'); + } + // Always validate the model exists on OpenRouter. await this.validateModelOrThrow({ apiKey: params.validationApiKey, diff --git a/packages/mukti-api/src/modules/conversations/__tests__/conversation.controller.spec.ts b/packages/mukti-api/src/modules/conversations/__tests__/conversation.controller.spec.ts index ce5713c5..bfb457af 100644 --- a/packages/mukti-api/src/modules/conversations/__tests__/conversation.controller.spec.ts +++ b/packages/mukti-api/src/modules/conversations/__tests__/conversation.controller.spec.ts @@ -72,7 +72,11 @@ describe('ConversationController', () => { }; const mockAiPolicyService = { + coerceModelForProvider: jest.fn().mockReturnValue('test-model'), + hasUserGeminiKey: jest.fn().mockReturnValue(false), + hasUserOpenRouterKey: jest.fn().mockReturnValue(false), resolveEffectiveModel: jest.fn().mockResolvedValue(undefined), + resolveActiveProvider: jest.fn().mockReturnValue('openrouter'), }; const mockAiSecretsService = { @@ -555,6 +559,7 @@ describe('ConversationController', () => { 'free', 'elenchus', undefined, + 'openrouter', false, ); }); diff --git a/packages/mukti-api/src/modules/conversations/conversation.controller.ts b/packages/mukti-api/src/modules/conversations/conversation.controller.ts index 4866519f..5fb296da 100644 --- a/packages/mukti-api/src/modules/conversations/conversation.controller.ts +++ b/packages/mukti-api/src/modules/conversations/conversation.controller.ts @@ -1,4 +1,5 @@ import { + BadRequestException, Body, Controller, Delete, @@ -306,41 +307,78 @@ export class ConversationController { const userRecord = await this.userModel .findById(user._id) - .select('+openRouterApiKeyEncrypted preferences') + .select('+openRouterApiKeyEncrypted +geminiApiKeyEncrypted preferences') .lean(); if (!userRecord) { throw new Error('User not found'); } - const usedByok = !!userRecord.openRouterApiKeyEncrypted; const serverApiKey = this.configService.get('OPENROUTER_API_KEY') ?? ''; + const hasOpenRouterByok = + this.aiPolicyService.hasUserOpenRouterKey(userRecord); + const hasGeminiKey = this.aiPolicyService.hasUserGeminiKey(userRecord); + const activeProvider = this.aiPolicyService.resolveActiveProvider({ + hasGeminiKey, + hasOpenRouterAccess: hasOpenRouterByok || !!serverApiKey, + preferredProvider: userRecord.preferences?.activeProvider, + }); + + if (activeProvider === 'gemini' && !hasGeminiKey) { + throw new BadRequestException({ + error: { + code: 'GEMINI_KEY_MISSING', + message: 'Gemini API key is required for Gemini provider', + }, + }); + } - if (!usedByok && !serverApiKey) { + if ( + activeProvider === 'openrouter' && + !hasOpenRouterByok && + !serverApiKey + ) { throw new Error('OPENROUTER_API_KEY not configured'); } - const validationApiKey = usedByok - ? this.aiSecretsService.decryptString( - userRecord.openRouterApiKeyEncrypted!, - ) - : serverApiKey; + const validationApiKey = + activeProvider === 'openrouter' + ? hasOpenRouterByok + ? this.aiSecretsService.decryptString( + userRecord.openRouterApiKeyEncrypted!, + ) + : serverApiKey + : undefined; + const normalizedActiveModel = this.aiPolicyService.coerceModelForProvider({ + activeProvider, + hasOpenRouterByok, + model: userRecord.preferences?.activeModel, + }); const effectiveModel = await this.aiPolicyService.resolveEffectiveModel({ - hasByok: usedByok, + activeProvider, + hasByok: hasOpenRouterByok, requestedModel: sendMessageDto.model, - userActiveModel: userRecord.preferences?.activeModel, + userActiveModel: normalizedActiveModel, validationApiKey, }); const shouldPersistModel = - !!sendMessageDto.model || !userRecord.preferences?.activeModel; + !!sendMessageDto.model || + !userRecord.preferences?.activeModel || + userRecord.preferences?.activeModel !== normalizedActiveModel || + userRecord.preferences?.activeProvider !== activeProvider; if (shouldPersistModel) { await this.userModel.updateOne( { _id: user._id }, - { $set: { 'preferences.activeModel': effectiveModel } }, + { + $set: { + 'preferences.activeModel': effectiveModel, + 'preferences.activeProvider': activeProvider, + }, + }, ); } @@ -351,7 +389,8 @@ export class ConversationController { subscriptionTier, conversation.technique, effectiveModel, - usedByok, + activeProvider, + hasOpenRouterByok, ); return { diff --git a/packages/mukti-api/src/modules/conversations/conversations.module.ts b/packages/mukti-api/src/modules/conversations/conversations.module.ts index 7cdd50fe..995952de 100644 --- a/packages/mukti-api/src/modules/conversations/conversations.module.ts +++ b/packages/mukti-api/src/modules/conversations/conversations.module.ts @@ -21,6 +21,7 @@ import { User, UserSchema } from '../../schemas/user.schema'; import { AiModule } from '../ai/ai.module'; import { ConversationController } from './conversation.controller'; import { ConversationService } from './services/conversation.service'; +import { GeminiService } from './services/gemini.service'; import { MessageService } from './services/message.service'; import { OpenRouterService } from './services/openrouter.service'; import { QueueService } from './services/queue.service'; @@ -45,6 +46,7 @@ import { StreamService } from './services/stream.service'; exports: [ SeedService, ConversationService, + GeminiService, MessageService, OpenRouterService, QueueService, @@ -97,6 +99,7 @@ import { StreamService } from './services/stream.service'; providers: [ SeedService, ConversationService, + GeminiService, MessageService, OpenRouterService, QueueService, diff --git a/packages/mukti-api/src/modules/conversations/dto/send-message.dto.ts b/packages/mukti-api/src/modules/conversations/dto/send-message.dto.ts index ab1c66fa..5f555987 100644 --- a/packages/mukti-api/src/modules/conversations/dto/send-message.dto.ts +++ b/packages/mukti-api/src/modules/conversations/dto/send-message.dto.ts @@ -20,7 +20,7 @@ export class SendMessageDto { content: string; @ApiPropertyOptional({ - description: 'OpenRouter model id to use for this message', + description: 'Model id to use for this message (provider-aware)', example: 'openai/gpt-5-mini', }) @IsOptional() diff --git a/packages/mukti-api/src/modules/conversations/services/__tests__/queue.gemini-stream.spec.ts b/packages/mukti-api/src/modules/conversations/services/__tests__/queue.gemini-stream.spec.ts new file mode 100644 index 00000000..88f888ea --- /dev/null +++ b/packages/mukti-api/src/modules/conversations/services/__tests__/queue.gemini-stream.spec.ts @@ -0,0 +1,226 @@ +import { getQueueToken } from '@nestjs/bullmq'; +import { ConfigService } from '@nestjs/config'; +import { getModelToken } from '@nestjs/mongoose'; +import { Test, type TestingModule } from '@nestjs/testing'; + +jest.mock('@openrouter/sdk', () => ({ + OpenRouter: jest.fn(() => ({})), +})); + +import { Conversation } from '../../../../schemas/conversation.schema'; +import { Technique } from '../../../../schemas/technique.schema'; +import { UsageEvent } from '../../../../schemas/usage-event.schema'; +import { User } from '../../../../schemas/user.schema'; +import { AiPolicyService } from '../../../ai/services/ai-policy.service'; +import { AiSecretsService } from '../../../ai/services/ai-secrets.service'; +import { GeminiService } from '../gemini.service'; +import { MessageService } from '../message.service'; +import { OpenRouterService } from '../openrouter.service'; +import { QueueService } from '../queue.service'; +import { StreamService } from '../stream.service'; + +describe('QueueService Gemini Streaming', () => { + let service: QueueService; + let streamService: { emitToConversation: jest.Mock }; + + beforeEach(async () => { + streamService = { + emitToConversation: jest.fn(), + }; + + const module: TestingModule = await Test.createTestingModule({ + providers: [ + QueueService, + { + provide: getQueueToken('conversation-requests'), + useValue: { + add: jest.fn(), + getJob: jest.fn(), + }, + }, + { + provide: getModelToken(Conversation.name), + useValue: { + findById: jest.fn().mockResolvedValue({ + _id: '507f1f77bcf86cd799439011', + recentMessages: [], + }), + }, + }, + { + provide: getModelToken(Technique.name), + useValue: { + findOne: jest.fn().mockResolvedValue({ + template: { + exampleQuestions: [], + followUpStrategy: 'Probe assumptions', + questioningStyle: 'Socratic', + systemPrompt: 'Ask reflective questions', + }, + }), + }, + }, + { + provide: getModelToken(UsageEvent.name), + useValue: { + create: jest.fn().mockResolvedValue(undefined), + }, + }, + { + provide: getModelToken(User.name), + useValue: { + findById: jest.fn().mockReturnValue({ + lean: jest.fn().mockResolvedValue({ + geminiApiKeyEncrypted: 'enc:AIzaSy_test_key', + }), + select: jest.fn().mockReturnThis(), + }), + }, + }, + { + provide: ConfigService, + useValue: { + get: jest.fn(), + }, + }, + { + provide: AiPolicyService, + useValue: { + getCuratedModels: jest.fn(() => [ + { id: 'openai/gpt-5-mini', label: 'GPT-5 Mini' }, + ]), + isGeminiModel: jest.fn().mockReturnValue(true), + }, + }, + { + provide: AiSecretsService, + useValue: { + decryptString: jest.fn((value: string) => + value.replace(/^enc:/, ''), + ), + }, + }, + { + provide: GeminiService, + useValue: { + sendMessage: jest.fn().mockResolvedValue({ + completionTokens: 12, + content: 'Have you considered alternative assumptions?', + cost: 0, + model: 'gemini-2.0-flash', + promptTokens: 21, + totalTokens: 33, + }), + }, + }, + { + provide: MessageService, + useValue: { + addMessageToConversation: jest.fn().mockResolvedValue({ + _id: { + toString: () => '507f1f77bcf86cd799439011', + }, + recentMessages: [ + { + content: 'What should I improve?', + role: 'user', + timestamp: new Date('2025-01-01T00:00:00.000Z'), + }, + { + content: 'Have you considered alternative assumptions?', + role: 'assistant', + timestamp: new Date('2025-01-01T00:00:00.000Z'), + }, + ], + totalMessageCount: 2, + }), + archiveOldMessages: jest.fn(), + buildConversationContext: jest.fn().mockReturnValue({ + messages: [], + systemPrompt: 'Ask reflective questions', + technique: {}, + }), + }, + }, + { + provide: OpenRouterService, + useValue: { + buildPrompt: jest.fn(), + sendChatCompletion: jest.fn(), + }, + }, + { + provide: StreamService, + useValue: streamService, + }, + ], + }).compile(); + + service = module.get(QueueService); + }); + + it('emits processing, message(user+assistant), and complete events for Gemini jobs', async () => { + const result = await service.process({ + data: { + conversationId: '507f1f77bcf86cd799439011', + message: 'What should I improve?', + model: 'gemini-2.0-flash', + provider: 'gemini', + subscriptionTier: 'free', + technique: 'elenchus', + usedByok: false, + userId: '507f1f77bcf86cd799439012', + }, + id: 'job-1', + } as any); + + const emittedEventTypes = streamService.emitToConversation.mock.calls.map( + (call) => call[1].type, + ); + + expect(emittedEventTypes).toContain('processing'); + expect(emittedEventTypes).toContain('progress'); + expect(emittedEventTypes).toContain('message'); + expect(emittedEventTypes).toContain('complete'); + + expect(streamService.emitToConversation).toHaveBeenCalledWith( + '507f1f77bcf86cd799439011', + { + data: { + content: 'What should I improve?', + role: 'user', + sequence: 1, + timestamp: '2025-01-01T00:00:00.000Z', + }, + type: 'message', + }, + ); + expect(streamService.emitToConversation).toHaveBeenCalledWith( + '507f1f77bcf86cd799439011', + { + data: { + content: 'Have you considered alternative assumptions?', + role: 'assistant', + sequence: 2, + timestamp: '2025-01-01T00:00:00.000Z', + tokens: 33, + }, + type: 'message', + }, + ); + expect(streamService.emitToConversation).toHaveBeenCalledWith( + '507f1f77bcf86cd799439011', + { + data: { + cost: 0, + jobId: 'job-1', + latency: expect.any(Number), + tokens: 33, + }, + type: 'complete', + }, + ); + + expect(result.tokens).toBe(33); + }); +}); diff --git a/packages/mukti-api/src/modules/conversations/services/__tests__/queue.service.spec.ts b/packages/mukti-api/src/modules/conversations/services/__tests__/queue.service.spec.ts index 2489f5f4..bf46b590 100644 --- a/packages/mukti-api/src/modules/conversations/services/__tests__/queue.service.spec.ts +++ b/packages/mukti-api/src/modules/conversations/services/__tests__/queue.service.spec.ts @@ -14,6 +14,7 @@ import { UsageEvent } from '../../../../schemas/usage-event.schema'; import { User } from '../../../../schemas/user.schema'; import { AiPolicyService } from '../../../ai/services/ai-policy.service'; import { AiSecretsService } from '../../../ai/services/ai-secrets.service'; +import { GeminiService } from '../gemini.service'; import { MessageService } from '../message.service'; import { OpenRouterService } from '../openrouter.service'; import { QueueService } from '../queue.service'; @@ -126,6 +127,10 @@ describe('QueueService', () => { sendChatCompletion: jest.fn(), }; + const mockGeminiService = { + sendMessage: jest.fn(), + }; + const mockStreamService = { addConnection: jest.fn(), cleanupConversation: jest.fn(), @@ -176,6 +181,10 @@ describe('QueueService', () => { provide: OpenRouterService, useValue: mockOpenRouterService, }, + { + provide: GeminiService, + useValue: mockGeminiService, + }, { provide: StreamService, useValue: mockStreamService, @@ -232,6 +241,7 @@ describe('QueueService', () => { subscriptionTier, technique, 'openai/gpt-5-mini', + 'openrouter', false, ); @@ -272,6 +282,7 @@ describe('QueueService', () => { 'free', 'elenchus', 'openai/gpt-5-mini', + 'openrouter', false, ); @@ -283,6 +294,7 @@ describe('QueueService', () => { 'paid', 'dialectic', 'openai/gpt-5-mini', + 'openrouter', false, ); @@ -306,6 +318,7 @@ describe('QueueService', () => { 'free', 'elenchus', 'openai/gpt-5-mini', + 'openrouter', false, ); @@ -342,6 +355,7 @@ describe('QueueService', () => { 'free', 'elenchus', 'openai/gpt-5-mini', + 'openrouter', false, ); await service.enqueueRequest( @@ -351,6 +365,7 @@ describe('QueueService', () => { 'paid', 'dialectic', 'openai/gpt-5-mini', + 'openrouter', false, ); @@ -413,6 +428,7 @@ describe('QueueService', () => { subscriptionTier, technique, 'openai/gpt-5-mini', + 'openrouter', false, ); @@ -425,6 +441,7 @@ describe('QueueService', () => { conversationId, message, model: 'openai/gpt-5-mini', + provider: 'openrouter', subscriptionTier, technique, usedByok: false, @@ -466,6 +483,7 @@ describe('QueueService', () => { 'free', 'elenchus', 'openai/gpt-5-mini', + 'openrouter', false, ); diff --git a/packages/mukti-api/src/modules/conversations/services/gemini.service.ts b/packages/mukti-api/src/modules/conversations/services/gemini.service.ts new file mode 100644 index 00000000..069ea0cd --- /dev/null +++ b/packages/mukti-api/src/modules/conversations/services/gemini.service.ts @@ -0,0 +1,76 @@ +import { Injectable, Logger } from '@nestjs/common'; + +import type { TechniqueTemplate } from '../../../schemas/technique.schema'; +import { GeminiClientFactory } from '../../ai/services/gemini-client.factory'; + +export interface GeminiResponse { + completionTokens: number; + content: string; + cost: number; + model: string; + promptTokens: number; + totalTokens: number; +} + +@Injectable() +export class GeminiService { + private readonly logger = new Logger(GeminiService.name); + + constructor(private readonly geminiClientFactory: GeminiClientFactory) {} + + async sendMessage(params: { + apiKey: string; + conversationHistory: { content: string; role: string }[]; + model: string; + technique: TechniqueTemplate; + userMessage: string; + }): Promise { + const client = this.geminiClientFactory.create(params.apiKey); + const model = client.getGenerativeModel({ model: params.model }); + + const historyText = params.conversationHistory + .map((message) => `${message.role.toUpperCase()}: ${message.content}`) + .join('\n\n'); + const prompt = [ + 'You are Mukti, a Socratic thinking partner.', + '', + 'System instructions:', + params.technique.systemPrompt, + '', + 'Conversation history:', + historyText || '(none)', + '', + 'Latest user message:', + params.userMessage, + '', + 'Reply as the assistant.', + ].join('\n'); + + this.logger.log(`Sending Gemini request with model: ${params.model}`); + + const generated = await model.generateContent(prompt); + const response = generated.response; + const content = response.text(); + const usage = (response as any).usageMetadata as + | { + candidatesTokenCount?: number; + promptTokenCount?: number; + totalTokenCount?: number; + } + | undefined; + + const promptTokens = usage?.promptTokenCount ?? 0; + const completionTokens = usage?.candidatesTokenCount ?? 0; + const totalTokens = + usage?.totalTokenCount ?? promptTokens + completionTokens; + + return { + completionTokens, + content, + cost: 0, + model: params.model, + promptTokens, + totalTokens, + }; + } +} diff --git a/packages/mukti-api/src/modules/conversations/services/queue.service.ts b/packages/mukti-api/src/modules/conversations/services/queue.service.ts index 71c20e5f..0c01162f 100644 --- a/packages/mukti-api/src/modules/conversations/services/queue.service.ts +++ b/packages/mukti-api/src/modules/conversations/services/queue.service.ts @@ -12,14 +12,19 @@ import { import { Technique, TechniqueDocument, + type TechniqueTemplate, } from '../../../schemas/technique.schema'; import { UsageEvent, UsageEventDocument, } from '../../../schemas/usage-event.schema'; import { User, UserDocument } from '../../../schemas/user.schema'; -import { AiPolicyService } from '../../ai/services/ai-policy.service'; +import { + AiPolicyService, + type AiProvider, +} from '../../ai/services/ai-policy.service'; import { AiSecretsService } from '../../ai/services/ai-secrets.service'; +import { GeminiService } from './gemini.service'; import { MessageService } from './message.service'; import { OpenRouterService } from './openrouter.service'; import { StreamService } from './stream.service'; @@ -31,6 +36,7 @@ export interface ConversationRequestJobData { conversationId: string; message: string; model: string; + provider: AiProvider; subscriptionTier: string; technique: string; usedByok: boolean; @@ -81,6 +87,7 @@ export class QueueService extends WorkerHost { private readonly configService: ConfigService, private readonly aiPolicyService: AiPolicyService, private readonly aiSecretsService: AiSecretsService, + private readonly geminiService: GeminiService, private readonly messageService: MessageService, private readonly openRouterService: OpenRouterService, private readonly streamService: StreamService, @@ -125,6 +132,7 @@ export class QueueService extends WorkerHost { subscriptionTier: string, technique: string, model: string, + provider: AiProvider, usedByok: boolean, ): Promise<{ jobId: string; position: number }> { const userIdString = this.formatId(userId); @@ -143,6 +151,7 @@ export class QueueService extends WorkerHost { conversationId: conversationIdString, message, model, + provider, subscriptionTier, technique, usedByok, @@ -390,6 +399,7 @@ export class QueueService extends WorkerHost { conversationId, message, model, + provider, subscriptionTier: _subscriptionTier, technique, usedByok, @@ -444,18 +454,12 @@ export class QueueService extends WorkerHost { type: 'progress', }); - // 3. Build prompt and call OpenRouter - const effectiveModel = this.validateEffectiveModel(model, usedByok); - const apiKey = await this.resolveApiKey(userId, usedByok); - const messages = this.openRouterService.buildPrompt( - techniqueDoc.template, - context.messages.map((m) => ({ - content: m.content, - role: m.role as 'assistant' | 'system' | 'user', - timestamp: new Date(), - })), - message, - ); + // 3. Build prompt and call provider + const effectiveModel = this.validateEffectiveModel({ + model, + provider, + usedByok, + }); // Emit progress event - calling AI this.streamService.emitToConversation(conversationId, { @@ -466,12 +470,28 @@ export class QueueService extends WorkerHost { type: 'progress', }); - const response = await this.openRouterService.sendChatCompletion( - messages, - effectiveModel, - apiKey, - techniqueDoc.template, - ); + const response = + provider === 'gemini' + ? await this.geminiService.sendMessage({ + apiKey: await this.resolveApiKey({ + provider, + usedByok, + userId, + }), + conversationHistory: context.messages, + model: effectiveModel, + technique: techniqueDoc.template, + userMessage: message, + }) + : await this.runOpenRouterCompletion({ + contextMessages: context.messages, + model: effectiveModel, + provider, + techniqueTemplate: techniqueDoc.template, + usedByok, + userId, + userMessage: message, + }); // 4. Add messages to conversation const updatedConversation = @@ -489,10 +509,16 @@ export class QueueService extends WorkerHost { }, ); - // Get the sequence numbers for the messages - const userMessageSequence = updatedConversation.recentMessages.length - 1; - const assistantMessageSequence = - updatedConversation.recentMessages.length; + const userMessageSequence = updatedConversation.totalMessageCount - 1; + const assistantMessageSequence = updatedConversation.totalMessageCount; + const userMessageTimestamp = + updatedConversation.recentMessages[ + updatedConversation.recentMessages.length - 2 + ]?.timestamp ?? new Date(); + const assistantMessageTimestamp = + updatedConversation.recentMessages[ + updatedConversation.recentMessages.length - 1 + ]?.timestamp ?? new Date(); // Emit message event for user message this.streamService.emitToConversation(conversationId, { @@ -500,7 +526,7 @@ export class QueueService extends WorkerHost { content: message, role: 'user', sequence: userMessageSequence, - timestamp: new Date().toISOString(), + timestamp: userMessageTimestamp.toISOString(), }, type: 'message', }); @@ -511,7 +537,7 @@ export class QueueService extends WorkerHost { content: response.content, role: 'assistant', sequence: assistantMessageSequence, - timestamp: new Date().toISOString(), + timestamp: assistantMessageTimestamp.toISOString(), tokens: response.totalTokens, }, type: 'message', @@ -600,13 +626,27 @@ export class QueueService extends WorkerHost { return error instanceof Error ? error.stack : undefined; } - private async resolveApiKey( - userId: string, - usedByok: boolean, - ): Promise { - if (usedByok) { + private async resolveApiKey(params: { + provider: AiProvider; + usedByok: boolean; + userId: string; + }): Promise { + if (params.provider === 'gemini') { + const user = await this.userModel + .findById(params.userId) + .select('+geminiApiKeyEncrypted') + .lean(); + + if (!user?.geminiApiKeyEncrypted) { + throw new Error('GEMINI_KEY_MISSING'); + } + + return this.aiSecretsService.decryptString(user.geminiApiKeyEncrypted); + } + + if (params.usedByok) { const user = await this.userModel - .findById(userId) + .findById(params.userId) .select('+openRouterApiKeyEncrypted') .lean(); @@ -647,14 +687,60 @@ export class QueueService extends WorkerHost { ); } - private validateEffectiveModel(model: string, usedByok: boolean): string { - const trimmed = model.trim(); + private async runOpenRouterCompletion(params: { + contextMessages: { content: string; role: string }[]; + model: string; + provider: AiProvider; + techniqueTemplate: TechniqueTemplate; + usedByok: boolean; + userId: string; + userMessage: string; + }) { + const messages = this.openRouterService.buildPrompt( + params.techniqueTemplate, + params.contextMessages.map((message) => ({ + content: message.content, + role: message.role as 'assistant' | 'system' | 'user', + timestamp: new Date(), + })), + params.userMessage, + ); + + const apiKey = await this.resolveApiKey({ + provider: params.provider, + usedByok: params.usedByok, + userId: params.userId, + }); + + return this.openRouterService.sendChatCompletion( + messages, + params.model, + apiKey, + params.techniqueTemplate, + ); + } + + private validateEffectiveModel(params: { + model: string; + provider: AiProvider; + usedByok: boolean; + }): string { + const trimmed = params.model.trim(); if (!trimmed) { throw new Error('Model is required'); } - if (!usedByok) { + if (params.provider === 'gemini') { + const isGemini = this.aiPolicyService.isGeminiModel(trimmed); + if (!isGemini) { + throw new Error('MODEL_NOT_ALLOWED'); + } + + return trimmed; + } + + if (!params.usedByok) { const isCurated = this.aiPolicyService .getCuratedModels() .some((allowed) => allowed.id === trimmed); diff --git a/packages/mukti-api/src/modules/conversations/services/stream.service.ts b/packages/mukti-api/src/modules/conversations/services/stream.service.ts index 573a859c..1658773e 100644 --- a/packages/mukti-api/src/modules/conversations/services/stream.service.ts +++ b/packages/mukti-api/src/modules/conversations/services/stream.service.ts @@ -1,4 +1,6 @@ -import { Injectable, Logger } from '@nestjs/common'; +import { InjectQueue } from '@nestjs/bullmq'; +import { Queue } from 'bullmq'; +import { Injectable, Logger, OnModuleDestroy, Optional } from '@nestjs/common'; /** * Union type of all possible stream events. @@ -102,6 +104,25 @@ interface StreamConnection { userId: string; } +interface StreamBridgeEnvelope { + event: StreamEvent; + scope: 'conversation' | 'user'; + sourceId: string; + userId?: string; +} + +interface RedisPubSubClient { + disconnect?: () => void; + duplicate?: () => RedisPubSubClient; + on?: (event: string, listener: (...args: any[]) => void) => void; + publish?: (channel: string, message: string) => Promise | number; + quit?: () => Promise; + subscribe?: ( + channel: string, + listener?: (message: string, channel: string) => void, + ) => Promise | unknown; +} + /** * Service responsible for managing Server-Sent Events (SSE) connections for real-time conversation updates. * Handles connection lifecycle, event emission, and cleanup for multiple concurrent connections. @@ -118,15 +139,33 @@ interface StreamConnection { * manage multiple connections per conversation while maintaining O(1) lookup performance. */ @Injectable() -export class StreamService { +export class StreamService implements OnModuleDestroy { /** * Map of conversation IDs to arrays of active connections. * Key: conversationId * Value: Array of StreamConnection objects for that conversation */ private readonly connections = new Map(); + private readonly streamBridgeChannel = 'mukti:conversation-stream-events:v1'; + private readonly streamSourceId = `${process.pid}-${Math.random() + .toString(36) + .slice(2, 10)}`; + private bridgeInitialized = false; + private bridgeInitialization?: Promise; private readonly logger = new Logger(StreamService.name); + private redisPublisher: null | RedisPubSubClient = null; + private redisSubscriber: null | RedisPubSubClient = null; + + constructor( + @Optional() + @InjectQueue('conversation-requests') + private readonly conversationRequestsQueue?: Queue, + ) { + if (process.env.NODE_ENV !== 'test') { + void this.ensureBridge(); + } + } /** * Registers a new SSE connection for a conversation. @@ -160,6 +199,10 @@ export class StreamService { connectionId: string, emitFn: (event: StreamEvent) => void, ): void { + if (process.env.NODE_ENV !== 'test') { + void this.ensureBridge(); + } + this.logger.log( `Adding SSE connection: conversationId=${conversationId}, userId=${userId}, connectionId=${connectionId}`, ); @@ -262,46 +305,18 @@ export class StreamService { conversationId: string, event: Omit, ): void { - const connections = this.connections.get(conversationId); - - if (!connections || connections.length === 0) { - this.logger.debug( - `No active connections for conversation ${conversationId}. Event not emitted.`, - ); - return; - } - - // Add timestamp and conversationId to the event const fullEvent: StreamEvent = { ...event, conversationId, timestamp: new Date().toISOString(), } as StreamEvent; - this.logger.log( - `Emitting event to conversation ${conversationId}: type=${event.type}, connections=${connections.length}`, - ); - - // Emit to all connections for this conversation - let successCount = 0; - let errorCount = 0; - - for (const connection of connections) { - try { - connection.emitFn(fullEvent); - successCount++; - } catch (error) { - errorCount++; - this.logger.error( - `Failed to emit event to connection ${connection.connectionId}: ${this.getErrorMessage(error)}`, - this.getErrorStack(error), - ); - } - } - - this.logger.log( - `Event emitted to conversation ${conversationId}: success=${successCount}, errors=${errorCount}`, - ); + this.emitToConversationLocal(conversationId, fullEvent); + this.publishToBridge({ + event: fullEvent, + scope: 'conversation', + sourceId: this.streamSourceId, + }); } /** @@ -341,58 +356,19 @@ export class StreamService { userId: string, event: Omit, ): void { - const connections = this.connections.get(conversationId); - - if (!connections || connections.length === 0) { - this.logger.debug( - `No active connections for conversation ${conversationId}. Event not emitted.`, - ); - return; - } - - // Filter connections for the specific user - const userConnections = connections.filter( - (conn) => conn.userId === userId, - ); - - if (userConnections.length === 0) { - this.logger.debug( - `No active connections for user ${userId} in conversation ${conversationId}. Event not emitted.`, - ); - return; - } - - // Add timestamp and conversationId to the event const fullEvent: StreamEvent = { ...event, conversationId, timestamp: new Date().toISOString(), } as StreamEvent; - this.logger.log( - `Emitting event to user ${userId} in conversation ${conversationId}: type=${event.type}, connections=${userConnections.length}`, - ); - - // Emit to all of the user's connections - let successCount = 0; - let errorCount = 0; - - for (const connection of userConnections) { - try { - connection.emitFn(fullEvent); - successCount++; - } catch (error) { - errorCount++; - this.logger.error( - `Failed to emit event to connection ${connection.connectionId}: ${this.getErrorMessage(error)}`, - this.getErrorStack(error), - ); - } - } - - this.logger.log( - `Event emitted to user ${userId}: success=${successCount}, errors=${errorCount}`, - ); + this.emitToUserLocal(conversationId, userId, fullEvent); + this.publishToBridge({ + event: fullEvent, + scope: 'user', + sourceId: this.streamSourceId, + userId, + }); } /** @@ -494,6 +470,247 @@ export class StreamService { } } + async onModuleDestroy(): Promise { + if (this.redisSubscriber?.quit) { + try { + await this.redisSubscriber.quit(); + } catch { + this.redisSubscriber.disconnect?.(); + } + } else { + this.redisSubscriber?.disconnect?.(); + } + + this.redisSubscriber = null; + this.redisPublisher = null; + this.bridgeInitialized = false; + } + + private async ensureBridge(): Promise { + if (this.bridgeInitialized || this.bridgeInitialization) { + return this.bridgeInitialization; + } + + const queue = this.conversationRequestsQueue; + if (!queue) { + return; + } + + this.bridgeInitialization = (async () => { + try { + const queueClient = + (await queue.client) as unknown as RedisPubSubClient; + + if (!queueClient?.publish || !queueClient.duplicate) { + this.logger.warn( + 'Redis bridge disabled: queue client does not support pub/sub', + ); + return; + } + + const subscriber = queueClient.duplicate(); + if (!subscriber.subscribe) { + this.logger.warn( + 'Redis bridge disabled: duplicate client does not support subscribe', + ); + return; + } + + subscriber.on?.('error', (error: unknown) => { + this.logger.error( + `Redis stream bridge subscriber error: ${this.getErrorMessage(error)}`, + this.getErrorStack(error), + ); + }); + + if (subscriber.on) { + subscriber.on('message', (channel: string, payload: string) => { + if (channel !== this.streamBridgeChannel) { + return; + } + + this.handleBridgePayload(payload); + }); + } + + await subscriber.subscribe(this.streamBridgeChannel); + + this.redisPublisher = queueClient; + this.redisSubscriber = subscriber; + this.bridgeInitialized = true; + + this.logger.log( + `Redis stream bridge initialized on channel ${this.streamBridgeChannel}`, + ); + } catch (error) { + this.logger.warn( + `Redis stream bridge unavailable: ${this.getErrorMessage(error)}`, + ); + } finally { + this.bridgeInitialization = undefined; + } + })(); + + return this.bridgeInitialization; + } + + private emitToConversationLocal( + conversationId: string, + event: StreamEvent, + ): void { + const connections = this.connections.get(conversationId); + + if (!connections || connections.length === 0) { + this.logger.debug( + `No active connections for conversation ${conversationId}. Event not emitted locally.`, + ); + return; + } + + this.logger.log( + `Emitting local event to conversation ${conversationId}: type=${event.type}, connections=${connections.length}`, + ); + + let successCount = 0; + let errorCount = 0; + + for (const connection of connections) { + try { + connection.emitFn(event); + successCount++; + } catch (error) { + errorCount++; + this.logger.error( + `Failed to emit event to connection ${connection.connectionId}: ${this.getErrorMessage(error)}`, + this.getErrorStack(error), + ); + } + } + + this.logger.log( + `Local event emitted to conversation ${conversationId}: success=${successCount}, errors=${errorCount}`, + ); + } + + private emitToUserLocal( + conversationId: string, + userId: string, + event: StreamEvent, + ): void { + const connections = this.connections.get(conversationId); + + if (!connections || connections.length === 0) { + this.logger.debug( + `No active connections for conversation ${conversationId}. Event not emitted locally.`, + ); + return; + } + + const userConnections = connections.filter( + (connection) => connection.userId === userId, + ); + + if (userConnections.length === 0) { + this.logger.debug( + `No active connections for user ${userId} in conversation ${conversationId}. Event not emitted locally.`, + ); + return; + } + + this.logger.log( + `Emitting local event to user ${userId} in conversation ${conversationId}: type=${event.type}, connections=${userConnections.length}`, + ); + + let successCount = 0; + let errorCount = 0; + + for (const connection of userConnections) { + try { + connection.emitFn(event); + successCount++; + } catch (error) { + errorCount++; + this.logger.error( + `Failed to emit event to connection ${connection.connectionId}: ${this.getErrorMessage(error)}`, + this.getErrorStack(error), + ); + } + } + + this.logger.log( + `Local event emitted to user ${userId}: success=${successCount}, errors=${errorCount}`, + ); + } + + private handleBridgePayload(payload: string): void { + try { + const envelope = JSON.parse(payload) as StreamBridgeEnvelope; + + if (envelope.sourceId === this.streamSourceId) { + return; + } + + if (envelope.scope === 'user' && envelope.userId) { + this.emitToUserLocal( + envelope.event.conversationId, + envelope.userId, + envelope.event, + ); + return; + } + + this.emitToConversationLocal( + envelope.event.conversationId, + envelope.event, + ); + } catch (error) { + this.logger.error( + `Failed to process bridged stream payload: ${this.getErrorMessage(error)}`, + this.getErrorStack(error), + ); + } + } + + private publishToBridge(envelope: StreamBridgeEnvelope): void { + if (process.env.NODE_ENV === 'test' || !this.conversationRequestsQueue) { + return; + } + + if (!this.bridgeInitialized || !this.redisPublisher?.publish) { + void this.ensureBridge().then(() => { + if (!this.bridgeInitialized || !this.redisPublisher?.publish) { + return; + } + + void Promise.resolve( + this.redisPublisher.publish( + this.streamBridgeChannel, + JSON.stringify(envelope), + ), + ).catch((error: unknown) => { + this.logger.error( + `Failed to publish bridged event: ${this.getErrorMessage(error)}`, + this.getErrorStack(error), + ); + }); + }); + + return; + } + + void Promise.resolve( + this.redisPublisher.publish( + this.streamBridgeChannel, + JSON.stringify(envelope), + ), + ).catch((error: unknown) => { + this.logger.error( + `Failed to publish bridged event: ${this.getErrorMessage(error)}`, + this.getErrorStack(error), + ); + }); + } + private getErrorMessage(error: unknown): string { if (error instanceof Error) { return error.message; diff --git a/packages/mukti-api/src/schemas/user.schema.ts b/packages/mukti-api/src/schemas/user.schema.ts index 6e1b59e6..15f9c986 100644 --- a/packages/mukti-api/src/schemas/user.schema.ts +++ b/packages/mukti-api/src/schemas/user.schema.ts @@ -5,6 +5,7 @@ export type UserDocument = Document & User; export interface UserPreferences { activeModel?: string; + activeProvider?: 'gemini' | 'openrouter'; defaultTechnique?: string; emailNotifications?: boolean; language?: string; diff --git a/packages/mukti-web/src/app/dashboard/settings/page.tsx b/packages/mukti-web/src/app/dashboard/settings/page.tsx index c053ee0a..9deb84f6 100644 --- a/packages/mukti-web/src/app/dashboard/settings/page.tsx +++ b/packages/mukti-web/src/app/dashboard/settings/page.tsx @@ -13,6 +13,7 @@ import { useAiStore } from '@/lib/stores/ai-store'; export default function SettingsPage() { const { activeModel, + activeProvider, deleteGeminiKey, deleteOpenRouterKey, geminiKeyLast4, @@ -20,6 +21,7 @@ export default function SettingsPage() { hasOpenRouterKey, hydrate, isHydrated, + mode, models, openRouterKeyLast4, refreshModels, @@ -42,20 +44,40 @@ export default function SettingsPage() { } }, [hydrate, isHydrated]); + const openRouterActive = activeProvider === 'openrouter' && hasOpenRouterKey; + const geminiActive = activeProvider === 'gemini' && hasGeminiKey; + const selectorDescription = + activeProvider === 'gemini' + ? 'Choose which Gemini model to use.' + : mode === 'openrouter' + ? 'Choose which OpenRouter model to use.' + : 'Choose which curated model to use.'; + return (

AI

- Pick your active model and optionally connect your own OpenRouter key. + Exactly one provider is active at a time. Saving a provider key activates it + immediately.

+
+ Active provider + + {activeProvider} + +
+
- +
{ setSavingModel(true); @@ -66,6 +88,7 @@ export default function SettingsPage() { setSavingModel(false); } }} + title={`Select ${activeProvider === 'gemini' ? 'Gemini' : 'OpenRouter'} Model`} value={activeModel} />