diff --git a/jest.config.js b/jest.config.js index c6866172..c0afb5df 100644 --- a/jest.config.js +++ b/jest.config.js @@ -25,6 +25,13 @@ module.exports = { 'src/services/embeddings/ConversationIndexer.ts', 'src/services/embeddings/TraceIndexer.ts', 'src/agents/searchManager/services/ConversationSearchStrategy.ts', + // OAuth service layer + providers + adapter + 'src/services/oauth/PKCEUtils.ts', + 'src/services/oauth/OAuthCallbackServer.ts', + 'src/services/oauth/OAuthService.ts', + 'src/services/oauth/providers/OpenRouterOAuthProvider.ts', + 'src/services/oauth/providers/OpenAICodexOAuthProvider.ts', + 'src/services/llm/adapters/openai-codex/OpenAICodexAdapter.ts', '!src/**/*.d.ts' ], coverageThreshold: { @@ -126,6 +133,47 @@ module.exports = { functions: 85, lines: 85, statements: 85 + }, + // OAuth service layer: pure crypto utils (high bar) + './src/services/oauth/PKCEUtils.ts': { + branches: 80, + functions: 100, + lines: 100, + statements: 100 + }, + // OAuth callback server: integration-style tests cover all paths + './src/services/oauth/OAuthCallbackServer.ts': { + branches: 75, + functions: 80, + lines: 80, + statements: 80 + }, + // OAuth service: orchestration with mocked dependencies + './src/services/oauth/OAuthService.ts': { + branches: 75, + functions: 80, + lines: 80, + statements: 80 + }, + // OAuth providers: API integration with mocked fetch + './src/services/oauth/providers/OpenRouterOAuthProvider.ts': { + branches: 75, + functions: 80, + lines: 80, + statements: 80 + }, + './src/services/oauth/providers/OpenAICodexOAuthProvider.ts': { + branches: 75, + functions: 80, + lines: 80, + statements: 80 + }, + // Codex adapter: SSE parsing + token management with mocked fetch + './src/services/llm/adapters/openai-codex/OpenAICodexAdapter.ts': { + branches: 75, + functions: 80, + lines: 80, + statements: 80 } }, coverageDirectory: 'coverage', diff --git a/src/components/LLMProviderModal.ts b/src/components/LLMProviderModal.ts index 69149fa2..3e43e348 100644 --- a/src/components/LLMProviderModal.ts +++ b/src/components/LLMProviderModal.ts @@ -20,6 +20,8 @@ import { IProviderModal, ProviderModalConfig, ProviderModalDependencies, + OAuthModalConfig, + SecondaryOAuthProviderConfig, } from './llm-provider/types'; import { NexusProviderModal } from './llm-provider/providers/NexusProviderModal'; import { OllamaProviderModal } from './llm-provider/providers/OllamaProviderModal'; @@ -36,6 +38,8 @@ export interface LLMProviderModalConfig { keyFormat: string; signupUrl: string; config: LLMProviderConfig; + oauthConfig?: OAuthModalConfig; + secondaryOAuthProvider?: SecondaryOAuthProviderConfig; onSave: (config: LLMProviderConfig) => void; } @@ -144,6 +148,8 @@ export class LLMProviderModal extends Modal { signupUrl: this.config.signupUrl, config: { ...this.config.config }, onConfigChange: (config: LLMProviderConfig) => this.handleConfigChange(config), + oauthConfig: this.config.oauthConfig, + secondaryOAuthProvider: this.config.secondaryOAuthProvider, }; } diff --git a/src/components/llm-provider/providers/GenericProviderModal.ts b/src/components/llm-provider/providers/GenericProviderModal.ts index 0411b24e..208bc389 100644 --- a/src/components/llm-provider/providers/GenericProviderModal.ts +++ b/src/components/llm-provider/providers/GenericProviderModal.ts @@ -2,7 +2,7 @@ * GenericProviderModal * * Provider modal for API-key based providers (OpenAI, Anthropic, Google, etc.). - * Handles API key input, validation, and model toggles. + * Handles API key input, validation, model toggles, and optional OAuth connect. */ import { Setting, Notice } from 'obsidian'; @@ -10,9 +10,13 @@ import { IProviderModal, ProviderModalConfig, ProviderModalDependencies, + OAuthModalConfig, + SecondaryOAuthProviderConfig, } from '../types'; import { LLMValidationService } from '../../../services/llm/validation/ValidationService'; import { ModelWithProvider } from '../../../services/StaticModelsService'; +import { OAuthConsentModal, OAuthPreAuthModal } from './OAuthModals'; +import { OAuthService } from '../../../services/oauth/OAuthService'; export class GenericProviderModal implements IProviderModal { private config: ProviderModalConfig; @@ -22,11 +26,19 @@ export class GenericProviderModal implements IProviderModal { private container: HTMLElement | null = null; private apiKeyInput: HTMLInputElement | null = null; private modelsContainer: HTMLElement | null = null; + private oauthBannerContainer: HTMLElement | null = null; + private connectButton: HTMLButtonElement | null = null; + + // Secondary OAuth UI elements + private secondaryBannerContainer: HTMLElement | null = null; + private secondaryConnectButton: HTMLButtonElement | null = null; // State private apiKey: string = ''; private models: ModelWithProvider[] = []; private isValidated: boolean = false; + private isOAuthConnecting: boolean = false; + private isSecondaryOAuthConnecting: boolean = false; private validationTimeout: ReturnType | null = null; constructor(config: ProviderModalConfig, deps: ProviderModalDependencies) { @@ -46,15 +58,23 @@ export class GenericProviderModal implements IProviderModal { this.renderApiKeySection(container); this.renderModelsSection(container); + + if (this.config.secondaryOAuthProvider) { + this.renderSecondaryOAuthSection(container); + } } /** - * Render API key input section + * Render API key input section, with optional OAuth connect button and connected banner */ private renderApiKeySection(container: HTMLElement): void { - container.createEl('h2', { text: 'API Key' }); + container.createEl('h2', { text: 'API key' }); + + // OAuth connected banner (shown above the key input when connected) + this.oauthBannerContainer = container.createDiv('oauth-banner-container'); + this.renderOAuthBanner(); - new Setting(container) + const setting = new Setting(container) .setDesc(`Enter your ${this.config.providerName} API key (format: ${this.config.keyFormat})`) .addText(text => { this.apiKeyInput = text.inputEl; @@ -71,7 +91,7 @@ export class GenericProviderModal implements IProviderModal { }) .addButton(button => { button - .setButtonText('Get Key') + .setButtonText('Get key') .setTooltip(`Open ${this.config.providerName} API key page`) .onClick(() => { window.open(this.config.signupUrl, '_blank'); @@ -79,6 +99,335 @@ export class GenericProviderModal implements IProviderModal { }); } + /** + * Render the OAuth banner area: connected banner when connected, + * standalone connect button when disconnected but OAuth is available + */ + private renderOAuthBanner(): void { + if (!this.oauthBannerContainer) return; + this.oauthBannerContainer.empty(); + + if (!this.config.oauthConfig) return; + + const oauthState = this.config.config.oauth; + + if (oauthState?.connected) { + // Connected state: show connected banner with disconnect button + const banner = this.oauthBannerContainer.createDiv('oauth-connected-banner'); + + const statusText = banner.createSpan('oauth-connected-status'); + statusText.textContent = `Connected via ${this.config.oauthConfig.providerLabel}`; + + const disconnectBtn = banner.createEl('button', { + text: 'Disconnect', + cls: 'oauth-disconnect-btn', + }); + disconnectBtn.setAttribute('aria-label', `Disconnect ${this.config.oauthConfig.providerLabel} OAuth`); + disconnectBtn.onclick = () => this.handleOAuthDisconnect(); + } else { + // Disconnected state: show standalone connect button + const connectDiv = this.oauthBannerContainer.createDiv('oauth-connect-standalone'); + const label = this.config.oauthConfig.providerLabel; + this.connectButton = connectDiv.createEl('button', { + text: `Connect with ${label}`, + cls: 'mod-cta oauth-connect-btn', + }); + this.connectButton.setAttribute('aria-label', `Connect with ${label} via OAuth`); + this.connectButton.onclick = () => this.handleOAuthConnect(); + } + } + + /** + * Handle the OAuth connect button click + */ + private async handleOAuthConnect(): Promise { + const oauthConfig = this.config.oauthConfig; + if (!oauthConfig || this.isOAuthConnecting) return; + + const hasPreAuthFields = oauthConfig.preAuthFields && oauthConfig.preAuthFields.length > 0; + + // Experimental provider: always show consent modal (includes pre-auth fields) + if (oauthConfig.experimental) { + new OAuthConsentModal( + this.deps.app, + oauthConfig, + (params) => this.executeOAuthFlow(oauthConfig, params), + () => { /* cancelled */ }, + ).open(); + return; + } + + // Non-experimental with pre-auth fields: show pre-auth modal + if (hasPreAuthFields) { + new OAuthPreAuthModal( + this.deps.app, + oauthConfig, + (params) => this.executeOAuthFlow(oauthConfig, params), + () => { /* cancelled */ }, + ).open(); + return; + } + + // No consent or pre-auth needed: start flow directly + await this.executeOAuthFlow(oauthConfig, {}); + } + + /** + * Execute the OAuth flow and handle the result + */ + private async executeOAuthFlow( + oauthConfig: OAuthModalConfig, + params: Record, + ): Promise { + this.setOAuthConnecting(true); + + try { + const result = await oauthConfig.startFlow(params); + + if (result.success && result.apiKey) { + // Update API key + this.apiKey = result.apiKey; + this.config.config.apiKey = result.apiKey; + + if (this.apiKeyInput) { + this.apiKeyInput.value = result.apiKey; + } + + // Set OAuth state + this.config.config.oauth = { + connected: true, + providerId: this.config.providerId, + connectedAt: Date.now(), + refreshToken: result.refreshToken, + expiresAt: result.expiresAt, + metadata: result.metadata, + }; + + // Auto-enable the provider + this.config.config.enabled = true; + this.saveConfig(); + + // Refresh the banner + this.renderOAuthBanner(); + + new Notice(`Connected to ${oauthConfig.providerLabel} successfully`); + } else { + const errorMsg = result.error || 'OAuth flow failed'; + new Notice(`${oauthConfig.providerLabel} connection failed: ${errorMsg}`); + } + } catch (error) { + const errorMsg = error instanceof Error ? error.message : 'Unknown error'; + new Notice(`${oauthConfig.providerLabel} connection failed: ${errorMsg}`); + } finally { + this.setOAuthConnecting(false); + } + } + + /** + * Handle OAuth disconnect + */ + private handleOAuthDisconnect(): void { + this.apiKey = ''; + this.config.config.apiKey = ''; + this.config.config.oauth = undefined; + + if (this.apiKeyInput) { + this.apiKeyInput.value = ''; + } + + this.saveConfig(); + this.renderOAuthBanner(); + + new Notice(`Disconnected from ${this.config.oauthConfig?.providerLabel || 'provider'}`); + } + + /** + * Update the connect button state during OAuth flow + */ + private setOAuthConnecting(connecting: boolean): void { + this.isOAuthConnecting = connecting; + if (!this.connectButton || !this.config.oauthConfig) return; + + if (connecting) { + this.connectButton.textContent = 'Connecting...'; + this.connectButton.disabled = true; + this.connectButton.addClass('oauth-connecting'); + } else { + const label = this.config.oauthConfig.providerLabel; + this.connectButton.textContent = `Connect with ${label}`; + this.connectButton.disabled = false; + this.connectButton.removeClass('oauth-connecting'); + } + } + + /** + * Render a secondary OAuth provider sub-section (e.g., Codex inside OpenAI modal) + */ + private renderSecondaryOAuthSection(container: HTMLElement): void { + const secondary = this.config.secondaryOAuthProvider; + if (!secondary) return; + + const section = container.createDiv('secondary-oauth-section'); + + section.createEl('h2', { text: secondary.providerLabel }); + section.createEl('p', { + text: secondary.description, + cls: 'setting-item-description', + }); + + // Banner container for connected/disconnected state + this.secondaryBannerContainer = section.createDiv('oauth-banner-container'); + this.renderSecondaryOAuthBanner(); + } + + /** + * Render the secondary OAuth banner: connected banner or connect button + */ + private renderSecondaryOAuthBanner(): void { + if (!this.secondaryBannerContainer) return; + this.secondaryBannerContainer.empty(); + + const secondary = this.config.secondaryOAuthProvider; + if (!secondary) return; + + const oauthState = secondary.config.oauth; + + if (oauthState?.connected) { + const banner = this.secondaryBannerContainer.createDiv('oauth-connected-banner'); + + const statusText = banner.createSpan('oauth-connected-status'); + statusText.textContent = `Connected via ${secondary.oauthConfig.providerLabel}`; + + const disconnectBtn = banner.createEl('button', { + text: 'Disconnect', + cls: 'oauth-disconnect-btn', + }); + disconnectBtn.setAttribute('aria-label', `Disconnect ${secondary.oauthConfig.providerLabel} OAuth`); + disconnectBtn.onclick = () => this.handleSecondaryOAuthDisconnect(); + } else { + const connectDiv = this.secondaryBannerContainer.createDiv('oauth-connect-standalone'); + const label = secondary.oauthConfig.providerLabel; + this.secondaryConnectButton = connectDiv.createEl('button', { + text: `Connect with ${label}`, + cls: 'mod-cta oauth-connect-btn', + }); + this.secondaryConnectButton.setAttribute('aria-label', `Connect with ${label} via OAuth`); + this.secondaryConnectButton.onclick = () => this.handleSecondaryOAuthConnect(); + } + } + + /** + * Handle secondary OAuth connect button click + */ + private async handleSecondaryOAuthConnect(): Promise { + const secondary = this.config.secondaryOAuthProvider; + if (!secondary || this.isSecondaryOAuthConnecting) return; + + const oauthConfig = secondary.oauthConfig; + + // Experimental provider: show consent modal + if (oauthConfig.experimental) { + new OAuthConsentModal( + this.deps.app, + oauthConfig, + (params) => this.executeSecondaryOAuthFlow(secondary, params), + () => { /* cancelled */ }, + ).open(); + return; + } + + // Pre-auth fields: show pre-auth modal + const hasPreAuthFields = oauthConfig.preAuthFields && oauthConfig.preAuthFields.length > 0; + if (hasPreAuthFields) { + new OAuthPreAuthModal( + this.deps.app, + oauthConfig, + (params) => this.executeSecondaryOAuthFlow(secondary, params), + () => { /* cancelled */ }, + ).open(); + return; + } + + // No consent or pre-auth: start directly + await this.executeSecondaryOAuthFlow(secondary, {}); + } + + /** + * Execute the secondary OAuth flow and handle the result + */ + private async executeSecondaryOAuthFlow( + secondary: SecondaryOAuthProviderConfig, + params: Record, + ): Promise { + this.setSecondaryOAuthConnecting(true); + + try { + const result = await secondary.oauthConfig.startFlow(params); + + if (result.success && result.apiKey) { + secondary.config.apiKey = result.apiKey; + secondary.config.oauth = { + connected: true, + providerId: secondary.providerId, + connectedAt: Date.now(), + refreshToken: result.refreshToken, + expiresAt: result.expiresAt, + metadata: result.metadata, + }; + secondary.config.enabled = true; + secondary.onConfigChange(secondary.config); + + this.renderSecondaryOAuthBanner(); + + new Notice(`Connected to ${secondary.oauthConfig.providerLabel} successfully`); + } else { + const errorMsg = result.error || 'OAuth flow failed'; + new Notice(`${secondary.oauthConfig.providerLabel} connection failed: ${errorMsg}`); + } + } catch (error) { + const errorMsg = error instanceof Error ? error.message : 'Unknown error'; + new Notice(`${secondary.oauthConfig.providerLabel} connection failed: ${errorMsg}`); + } finally { + this.setSecondaryOAuthConnecting(false); + } + } + + /** + * Handle secondary OAuth disconnect + */ + private handleSecondaryOAuthDisconnect(): void { + const secondary = this.config.secondaryOAuthProvider; + if (!secondary) return; + + secondary.config.apiKey = ''; + secondary.config.oauth = undefined; + secondary.onConfigChange(secondary.config); + + this.renderSecondaryOAuthBanner(); + + new Notice(`Disconnected from ${secondary.oauthConfig.providerLabel}`); + } + + /** + * Update the secondary connect button state during OAuth flow + */ + private setSecondaryOAuthConnecting(connecting: boolean): void { + this.isSecondaryOAuthConnecting = connecting; + const secondary = this.config.secondaryOAuthProvider; + if (!this.secondaryConnectButton || !secondary) return; + + if (connecting) { + this.secondaryConnectButton.textContent = 'Connecting...'; + this.secondaryConnectButton.disabled = true; + this.secondaryConnectButton.addClass('oauth-connecting'); + } else { + const label = secondary.oauthConfig.providerLabel; + this.secondaryConnectButton.textContent = `Connect with ${label}`; + this.secondaryConnectButton.disabled = false; + this.secondaryConnectButton.removeClass('oauth-connecting'); + } + } + /** * Handle API key input changes */ @@ -94,6 +443,12 @@ export class GenericProviderModal implements IProviderModal { this.config.config.lastValidated = undefined; this.config.config.validationHash = undefined; + // Clear OAuth badge if user manually types a key + if (this.config.config.oauth?.connected) { + this.config.config.oauth = undefined; + this.renderOAuthBanner(); + } + // Clear existing timeout if (this.validationTimeout) { clearTimeout(this.validationTimeout); @@ -122,7 +477,7 @@ export class GenericProviderModal implements IProviderModal { * Render models section */ private renderModelsSection(container: HTMLElement): void { - container.createEl('h2', { text: 'Available Models' }); + container.createEl('h2', { text: 'Available models' }); this.modelsContainer = container.createDiv('models-container'); this.loadModels(); @@ -279,8 +634,17 @@ export class GenericProviderModal implements IProviderModal { this.validationTimeout = null; } + // Cancel any in-progress OAuth flow so the callback server shuts down + if (this.isOAuthConnecting || this.isSecondaryOAuthConnecting) { + OAuthService.getInstance().cancelFlow(); + } + this.container = null; this.apiKeyInput = null; this.modelsContainer = null; + this.oauthBannerContainer = null; + this.connectButton = null; + this.secondaryBannerContainer = null; + this.secondaryConnectButton = null; } } diff --git a/src/components/llm-provider/providers/OAuthModals.ts b/src/components/llm-provider/providers/OAuthModals.ts new file mode 100644 index 00000000..fe07e721 --- /dev/null +++ b/src/components/llm-provider/providers/OAuthModals.ts @@ -0,0 +1,158 @@ +/** + * OAuthModals + * + * Helper modals for the OAuth connect flow: + * - OAuthConsentModal: experimental provider warning + optional pre-auth fields + * - OAuthPreAuthModal: pre-auth field collection for non-experimental providers + */ + +import { Modal, App, Setting } from 'obsidian'; +import type { OAuthModalConfig } from '../types'; + +/** + * Modal shown before an experimental OAuth flow starts. + * Displays a warning and optionally collects pre-auth fields. + */ +export class OAuthConsentModal extends Modal { + private oauthConfig: OAuthModalConfig; + private onConfirm: (params: Record) => void; + private onCancel: () => void; + + constructor( + app: App, + oauthConfig: OAuthModalConfig, + onConfirm: (params: Record) => void, + onCancel: () => void, + ) { + super(app); + this.oauthConfig = oauthConfig; + this.onConfirm = onConfirm; + this.onCancel = onCancel; + } + + onOpen(): void { + const { contentEl } = this; + contentEl.empty(); + contentEl.addClass('oauth-consent-modal'); + + contentEl.createEl('h2', { text: 'Experimental feature' }); + + if (this.oauthConfig.experimentalWarning) { + contentEl.createEl('p', { + text: this.oauthConfig.experimentalWarning, + cls: 'oauth-consent-warning', + }); + } + + const fieldValues: Record = {}; + if (this.oauthConfig.preAuthFields && this.oauthConfig.preAuthFields.length > 0) { + this.renderFields(contentEl, fieldValues); + } + + const buttonContainer = contentEl.createDiv('oauth-consent-buttons'); + + const cancelBtn = buttonContainer.createEl('button', { text: 'Cancel' }); + cancelBtn.addEventListener('click', () => { + this.onCancel(); + this.close(); + }); + + const confirmBtn = buttonContainer.createEl('button', { + text: 'I understand, connect', + cls: 'mod-cta', + }); + confirmBtn.addEventListener('click', () => { + this.onConfirm(fieldValues); + this.close(); + }); + } + + onClose(): void { + this.contentEl.empty(); + } + + private renderFields( + container: HTMLElement, + fieldValues: Record, + ): void { + const fieldsContainer = container.createDiv('oauth-consent-fields'); + for (const field of this.oauthConfig.preAuthFields!) { + fieldValues[field.key] = field.defaultValue || ''; + new Setting(fieldsContainer) + .setName(field.label) + .addText(text => { + text + .setPlaceholder(field.placeholder || '') + .setValue(field.defaultValue || '') + .onChange(value => { fieldValues[field.key] = value; }); + }); + } + } +} + +/** + * Modal for collecting pre-auth fields when there is no experimental warning. + */ +export class OAuthPreAuthModal extends Modal { + private oauthConfig: OAuthModalConfig; + private onConfirm: (params: Record) => void; + private onCancel: () => void; + + constructor( + app: App, + oauthConfig: OAuthModalConfig, + onConfirm: (params: Record) => void, + onCancel: () => void, + ) { + super(app); + this.oauthConfig = oauthConfig; + this.onConfirm = onConfirm; + this.onCancel = onCancel; + } + + onOpen(): void { + const { contentEl } = this; + contentEl.empty(); + contentEl.addClass('oauth-preauth-modal'); + + contentEl.createEl('h2', { + text: `Connect with ${this.oauthConfig.providerLabel}`, + }); + + const fieldValues: Record = {}; + const fieldsContainer = contentEl.createDiv('oauth-preauth-fields'); + + for (const field of this.oauthConfig.preAuthFields || []) { + fieldValues[field.key] = field.defaultValue || ''; + new Setting(fieldsContainer) + .setName(field.label) + .addText(text => { + text + .setPlaceholder(field.placeholder || '') + .setValue(field.defaultValue || '') + .onChange(value => { fieldValues[field.key] = value; }); + }); + } + + const buttonContainer = contentEl.createDiv('oauth-preauth-buttons'); + + const cancelBtn = buttonContainer.createEl('button', { text: 'Cancel' }); + cancelBtn.addEventListener('click', () => { + this.onCancel(); + this.close(); + }); + + const confirmBtn = buttonContainer.createEl('button', { + text: 'Connect', + cls: 'mod-cta', + }); + confirmBtn.addEventListener('click', () => { + this.onConfirm(fieldValues); + this.close(); + }); + } + + onClose(): void { + this.contentEl.empty(); + } +} diff --git a/src/components/llm-provider/types.ts b/src/components/llm-provider/types.ts index 57f9acb5..29b3012e 100644 --- a/src/components/llm-provider/types.ts +++ b/src/components/llm-provider/types.ts @@ -58,6 +58,53 @@ export interface ProviderModalConfig { /** Callback when configuration changes (for auto-save) */ onConfigChange: (config: LLMProviderConfig) => void; + + /** Optional OAuth configuration for providers that support OAuth connect */ + oauthConfig?: OAuthModalConfig; + + /** Optional secondary OAuth provider shown as a sub-section in the modal */ + secondaryOAuthProvider?: SecondaryOAuthProviderConfig; +} + +/** + * Secondary OAuth provider shown as a sub-section inside a primary provider modal. + * For example, Codex (ChatGPT OAuth) shown inside the OpenAI modal. + */ +export interface SecondaryOAuthProviderConfig { + /** Provider identifier (e.g., 'openai-codex') */ + providerId: string; + /** Display label (e.g., "ChatGPT (Codex)") */ + providerLabel: string; + /** Description text shown in the sub-section */ + description: string; + /** Current provider configuration for the secondary provider */ + config: LLMProviderConfig; + /** OAuth configuration for the secondary provider's connect button */ + oauthConfig: OAuthModalConfig; + /** Callback when secondary provider configuration changes */ + onConfigChange: (config: LLMProviderConfig) => void; +} + +/** + * OAuth configuration for the provider modal connect button + */ +export interface OAuthModalConfig { + /** Display label (e.g., "OpenRouter", "ChatGPT (Experimental)") */ + providerLabel: string; + /** If true, show a consent dialog before starting the flow */ + experimental?: boolean; + /** Warning text for experimental providers */ + experimentalWarning?: string; + /** Fields to collect before opening the browser (e.g., key_name, credit limit) */ + preAuthFields?: Array<{ + key: string; + label: string; + placeholder?: string; + required: boolean; + defaultValue?: string; + }>; + /** Start the OAuth flow with collected params, returns the API key on success */ + startFlow(params: Record): Promise<{ success: boolean; apiKey?: string; refreshToken?: string; expiresAt?: number; metadata?: Record; error?: string }>; } /** diff --git a/src/components/shared/ChatSettingsRenderer.ts b/src/components/shared/ChatSettingsRenderer.ts index 131883ac..55a6da0b 100644 --- a/src/components/shared/ChatSettingsRenderer.ts +++ b/src/components/shared/ChatSettingsRenderer.ts @@ -11,12 +11,13 @@ * The difference is only WHERE data is saved (via callbacks). */ -import { App, Setting } from 'obsidian'; +import { App, Setting, EventRef } from 'obsidian'; import { LLMProviderManager } from '../../services/llm/providers/ProviderManager'; import { StaticModelsService } from '../../services/StaticModelsService'; import { LLMProviderSettings, ThinkingEffort } from '../../types/llm/ProviderTypes'; import { FilePickerRenderer } from '../workspace/FilePickerRenderer'; import { isDesktop, isProviderCompatible } from '../../utils/platform'; +import { LLMSettingsNotifier } from '../../services/llm/LLMSettingsNotifier'; /** * Current settings state @@ -95,7 +96,8 @@ const PROVIDER_NAMES: Record = { groq: 'Groq', openrouter: 'OpenRouter', requesty: 'Requesty', - perplexity: 'Perplexity' + perplexity: 'Perplexity', + 'openai-codex': 'ChatGPT' }; const EFFORT_LEVELS: ThinkingEffort[] = ['low', 'medium', 'high']; @@ -129,6 +131,10 @@ export class ChatSettingsRenderer { private effortSection?: HTMLElement; private agentEffortSection?: HTMLElement; private contextNotesListEl?: HTMLElement; + private settingsEventRef?: EventRef; + // Maps dropdown option value -> actual { provider, modelId } for merged model lists + private modelOptionMap: Map = new Map(); + private agentModelOptionMap: Map = new Map(); constructor(container: HTMLElement, config: ChatSettingsRendererConfig) { this.container = container; @@ -140,6 +146,19 @@ export class ChatSettingsRenderer { config.llmProviderSettings, config.app.vault ); + + this.settingsEventRef = LLMSettingsNotifier.onSettingsChanged((newSettings) => { + this.config.llmProviderSettings = newSettings; + this.providerManager.updateSettings(newSettings); + this.render(); + }); + } + + destroy(): void { + if (this.settingsEventRef) { + LLMSettingsNotifier.unsubscribe(this.settingsEventRef); + this.settingsEventRef = undefined; + } } render(): void { @@ -161,6 +180,8 @@ export class ChatSettingsRenderer { private getEnabledProviders(): string[] { const llmSettings = this.config.llmProviderSettings; return Object.keys(llmSettings.providers).filter(id => { + // Codex models are merged into the OpenAI provider display + if (id === 'openai-codex') return false; const config = llmSettings.providers[id]; if (!config?.enabled) return false; if (!isProviderCompatible(id)) return false; @@ -171,6 +192,11 @@ export class ChatSettingsRenderer { }); } + private isCodexConnected(): boolean { + const codexConfig = this.config.llmProviderSettings.providers['openai-codex']; + return !!(codexConfig?.oauth?.connected && codexConfig?.apiKey); + } + // ========== MODEL SECTION ========== private renderModelSection(parent: HTMLElement): void { @@ -183,10 +209,12 @@ export class ChatSettingsRenderer { .setName('Provider') .addDropdown(dropdown => { const providers = this.getEnabledProviders(); + // openai-codex is displayed under openai in the dropdown + const displayProvider = this.settings.provider === 'openai-codex' ? 'openai' : this.settings.provider; // If the currently-selected provider isn't usable on this platform (e.g. desktop-only // providers on mobile), fall back to the first available option. - if (providers.length > 0 && !providers.includes(this.settings.provider)) { + if (providers.length > 0 && !providers.includes(displayProvider)) { const nextProvider = providers[0]; this.settings.provider = nextProvider; this.settings.model = ''; @@ -207,7 +235,7 @@ export class ChatSettingsRenderer { }); } - dropdown.setValue(this.settings.provider); + dropdown.setValue(displayProvider); dropdown.onChange(async (value) => { this.settings.provider = value; this.settings.model = await this.getDefaultModelForProvider(value); @@ -218,8 +246,10 @@ export class ChatSettingsRenderer { // Model const providerId = this.settings.provider; + // For display purposes, openai-codex models appear under openai + const modelProviderId = providerId === 'openai-codex' ? 'openai' : providerId; - if (providerId === 'ollama') { + if (modelProviderId === 'ollama') { new Setting(content) .setName('Model') .addText(text => text @@ -230,33 +260,52 @@ export class ChatSettingsRenderer { new Setting(content) .setName('Model') .addDropdown(async dropdown => { - if (!providerId) { + if (!modelProviderId) { dropdown.addOption('', 'Select a provider first'); return; } try { - const models = await this.providerManager.getModelsForProvider(providerId); + this.modelOptionMap.clear(); + let models = await this.providerManager.getModelsForProvider(modelProviderId); + + // Merge Codex models into OpenAI list when Codex OAuth is connected + if (modelProviderId === 'openai' && this.isCodexConnected()) { + const codexModels = await this.providerManager.getModelsForProvider('openai-codex'); + const openaiModelIds = new Set(models.map(m => m.id)); + for (const cm of codexModels) { + // Skip duplicates (same model ID available in both providers) + if (!openaiModelIds.has(cm.id)) { + models = [...models, { ...cm, name: `${cm.name} (ChatGPT)` }]; + } + } + } + if (models.length === 0) { dropdown.addOption('', 'No models available'); } else { models.forEach(model => { - dropdown.addOption(model.id, model.name); + const optionKey = model.id; + this.modelOptionMap.set(optionKey, { provider: model.provider, modelId: model.id }); + dropdown.addOption(optionKey, model.name); }); const exists = models.some(m => m.id === this.settings.model); if (exists) { dropdown.setValue(this.settings.model); } else if (models.length > 0) { + const firstEntry = this.modelOptionMap.get(models[0].id); this.settings.model = models[0].id; + if (firstEntry) this.settings.provider = firstEntry.provider; dropdown.setValue(this.settings.model); - this.notifyChange(); } } dropdown.onChange((value) => { - this.settings.model = value; + const entry = this.modelOptionMap.get(value); + this.settings.model = entry?.modelId ?? value; + this.settings.provider = entry?.provider ?? modelProviderId; this.notifyChange(); // Re-render to update reasoning visibility this.render(); @@ -335,13 +384,14 @@ export class ChatSettingsRenderer { // Get only API-based providers (exclude local ones) const apiProviders = this.getEnabledProviders().filter(id => !LOCAL_PROVIDERS.includes(id)); + const agentDisplayProvider = this.settings.agentProvider === 'openai-codex' ? 'openai' : this.settings.agentProvider; // Provider dropdown new Setting(content) .setName('Provider') .addDropdown(dropdown => { // If the currently-selected agent provider isn't available, fall back to first API provider - if (apiProviders.length > 0 && this.settings.agentProvider && !apiProviders.includes(this.settings.agentProvider)) { + if (apiProviders.length > 0 && agentDisplayProvider && !apiProviders.includes(agentDisplayProvider)) { const nextProvider = apiProviders[0]; this.settings.agentProvider = nextProvider; this.settings.agentModel = ''; @@ -361,7 +411,7 @@ export class ChatSettingsRenderer { }); } - dropdown.setValue(this.settings.agentProvider || ''); + dropdown.setValue(agentDisplayProvider || ''); dropdown.onChange(async (value) => { this.settings.agentProvider = value === '' ? undefined : value; this.settings.agentModel = value ? await this.getDefaultModelForProvider(value) : undefined; @@ -372,37 +422,55 @@ export class ChatSettingsRenderer { // Model dropdown - always shown (mirrors Chat Model pattern) const agentProviderId = this.settings.agentProvider; + const agentModelProviderId = agentProviderId === 'openai-codex' ? 'openai' : agentProviderId; new Setting(content) .setName('Model') .addDropdown(async dropdown => { - if (!agentProviderId) { + if (!agentModelProviderId) { dropdown.addOption('', 'Select a provider first'); return; } try { - const models = await this.providerManager.getModelsForProvider(agentProviderId); + this.agentModelOptionMap.clear(); + let models = await this.providerManager.getModelsForProvider(agentModelProviderId); + + // Merge Codex models into OpenAI list when Codex OAuth is connected + if (agentModelProviderId === 'openai' && this.isCodexConnected()) { + const codexModels = await this.providerManager.getModelsForProvider('openai-codex'); + const openaiModelIds = new Set(models.map(m => m.id)); + for (const cm of codexModels) { + if (!openaiModelIds.has(cm.id)) { + models = [...models, { ...cm, name: `${cm.name} (ChatGPT)` }]; + } + } + } if (models.length === 0) { dropdown.addOption('', 'No models available'); } else { models.forEach(model => { - dropdown.addOption(model.id, model.name); + const optionKey = model.id; + this.agentModelOptionMap.set(optionKey, { provider: model.provider, modelId: model.id }); + dropdown.addOption(optionKey, model.name); }); const exists = models.some(m => m.id === this.settings.agentModel); if (exists) { dropdown.setValue(this.settings.agentModel!); } else if (models.length > 0) { + const firstEntry = this.agentModelOptionMap.get(models[0].id); this.settings.agentModel = models[0].id; + if (firstEntry) this.settings.agentProvider = firstEntry.provider; dropdown.setValue(this.settings.agentModel); - this.notifyChange(); } } dropdown.onChange((value) => { - this.settings.agentModel = value; + const entry = this.agentModelOptionMap.get(value); + this.settings.agentModel = entry?.modelId ?? value; + this.settings.agentProvider = entry?.provider ?? agentModelProviderId; this.notifyChange(); // Re-render to update reasoning visibility this.render(); diff --git a/src/core/services/ServiceDefinitions.ts b/src/core/services/ServiceDefinitions.ts index 577f9f1c..a289b202 100644 --- a/src/core/services/ServiceDefinitions.ts +++ b/src/core/services/ServiceDefinitions.ts @@ -228,6 +228,11 @@ export const CORE_SERVICE_DEFINITIONS: ServiceDefinition[] = [ llmService.setToolExecutor(directToolExecutor); } + // Wire settings persistence so token refresh is saved to disk immediately + llmService.setOnSettingsDirty(() => { + context.settings.saveSettings().catch(() => {}); + }); + return llmService; } }, diff --git a/src/main.ts b/src/main.ts index 9986364e..e64673a5 100644 --- a/src/main.ts +++ b/src/main.ts @@ -92,6 +92,22 @@ export default class NexusPlugin extends Plugin { } } + // Register OAuth providers (desktop only — needs local callback server) + if (Platform.isDesktop) { + try { + const { OAuthService } = await import('./services/oauth/OAuthService'); + const { OpenRouterOAuthProvider } = await import('./services/oauth/providers/OpenRouterOAuthProvider'); + const { OpenAICodexOAuthProvider } = await import('./services/oauth/providers/OpenAICodexOAuthProvider'); + + const oauthService = OAuthService.getInstance(); + oauthService.registerProvider(new OpenRouterOAuthProvider()); + oauthService.registerProvider(new OpenAICodexOAuthProvider()); + } catch (error) { + console.error(`[${BRAND_NAME}] Failed to initialize OAuth providers:`, error); + // Continue without OAuth — manual API key entry still works + } + } + // Create and initialize lifecycle manager const lifecycleConfig: PluginLifecycleConfig = { plugin: this, @@ -123,6 +139,16 @@ export default class NexusPlugin extends Plugin { await this.connector.stop(); } + // Clean up OAuth singleton (cancels any in-flight flow, releases callback server) + if (Platform.isDesktop) { + try { + const { OAuthService } = await import('./services/oauth/OAuthService'); + OAuthService.resetInstance(); + } catch { + // OAuth may not have been loaded; ignore + } + } + // Service manager cleanup handled by lifecycle manager } diff --git a/src/services/StaticModelsService.ts b/src/services/StaticModelsService.ts index e46d3591..d5ab26ff 100644 --- a/src/services/StaticModelsService.ts +++ b/src/services/StaticModelsService.ts @@ -12,6 +12,7 @@ import { GROQ_MODELS } from './llm/adapters/groq/GroqModels'; import { OPENROUTER_MODELS } from './llm/adapters/openrouter/OpenRouterModels'; import { REQUESTY_MODELS } from './llm/adapters/requesty/RequestyModels'; import { PERPLEXITY_MODELS } from './llm/adapters/perplexity/PerplexityModels'; +import { OPENAI_CODEX_MODELS } from './llm/adapters/openai-codex/OpenAICodexModels'; export interface ModelWithProvider { provider: string; @@ -63,7 +64,8 @@ export class StaticModelsService { { provider: 'groq', models: GROQ_MODELS }, { provider: 'openrouter', models: OPENROUTER_MODELS }, { provider: 'requesty', models: REQUESTY_MODELS }, - { provider: 'perplexity', models: PERPLEXITY_MODELS } + { provider: 'perplexity', models: PERPLEXITY_MODELS }, + { provider: 'openai-codex', models: OPENAI_CODEX_MODELS } ]; providerModels.forEach(({ provider, models }) => { @@ -110,6 +112,9 @@ export class StaticModelsService { case 'perplexity': providerModels = PERPLEXITY_MODELS; break; + case 'openai-codex': + providerModels = OPENAI_CODEX_MODELS; + break; default: return []; } @@ -149,7 +154,7 @@ export class StaticModelsService { * Get provider information */ getAvailableProviders(): string[] { - return ['openai', 'anthropic', 'google', 'mistral', 'groq', 'openrouter', 'requesty', 'perplexity']; + return ['openai', 'anthropic', 'google', 'mistral', 'groq', 'openrouter', 'requesty', 'perplexity', 'openai-codex']; } /** diff --git a/src/services/llm/adapters/ModelRegistry.ts b/src/services/llm/adapters/ModelRegistry.ts index 8ca3bd5f..311eec3f 100644 --- a/src/services/llm/adapters/ModelRegistry.ts +++ b/src/services/llm/adapters/ModelRegistry.ts @@ -15,6 +15,7 @@ import { MISTRAL_MODELS, MISTRAL_DEFAULT_MODEL } from './mistral/MistralModels'; import { OPENROUTER_MODELS, OPENROUTER_DEFAULT_MODEL } from './openrouter/OpenRouterModels'; import { REQUESTY_MODELS, REQUESTY_DEFAULT_MODEL } from './requesty/RequestyModels'; import { GROQ_MODELS, GROQ_DEFAULT_MODEL } from './groq/GroqModels'; +import { OPENAI_CODEX_MODELS, OPENAI_CODEX_DEFAULT_MODEL } from './openai-codex/OpenAICodexModels'; import type { LLMProviderSettings } from '../../../types'; // Re-export ModelSpec for convenience @@ -27,6 +28,7 @@ export type { ModelSpec }; */ export const AI_MODELS: Record = { openai: OPENAI_MODELS, + 'openai-codex': OPENAI_CODEX_MODELS, google: GOOGLE_MODELS, anthropic: ANTHROPIC_MODELS, mistral: MISTRAL_MODELS, @@ -208,6 +210,7 @@ export class ModelRegistry { */ export const DEFAULT_MODELS: Record = { openai: OPENAI_DEFAULT_MODEL, + 'openai-codex': OPENAI_CODEX_DEFAULT_MODEL, google: GOOGLE_DEFAULT_MODEL, anthropic: ANTHROPIC_DEFAULT_MODEL, mistral: MISTRAL_DEFAULT_MODEL, diff --git a/src/services/llm/adapters/openai-codex/OpenAICodexAdapter.ts b/src/services/llm/adapters/openai-codex/OpenAICodexAdapter.ts new file mode 100644 index 00000000..ac8c6df1 --- /dev/null +++ b/src/services/llm/adapters/openai-codex/OpenAICodexAdapter.ts @@ -0,0 +1,709 @@ +/** + * OpenAI Codex Adapter + * Location: src/services/llm/adapters/openai-codex/OpenAICodexAdapter.ts + * + * LLM adapter that routes inference to the Codex endpoint using OAuth tokens + * obtained via the PKCE flow against auth.openai.com. The Codex API uses a + * custom SSE streaming format (Responses API style), not the standard Chat + * Completions format. + * + * Key differences from standard OpenAI adapter: + * - Auth: OAuth Bearer token + ChatGPT-Account-Id header (not API key) + * - Endpoint: chatgpt.com/backend-api/codex/responses (not api.openai.com) + * - Request body: { input: [...], stream: true, store: false } (Responses API) + * - SSE events: delta.text / delta.content (not choices[].delta.content) + * - Token refresh: proactive refresh when access_token nears expiry + * - Cost: $0 (subscription-based, not per-token) + * + * Desktop only: uses Node.js https module to bypass browser CORS restrictions. + * + * Used by: AdapterRegistry (initializes this adapter when openai-codex is + * enabled with OAuth state), StreamingOrchestrator (for streaming inference). + */ + +import { BaseAdapter } from '../BaseAdapter'; +import { + GenerateOptions, + StreamChunk, + LLMResponse, + ModelInfo, + ProviderCapabilities, + ModelPricing, + LLMProviderError, + ToolCall +} from '../types'; +import { ModelRegistry } from '../ModelRegistry'; +import { BRAND_NAME } from '../../../../constants/branding'; + +/** Codex API endpoint (requires ChatGPT subscription) */ +const CODEX_API_ENDPOINT = 'https://chatgpt.com/backend-api/codex/responses'; + +/** OpenAI OAuth token endpoint for refresh */ +const OAUTH_TOKEN_ENDPOINT = 'https://auth.openai.com/oauth/token'; + +/** OAuth client ID (same as used during PKCE flow) */ +const OAUTH_CLIENT_ID = 'app_EMoamEEZ73f0CkXaXp7hrann'; + +/** Proactive refresh threshold: refresh if token expires within 5 minutes */ +const TOKEN_REFRESH_THRESHOLD_MS = 5 * 60 * 1000; + +/** Timeout for token refresh requests (30 seconds) */ +const TOKEN_REFRESH_TIMEOUT_MS = 30_000; + +/** Timeout for streaming inference requests (2 minutes) */ +const STREAMING_REQUEST_TIMEOUT_MS = 120_000; + +/** Two-tool architecture tool names (must match ToolManager slugs) */ +const TOOL_NAMES = { discover: 'getTools', execute: 'useTools' } as const; + +/** + * OAuth token state managed by the adapter. + * Mirrors the fields persisted in OAuthState on LLMProviderConfig.oauth. + */ +export interface CodexOAuthTokens { + accessToken: string; + refreshToken: string; + expiresAt: number; + accountId: string; +} + +/** + * Callback to persist refreshed tokens back to plugin settings. + * The adapter calls this after a successful token refresh so the + * new tokens survive across plugin restarts. + */ +export type TokenPersistCallback = (tokens: CodexOAuthTokens) => void; + +export class OpenAICodexAdapter extends BaseAdapter { + readonly name = 'openai-codex'; + readonly baseUrl = CODEX_API_ENDPOINT; + + private tokens: CodexOAuthTokens; + private onTokenRefresh?: TokenPersistCallback; + private refreshInProgress: Promise | null = null; + + /** + * @param tokens - Current OAuth token state (access token, refresh token, expiry, account ID) + * @param onTokenRefresh - Optional callback invoked after successful token refresh to persist new tokens + */ + constructor(tokens: CodexOAuthTokens, onTokenRefresh?: TokenPersistCallback) { + // Pass accessToken as apiKey for BaseAdapter compatibility; baseUrl is the Codex endpoint + super(tokens.accessToken, 'gpt-5.3-codex', CODEX_API_ENDPOINT, false); + this.tokens = { ...tokens }; + this.onTokenRefresh = onTokenRefresh; + this.initializeCache(); + } + + /** + * Ensure the access token is fresh before making a request. + * Uses a deduplication lock to prevent concurrent refresh attempts. + */ + private async ensureFreshToken(): Promise { + const timeUntilExpiry = this.tokens.expiresAt - Date.now(); + + if (timeUntilExpiry > TOKEN_REFRESH_THRESHOLD_MS) { + return; // Token is still fresh + } + + // Deduplicate: if a refresh is already in flight, wait for it + if (this.refreshInProgress) { + await this.refreshInProgress; + return; + } + + this.refreshInProgress = this.performTokenRefresh(); + try { + await this.refreshInProgress; + } finally { + this.refreshInProgress = null; + } + } + + /** + * Execute the OAuth token refresh against auth.openai.com. + * Updates internal state and invokes the persistence callback. + * + * NOTE: This duplicates the refresh logic in OpenAICodexOAuthProvider.refreshToken(). + * The duplication is intentional — the OAuthProvider uses fetch() (fine for the + * initial OAuth UI flow), while the adapter must use Node.js https to bypass + * CORS restrictions in Electron's renderer process during inference. + */ + private async performTokenRefresh(): Promise { + const body = new URLSearchParams({ + grant_type: 'refresh_token', + client_id: OAUTH_CLIENT_ID, + refresh_token: this.tokens.refreshToken + }); + + // Use Node.js https to bypass browser CORS restrictions + // eslint-disable-next-line @typescript-eslint/no-var-requires + const httpsModule = require('https') as typeof import('https'); + const bodyStr = body.toString(); + const parsedUrl = new URL(OAUTH_TOKEN_ENDPOINT); + + const { statusCode, data } = await new Promise<{ statusCode: number; data: string }>( + (resolve, reject) => { + let responseData = ''; + const req = httpsModule.request( + { + hostname: parsedUrl.hostname, + path: parsedUrl.pathname, + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Content-Length': Buffer.byteLength(bodyStr) + } + }, + (res) => { + res.on('data', (chunk: Buffer) => { responseData += chunk.toString(); }); + res.on('end', () => resolve({ statusCode: res.statusCode ?? 0, data: responseData })); + } + ); + req.setTimeout(TOKEN_REFRESH_TIMEOUT_MS, () => { + req.destroy(new Error('Token refresh request timed out')); + }); + req.on('error', reject); + req.write(bodyStr); + req.end(); + } + ); + + if (statusCode < 200 || statusCode >= 300) { + throw new LLMProviderError( + `Token refresh failed (HTTP ${statusCode}): ${data}`, + this.name, + 'AUTHENTICATION_ERROR' + ); + } + + let tokenData: Record; + try { + tokenData = JSON.parse(data); + } catch { + throw new LLMProviderError( + `Token refresh returned malformed response: ${data.slice(0, 200)}`, + this.name, + 'AUTHENTICATION_ERROR' + ); + } + + // Validate expires_in — default to 10 days if missing or invalid + const rawExpiresIn = tokenData.expires_in; + const expiresIn = (typeof rawExpiresIn === 'number' && rawExpiresIn > 0) + ? rawExpiresIn + : 864000; + + // Update internal token state + this.tokens = { + accessToken: tokenData.access_token as string, + refreshToken: (tokenData.refresh_token as string) || this.tokens.refreshToken, // Rotation: use new if provided + expiresAt: Date.now() + (expiresIn * 1000), + accountId: this.tokens.accountId // Account ID doesn't change on refresh + }; + + // Update the apiKey field used by BaseAdapter + this.apiKey = this.tokens.accessToken; + + // Persist the refreshed tokens + if (this.onTokenRefresh) { + this.onTokenRefresh(this.tokens); + } + } + + /** + * Build the request headers for the Codex API. + */ + private buildCodexHeaders(): Record { + return { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${this.tokens.accessToken}`, + 'ChatGPT-Account-Id': this.tokens.accountId, + 'originator': 'opencode', + 'User-Agent': `claudesidian-mcp/${BRAND_NAME}` + }; + } + + /** + * Convert the plugin's message format to the Codex input array format. + * Codex expects: { role: string, content: string }[] + */ + private buildCodexInput( + prompt: string, + systemPrompt?: string, + conversationHistory?: Array> + ): Array> { + // If conversation history is provided, use it directly. + // Items may be role-based messages ({role, content}) or Responses API + // items ({type: "function_call"|"function_call_output", ...}). + if (conversationHistory && conversationHistory.length > 0) { + return conversationHistory; + } + + // Otherwise build from prompt + optional system prompt + const input: Array> = []; + if (systemPrompt) { + input.push({ role: 'system', content: systemPrompt }); + } + input.push({ role: 'user', content: prompt }); + return input; + } + + /** + * Generate a non-streaming response. + * Note: The Codex endpoint requires stream: true, so we collect + * all SSE chunks and return the assembled result. + */ + async generateUncached(prompt: string, options?: GenerateOptions): Promise { + try { + await this.ensureFreshToken(); + + const model = options?.model || this.currentModel; + let fullText = ''; + let collectedToolCalls: ToolCall[] = []; + + // Codex requires streaming; collect all chunks + for await (const chunk of this.generateStreamAsync(prompt, options)) { + if (chunk.content) { + fullText += chunk.content; + } + if (chunk.toolCalls && chunk.toolCalls.length > 0) { + collectedToolCalls = chunk.toolCalls; + } + } + + const hasToolCalls = collectedToolCalls.length > 0; + return this.buildLLMResponse( + fullText, + model, + { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, // Codex doesn't report usage + {}, + hasToolCalls ? 'tool_calls' : 'stop', + hasToolCalls ? collectedToolCalls : undefined + ); + } catch (error) { + throw this.handleError(error, 'generation'); + } + } + + /** + * Generate a streaming response from the Codex endpoint. + * Reads SSE events and extracts text deltas from the Responses API format. + */ + async* generateStreamAsync( + prompt: string, + options?: GenerateOptions + ): AsyncGenerator { + try { + await this.ensureFreshToken(); + + const model = options?.model || this.currentModel; + const input = this.buildCodexInput( + prompt, + options?.systemPrompt, + options?.conversationHistory + ); + + const requestBody: Record = { + model, + input, + stream: true, + store: false + }; + + // Always include instructions — Codex API requires this field on every request + // (including tool continuation calls which pass conversationHistory) + requestBody.instructions = options?.systemPrompt || ''; + + if (options?.temperature !== undefined) { + requestBody.temperature = options.temperature; + } + if (options?.maxTokens !== undefined) { + requestBody.max_output_tokens = options.maxTokens; + } + + // Convert tools from Chat Completions format to Responses API flat format + // Codex expects: { type: "function", name: "...", parameters: {...} } + // Chat Completions sends: { type: "function", function: { name: "...", parameters: {...} } } + if (options?.tools && options.tools.length > 0) { + requestBody.tools = options.tools.map((tool) => { + const fn = tool.function as Record | undefined; + if (fn) { + const converted: Record = { + type: 'function', + name: fn.name, + parameters: fn.parameters || {} + }; + // Only include optional fields if they have values + // (null/undefined fields can cause API errors) + if (fn.description) converted.description = fn.description; + if (fn.strict !== undefined && fn.strict !== null) converted.strict = fn.strict; + return converted; + } + // Already in Responses API format + return tool; + }); + + // Tell the API to allow tool calls (default may be "none" for some models) + requestBody.tool_choice = 'auto'; + + // Prepend Codex-specific tool instruction to ensure the model uses tools + // rather than responding with plain text describing what it would do + const toolPreamble = 'You are an AI assistant with tool access. ' + + 'Fulfill user requests by calling tools immediately — do NOT describe what you will do. ' + + `Call ${TOOL_NAMES.discover} first to discover available tools, then call ${TOOL_NAMES.execute} to execute them.\n\n`; + requestBody.instructions = toolPreamble + (requestBody.instructions || ''); + + } + + // Use Node.js https to bypass browser CORS restrictions + // eslint-disable-next-line @typescript-eslint/no-var-requires + const httpsModule = require('https') as typeof import('https'); + const bodyStr = JSON.stringify(requestBody); + const headers = this.buildCodexHeaders(); + const parsedUrl = new URL(CODEX_API_ENDPOINT); + + const { statusCode, nodeRes } = await new Promise<{ + statusCode: number; + nodeRes: import('http').IncomingMessage; + }>((resolve, reject) => { + const req = httpsModule.request( + { + hostname: parsedUrl.hostname, + path: parsedUrl.pathname, + method: 'POST', + headers: { ...headers, 'Content-Length': Buffer.byteLength(bodyStr) } + }, + (res) => resolve({ statusCode: res.statusCode ?? 0, nodeRes: res }) + ); + req.setTimeout(STREAMING_REQUEST_TIMEOUT_MS, () => { + req.destroy(new Error('Codex streaming request timed out')); + }); + req.on('error', reject); + req.write(bodyStr); + req.end(); + }); + + // Error handling for non-2xx responses + if (statusCode >= 400) { + // Use event-listener pattern instead of async iteration — async iteration on + // Node.js IncomingMessage hangs in Electron's renderer process + const errorBody = await new Promise((resolve, reject) => { + let data = ''; + nodeRes.on('data', (c: Buffer) => { data += c.toString(); }); + nodeRes.on('end', () => resolve(data)); + nodeRes.on('error', reject); + }); + + // Detect expired/invalid token specifically + if (statusCode === 401 || statusCode === 403) { + throw new LLMProviderError( + `Codex API authentication failed (HTTP ${statusCode}). Token may be expired or revoked. Please reconnect via OAuth.`, + this.name, + 'AUTHENTICATION_ERROR' + ); + } + + // Rate limit — throw specific code so StreamingOrchestrator can fall back + if (statusCode === 429) { + throw new LLMProviderError( + `Codex rate limited (HTTP 429). ${errorBody}`, + this.name, + 'RATE_LIMIT_ERROR' + ); + } + + throw new LLMProviderError( + `Codex API error (HTTP ${statusCode}): ${errorBody}`, + this.name, + 'HTTP_ERROR' + ); + } + + // Parse SSE stream from the Node.js IncomingMessage + // The Codex API returns SSE with data: {json} lines containing + // response events in the Responses API format. + yield* this.parseNodeSSEStream(nodeRes); + + } catch (error) { + throw this.handleError(error, 'streaming generation'); + } + } + + /** + * Parse a Codex SSE stream from a Node.js IncomingMessage. + * + * The Codex Responses API emits events like: + * data: {"type":"response.output_text.delta","delta":{"text":"Hello"}} + * data: {"type":"response.output_text.done","text":"Hello world"} + * data: {"type":"response.completed",...} + * data: [DONE] + * + * We extract text deltas and yield StreamChunks. + */ + private async* parseNodeSSEStream( + nodeRes: import('http').IncomingMessage + ): AsyncGenerator { + let buffer = ''; + const toolCallsMap = new Map(); + let currentResponseId: string | undefined; + + // Use event-listener queue instead of async iteration — async iteration on + // Node.js IncomingMessage hangs in Electron's renderer process + const chunkQueue: string[] = []; + let streamEnded = false; + let streamError: Error | null = null; + let chunkWaiter: (() => void) | null = null; + + const notifyWaiter = () => { + if (chunkWaiter) { + const resolve = chunkWaiter; + chunkWaiter = null; + resolve(); + } + }; + + nodeRes.on('data', (chunk: Buffer) => { + chunkQueue.push(chunk.toString()); + notifyWaiter(); + }); + nodeRes.on('end', () => { + streamEnded = true; + notifyWaiter(); + }); + nodeRes.on('error', (err: Error) => { + streamError = err; + notifyWaiter(); + }); + + while (!streamEnded || chunkQueue.length > 0) { + if (streamError) throw streamError; + if (chunkQueue.length === 0) { + await new Promise(resolve => { chunkWaiter = resolve; }); + continue; + } + const rawChunk = chunkQueue.shift()!; + buffer += rawChunk; + + // Process complete lines from the buffer + const lines = buffer.split('\n'); + // Keep the last (potentially incomplete) line in the buffer + buffer = lines.pop() || ''; + + for (const line of lines) { + const trimmed = line.trim(); + + if (!trimmed || trimmed.startsWith(':')) { + // Empty line or SSE comment — skip + continue; + } + + if (!trimmed.startsWith('data: ')) { + continue; + } + + const jsonStr = trimmed.slice(6).trim(); + + if (jsonStr === '[DONE]') { + const finalToolCalls = toolCallsMap.size > 0 ? Array.from(toolCallsMap.values()) : undefined; + const metadata = currentResponseId ? { responseId: currentResponseId } : undefined; + yield { + content: '', + complete: true, + toolCalls: finalToolCalls, + toolCallsReady: finalToolCalls ? true : undefined, + metadata + }; + return; + } + + let event: Record; + try { + event = JSON.parse(jsonStr); + } catch { + // Malformed JSON — skip this line + continue; + } + + const eventType = event.type as string | undefined; + + // Capture response ID for stateful continuations (Responses API) + const responseObj = event.response as Record | undefined; + if (responseObj?.id && typeof responseObj.id === 'string' && !currentResponseId) { + currentResponseId = responseObj.id; + } + + // Accumulate completed function calls + if (eventType === 'response.output_item.done') { + const item = event.item as Record | undefined; + if (item && item.type === 'function_call') { + const index = (event.output_index as number) || 0; + toolCallsMap.set(index, { + id: (item.call_id as string) || (item.id as string) || '', + type: 'function', + function: { + name: (item.name as string) || '', + arguments: (item.arguments as string) || '{}' + } + }); + } + } + + // Arguments are streamed incrementally; we capture the complete call in output_item.done + if (eventType === 'response.function_call_arguments.delta') { + continue; + } + + // Extract text delta from various event shapes + const delta = this.extractDeltaText(event); + if (delta) { + yield { content: delta, complete: false }; + } + + // Detect completion event + if (eventType === 'response.completed' || eventType === 'response.done') { + const finalToolCalls = toolCallsMap.size > 0 ? Array.from(toolCallsMap.values()) : undefined; + const metadata = currentResponseId ? { responseId: currentResponseId } : undefined; + yield { + content: '', + complete: true, + toolCalls: finalToolCalls, + toolCallsReady: finalToolCalls ? true : undefined, + metadata + }; + return; + } + } + } + + // If stream ended without explicit [DONE], emit completion + const finalToolCalls = toolCallsMap.size > 0 ? Array.from(toolCallsMap.values()) : undefined; + const metadata = currentResponseId ? { responseId: currentResponseId } : undefined; + yield { + content: '', + complete: true, + toolCalls: finalToolCalls, + toolCallsReady: finalToolCalls ? true : undefined, + metadata + }; + } + + /** + * Extract text content from a Codex SSE event. + * The Responses API uses several event shapes for text delivery. + */ + private extractDeltaText(event: Record): string | null { + // Shape 1a: { delta: "text" } — Codex Responses API output_text.delta + // The delta field is the text string itself, not a nested object + if (typeof event.delta === 'string' && event.delta) { + return event.delta; + } + + // Shape 1b: { delta: { text: "..." } } — alternative nested delta format + const delta = event.delta as Record | undefined; + if (delta && typeof delta === 'object') { + if (typeof delta.text === 'string' && delta.text) return delta.text; + if (typeof delta.content === 'string' && delta.content) return delta.content; + } + + // Shape 2: { text: "..." } at top level — output_text.done event + // (Skip for done events to avoid duplicating the full text) + const eventType = event.type as string | undefined; + if (eventType === 'response.output_text.done') { + return null; // Full text is a recap, not a delta + } + + // Shape 3: { content: "..." } at top level — some event variants + if (typeof event.content === 'string' && event.content) { + return event.content; + } + + return null; + } + + /** + * List available Codex models from the static model registry. + */ + async listModels(): Promise { + const codexModels = ModelRegistry.getProviderModels('openai-codex'); + return codexModels.map(model => ModelRegistry.toModelInfo(model)); + } + + /** + * Get provider capabilities. + */ + getCapabilities(): ProviderCapabilities { + return { + supportsStreaming: true, + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsThinking: false, + maxContextWindow: 400000, + supportedFeatures: [ + 'streaming', + 'json_mode', + 'image_input', + 'tool_calling', + 'subscription_based', + 'oauth_required' + ] + }; + } + + /** + * Get model pricing — Codex models are subscription-based ($0 per token). + */ + async getModelPricing(modelId: string): Promise { + const models = ModelRegistry.getProviderModels('openai-codex'); + const model = models.find(m => m.apiName === modelId); + if (!model) return null; + + return { + rateInputPerMillion: 0, + rateOutputPerMillion: 0, + currency: 'USD' + }; + } + + /** + * Override isAvailable to check OAuth token validity instead of API key. + */ + async isAvailable(): Promise { + return !!( + this.tokens.accessToken && + this.tokens.refreshToken && + this.tokens.accountId + ); + } + + /** + * Get the current token state (for diagnostics or UI display). + * Masks sensitive values. + */ + getTokenStatus(): { + hasAccessToken: boolean; + hasRefreshToken: boolean; + hasAccountId: boolean; + expiresAt: number; + isExpired: boolean; + needsRefresh: boolean; + } { + const now = Date.now(); + return { + hasAccessToken: !!this.tokens.accessToken, + hasRefreshToken: !!this.tokens.refreshToken, + hasAccountId: !!this.tokens.accountId, + expiresAt: this.tokens.expiresAt, + isExpired: now >= this.tokens.expiresAt, + needsRefresh: (this.tokens.expiresAt - now) < TOKEN_REFRESH_THRESHOLD_MS + }; + } + + /** + * Update the OAuth tokens (e.g., after an external refresh or reconnect). + */ + updateTokens(tokens: CodexOAuthTokens): void { + this.tokens = { ...tokens }; + this.apiKey = tokens.accessToken; + } +} diff --git a/src/services/llm/adapters/openai-codex/OpenAICodexModels.ts b/src/services/llm/adapters/openai-codex/OpenAICodexModels.ts new file mode 100644 index 00000000..93718cd5 --- /dev/null +++ b/src/services/llm/adapters/openai-codex/OpenAICodexModels.ts @@ -0,0 +1,113 @@ +/** + * OpenAI Codex Model Specifications + * Location: src/services/llm/adapters/openai-codex/OpenAICodexModels.ts + * + * Defines models available through the Codex endpoint (ChatGPT subscription). + * All costs are $0 since these models are included in the user's ChatGPT + * subscription — no per-token API billing. + * + * Used by: ModelRegistry (AI_MODELS), OpenAICodexAdapter, ProviderManager + */ + +import { ModelSpec } from '../modelTypes'; + +export const OPENAI_CODEX_MODELS: ModelSpec[] = [ + { + provider: 'openai-codex', + name: 'GPT-5.3 Codex', + apiName: 'gpt-5.3-codex', + contextWindow: 400000, + maxTokens: 128000, + inputCostPerMillion: 0, + outputCostPerMillion: 0, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'openai-codex', + name: 'GPT-5.2 Codex', + apiName: 'gpt-5.2-codex', + contextWindow: 400000, + maxTokens: 128000, + inputCostPerMillion: 0, + outputCostPerMillion: 0, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'openai-codex', + name: 'GPT-5.2', + apiName: 'gpt-5.2', + contextWindow: 400000, + maxTokens: 128000, + inputCostPerMillion: 0, + outputCostPerMillion: 0, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'openai-codex', + name: 'GPT-5.1 Codex', + apiName: 'gpt-5.1-codex', + contextWindow: 400000, + maxTokens: 128000, + inputCostPerMillion: 0, + outputCostPerMillion: 0, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'openai-codex', + name: 'GPT-5.1 Codex Max', + apiName: 'gpt-5.1-codex-max', + contextWindow: 400000, + maxTokens: 128000, + inputCostPerMillion: 0, + outputCostPerMillion: 0, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'openai-codex', + name: 'GPT-5.1 Codex Mini', + apiName: 'gpt-5.1-codex-mini', + contextWindow: 200000, + maxTokens: 64000, + inputCostPerMillion: 0, + outputCostPerMillion: 0, + capabilities: { + supportsJSON: true, + supportsImages: false, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + } +]; + +export const OPENAI_CODEX_DEFAULT_MODEL = 'gpt-5.3-codex'; diff --git a/src/services/llm/adapters/types.ts b/src/services/llm/adapters/types.ts index 6fce453d..d56b6e8d 100644 --- a/src/services/llm/adapters/types.ts +++ b/src/services/llm/adapters/types.ts @@ -6,7 +6,7 @@ /** * Supported LLM providers */ -export type SupportedProvider = 'openai' | 'openrouter' | 'anthropic' | 'google' | 'groq' | 'mistral' | 'perplexity' | 'requesty'; +export type SupportedProvider = 'openai' | 'openai-codex' | 'openrouter' | 'anthropic' | 'google' | 'groq' | 'mistral' | 'perplexity' | 'requesty'; export interface GenerateOptions { model?: string; diff --git a/src/services/llm/core/AdapterRegistry.ts b/src/services/llm/core/AdapterRegistry.ts index a7e6a55e..8ae4db2c 100644 --- a/src/services/llm/core/AdapterRegistry.ts +++ b/src/services/llm/core/AdapterRegistry.ts @@ -22,6 +22,7 @@ import { isMobile } from '../../../utils/platform'; // Type imports for TypeScript (don't affect bundling) import type { WebLLMAdapter as WebLLMAdapterType } from '../adapters/webllm/WebLLMAdapter'; +import type { CodexOAuthTokens } from '../adapters/openai-codex/OpenAICodexAdapter'; /** * Interface for adapter registry operations @@ -71,6 +72,7 @@ export class AdapterRegistry implements IAdapterRegistry { private vault?: Vault; private webllmAdapter?: WebLLMAdapterType; private initPromise?: Promise; + private _onSettingsDirty?: () => void; constructor(settings: LLMProviderSettings, vault?: Vault) { this.settings = settings; @@ -98,6 +100,14 @@ export class AdapterRegistry implements IAdapterRegistry { } } + /** + * Set a callback invoked when adapter-level changes (e.g. token refresh) dirty the settings. + * The callback should persist settings to disk. + */ + setOnSettingsDirty(cb: () => void): void { + this._onSettingsDirty = cb; + } + /** * Update settings and reinitialize all adapters */ @@ -188,6 +198,10 @@ export class AdapterRegistry implements IAdapterRegistry { return new PerplexityAdapter(config.apiKey); }); + // OpenAI Codex (OAuth-based, uses fetch — mobile compatible) + // Requires OAuth tokens instead of a traditional API key + await this.initializeCodexAdapter(providers['openai-codex']); + // ═══════════════════════════════════════════════════════════════════════════ // DESKTOP-ONLY PROVIDERS (use Node.js SDKs) // Skip on mobile to avoid crashes from SDK Node.js dependencies @@ -279,6 +293,51 @@ export class AdapterRegistry implements IAdapterRegistry { } } + /** + * Initialize the OpenAI Codex adapter from OAuth state. + * Unlike API-key providers, Codex uses OAuth tokens stored in config.oauth. + * The adapter handles proactive token refresh and calls back to persist new tokens. + */ + private async initializeCodexAdapter(config: LLMProviderConfig | undefined): Promise { + if (!config?.enabled) return; + + const oauth = config.oauth; + if (!oauth?.connected || !config.apiKey || !oauth.refreshToken || !oauth.metadata?.accountId) { + return; // Not connected via OAuth — skip initialization + } + + try { + const { OpenAICodexAdapter } = await import('../adapters/openai-codex/OpenAICodexAdapter'); + + const tokens: CodexOAuthTokens = { + accessToken: config.apiKey, // OAuth access token is stored as apiKey + refreshToken: oauth.refreshToken, + expiresAt: oauth.expiresAt || 0, + accountId: oauth.metadata.accountId + }; + + // Token refresh callback: updates the settings so refreshed tokens + // persist across plugin restarts, then triggers a settings save. + const onTokenRefresh = (newTokens: CodexOAuthTokens): void => { + // Update the config object in-place (settings reference) + config.apiKey = newTokens.accessToken; + const oauthState = config.oauth; + if (oauthState) { + oauthState.refreshToken = newTokens.refreshToken; + oauthState.expiresAt = newTokens.expiresAt; + } + // Persist to disk immediately so rotated tokens survive a crash + this._onSettingsDirty?.(); + }; + + const adapter = new OpenAICodexAdapter(tokens, onTokenRefresh); + this.adapters.set('openai-codex', adapter); + } catch (error) { + console.error('AdapterRegistry: Failed to initialize OpenAI Codex adapter:', error); + this.logError('openai-codex', error); + } + } + /** * Initialize a single provider adapter using async factory pattern * Handles common validation and error logging with dynamic import support diff --git a/src/services/llm/core/LLMService.ts b/src/services/llm/core/LLMService.ts index 693e3b22..14c74dfc 100644 --- a/src/services/llm/core/LLMService.ts +++ b/src/services/llm/core/LLMService.ts @@ -84,6 +84,14 @@ export class LLMService { this.toolExecutor = executor; } + /** + * Set a callback invoked when adapter-level changes (e.g. token refresh) dirty settings. + * The callback should persist settings to disk. + */ + setOnSettingsDirty(cb: () => void): void { + this.adapterRegistry.setOnSettingsDirty(cb); + } + /** Update settings and reinitialize adapters */ updateSettings(settings: LLMProviderSettings): void { diff --git a/src/services/llm/core/ProviderMessageBuilder.ts b/src/services/llm/core/ProviderMessageBuilder.ts index 9b8d5302..82db1a84 100644 --- a/src/services/llm/core/ProviderMessageBuilder.ts +++ b/src/services/llm/core/ProviderMessageBuilder.ts @@ -188,6 +188,53 @@ export class ProviderMessageBuilder { conversationHistory, systemPrompt: generateOptions.systemPrompt }; + } else if (provider === 'openai-codex') { + // Codex uses stateless Responses API — no previous_response_id. + // Build a full input array: prior messages + user prompt + function_call + function_call_output items. + const inputItems: Array> = []; + + // Add prior conversation messages + for (const msg of previousMessages) { + if (msg.role === 'system') continue; // system prompt goes in instructions + inputItems.push({ role: msg.role, content: msg.content }); + } + + // Add current user prompt + if (userPrompt) { + inputItems.push({ role: 'user', content: userPrompt }); + } + + // Add function_call items (what the model called) + for (const tc of toolCalls) { + const name = ('name' in tc && tc.name) ? tc.name as string : tc.function?.name || ''; + const args = tc.function?.arguments || '{}'; + inputItems.push({ + type: 'function_call', + call_id: tc.id, + name, + arguments: args + }); + } + + // Add function_call_output items (what the tools returned) + for (let i = 0; i < toolResults.length; i++) { + const result = toolResults[i]; + const tc = toolCalls[i]; + inputItems.push({ + type: 'function_call_output', + call_id: tc?.id || result.id, + output: result.success + ? JSON.stringify(result.result || {}) + : JSON.stringify({ error: result.error || 'Tool execution failed' }) + }); + } + + return { + ...generateOptions, + conversationHistory: inputItems, + systemPrompt: generateOptions.systemPrompt, + tools: generateOptions.tools + }; } else if (provider === 'openai') { // OpenAI uses Responses API with function_call_output items + previous_response_id // State is reliably tracked on OpenAI's servers diff --git a/src/services/llm/core/StreamingOrchestrator.ts b/src/services/llm/core/StreamingOrchestrator.ts index 888da872..2889ea8c 100644 --- a/src/services/llm/core/StreamingOrchestrator.ts +++ b/src/services/llm/core/StreamingOrchestrator.ts @@ -12,7 +12,8 @@ import { IToolExecutor } from '../adapters/shared/ToolExecutionUtils'; import { LLMProviderSettings } from '../../../types'; import { IAdapterRegistry } from './AdapterRegistry'; -import { TokenUsage } from '../adapters/types'; +import { TokenUsage, LLMProviderError } from '../adapters/types'; +import { Notice } from 'obsidian'; import { ToolCall as ChatToolCall } from '../../../types/chat/ChatTypes'; import { ProviderMessageBuilder, @@ -90,81 +91,177 @@ export class StreamingOrchestrator { const isGoogleModel = provider === 'google'; const promptToPass = isGoogleModel ? '' : userPrompt; - for await (const chunk of adapter.generateStreamAsync(promptToPass, generateOptions)) { - // Track usage from chunks - if (chunk.usage) { - finalUsage = chunk.usage; - } + // Determine the active adapter and provider for streaming. + // These may be swapped to a fallback on Codex rate limit (429). + let activeAdapter = adapter; + let activeProvider = provider; - // Handle text content streaming - if (chunk.content) { - fullContent += chunk.content; + try { + for await (const chunk of activeAdapter.generateStreamAsync(promptToPass, generateOptions)) { + // Track usage from chunks + if (chunk.usage) { + finalUsage = chunk.usage; + } - yield { - chunk: chunk.content, - complete: false, - content: fullContent, - toolCalls: undefined - }; - } + // Handle text content streaming + if (chunk.content) { + fullContent += chunk.content; - // Handle reasoning/thinking content (Claude, GPT-5, Gemini) - if (chunk.reasoning) { - yield { - chunk: '', - complete: false, - content: fullContent, - toolCalls: undefined, - reasoning: chunk.reasoning, - reasoningComplete: chunk.reasoningComplete - }; - } + yield { + chunk: chunk.content, + complete: false, + content: fullContent, + toolCalls: undefined + }; + } + + // Handle reasoning/thinking content (Claude, GPT-5, Gemini) + if (chunk.reasoning) { + yield { + chunk: '', + complete: false, + content: fullContent, + toolCalls: undefined, + reasoning: chunk.reasoning, + reasoningComplete: chunk.reasoningComplete + }; + } + + // Handle dynamic tool call detection + if (chunk.toolCalls) { + const chatToolCalls: ChatToolCall[] = chunk.toolCalls.map(tc => ({ + ...tc, + type: tc.type || 'function', + function: tc.function || { name: '', arguments: '{}' } + })); + + // ALWAYS yield tool calls for progressive UI display + yield { + chunk: '', + complete: false, + content: fullContent, + toolCalls: chatToolCalls, + toolCallsReady: chunk.complete || false + }; + + // Only STORE tool calls for execution when streaming is COMPLETE + if (chunk.complete) { + detectedToolCalls = chatToolCalls; + } + } - // Handle dynamic tool call detection - if (chunk.toolCalls) { - const chatToolCalls: ChatToolCall[] = chunk.toolCalls.map(tc => ({ - ...tc, - type: tc.type || 'function', - function: tc.function || { name: '', arguments: '{}' } - })); - - // ALWAYS yield tool calls for progressive UI display - yield { - chunk: '', - complete: false, - content: fullContent, - toolCalls: chatToolCalls, - toolCallsReady: chunk.complete || false - }; - - // Only STORE tool calls for execution when streaming is COMPLETE if (chunk.complete) { - detectedToolCalls = chatToolCalls; + // Store response ID for future continuations (OpenAI/Codex use Responses API) + const rawResponseId = chunk.metadata?.responseId; + if ((activeProvider === 'openai' || activeProvider === 'openai-codex') && rawResponseId && typeof rawResponseId === 'string') { + const responseId = rawResponseId; + + // Only capture if we don't already have one (from options or memory) + const existingId = options?.responsesApiId || + (options?.conversationId ? this.conversationResponseIds.get(options.conversationId) : undefined); + + if (!existingId) { + // Store in memory for this session + if (options?.conversationId) { + this.conversationResponseIds.set(options.conversationId, responseId); + } + // Notify caller to persist to conversation metadata + if (options?.onResponsesApiId) { + options.onResponsesApiId(responseId); + } + } + } + break; } } + } catch (error) { + // On Codex 429, fall back to standard OpenAI adapter if available + if ( + error instanceof LLMProviderError && + error.code === 'RATE_LIMIT_ERROR' && + error.provider === 'openai-codex' + ) { + const fallbackAdapter = this.adapterRegistry.getAdapter('openai'); + if (fallbackAdapter) { + new Notice('ChatGPT rate limit reached — falling back to OpenAI API'); + activeAdapter = fallbackAdapter; + activeProvider = 'openai'; + + // Reset streaming state for retry + fullContent = ''; + detectedToolCalls = []; + finalUsage = undefined; + + for await (const chunk of fallbackAdapter.generateStreamAsync(promptToPass, generateOptions)) { + if (chunk.usage) { + finalUsage = chunk.usage; + } - if (chunk.complete) { - // Store response ID for future continuations (OpenAI uses Responses API) - const rawResponseId = chunk.metadata?.responseId; - if (provider === 'openai' && rawResponseId && typeof rawResponseId === 'string') { - const responseId = rawResponseId; + if (chunk.content) { + fullContent += chunk.content; + yield { + chunk: chunk.content, + complete: false, + content: fullContent, + toolCalls: undefined + }; + } + + if (chunk.reasoning) { + yield { + chunk: '', + complete: false, + content: fullContent, + toolCalls: undefined, + reasoning: chunk.reasoning, + reasoningComplete: chunk.reasoningComplete + }; + } - // Only capture if we don't already have one (from options or memory) - const existingId = options?.responsesApiId || - (options?.conversationId ? this.conversationResponseIds.get(options.conversationId) : undefined); + if (chunk.toolCalls) { + const chatToolCalls: ChatToolCall[] = chunk.toolCalls.map(tc => ({ + ...tc, + type: tc.type || 'function', + function: tc.function || { name: '', arguments: '{}' } + })); - if (!existingId) { - // Store in memory for this session - if (options?.conversationId) { - this.conversationResponseIds.set(options.conversationId, responseId); + yield { + chunk: '', + complete: false, + content: fullContent, + toolCalls: chatToolCalls, + toolCallsReady: chunk.complete || false + }; + + if (chunk.complete) { + detectedToolCalls = chatToolCalls; + } } - // Notify caller to persist to conversation metadata - if (options?.onResponsesApiId) { - options.onResponsesApiId(responseId); + + if (chunk.complete) { + const rawResponseId = chunk.metadata?.responseId; + if ((activeProvider === 'openai' || activeProvider === 'openai-codex') && rawResponseId && typeof rawResponseId === 'string') { + const responseId = rawResponseId; + const existingId = options?.responsesApiId || + (options?.conversationId ? this.conversationResponseIds.get(options.conversationId) : undefined); + if (!existingId) { + if (options?.conversationId) { + this.conversationResponseIds.set(options.conversationId, responseId); + } + if (options?.onResponsesApiId) { + options.onResponsesApiId(responseId); + } + } + } + break; } } + } else { + // No fallback available — re-throw original error + throw error; } - break; + } else { + throw error; } } @@ -182,8 +279,8 @@ export class StreamingOrchestrator { // Tool calls detected - delegate to ToolContinuationService yield* this.toolContinuation.executeToolsAndContinue( - adapter, - provider, + activeAdapter, + activeProvider, detectedToolCalls, previousMessages, userPrompt, diff --git a/src/services/llm/providers/ProviderManager.ts b/src/services/llm/providers/ProviderManager.ts index 419d633b..020e6123 100644 --- a/src/services/llm/providers/ProviderManager.ts +++ b/src/services/llm/providers/ProviderManager.ts @@ -270,6 +270,11 @@ export class LLMProviderManager { name: 'Perplexity', description: 'Web search-enabled models with real-time information and citations' }, + { + id: 'openai-codex', + name: 'ChatGPT (Codex)', + description: 'GPT models via ChatGPT subscription — free inference, requires OAuth sign-in' + }, { id: 'webllm', name: 'Nexus (Local)', @@ -295,6 +300,9 @@ export class LLMProviderManager { let hasApiKey = false; if (provider.id === 'webllm') { hasApiKey = true; // WebLLM doesn't need an API key + } else if (provider.id === 'openai-codex') { + // Codex uses OAuth — check for connected OAuth state with access token + hasApiKey = !!(config?.oauth?.connected && config?.apiKey); } else if (provider.id === 'ollama' || provider.id === 'lmstudio') { hasApiKey = !!(config?.apiKey && config.apiKey.trim()); } else { diff --git a/src/services/oauth/IOAuthProvider.ts b/src/services/oauth/IOAuthProvider.ts new file mode 100644 index 00000000..87e7e8ca --- /dev/null +++ b/src/services/oauth/IOAuthProvider.ts @@ -0,0 +1,126 @@ +/** + * IOAuthProvider.ts + * Location: src/services/oauth/IOAuthProvider.ts + * + * Defines the core interfaces for the OAuth 2.0 PKCE provider system. + * All OAuth providers implement IOAuthProvider. OAuthState is persisted + * on LLMProviderConfig to track connection status. OAuthResult is the + * transient result returned after a successful OAuth flow. + * + * Used by: OAuthService (orchestrates flows), provider implementations + * (OpenRouter, Codex), ProviderTypes.ts (re-exports OAuthState), and + * the settings UI (reads OAuthProviderConfig for button rendering). + */ + +/** + * Static configuration for an OAuth provider. Describes the provider's + * endpoints, port preference, scopes, and display metadata. Immutable + * after construction. + */ +export interface OAuthProviderConfig { + /** Matches SupportedProvider enum value (e.g., 'openrouter', 'openai-codex') */ + providerId: string; + /** Human-readable label for the UI (e.g., "OpenRouter", "ChatGPT (Experimental)") */ + displayName: string; + /** Authorization endpoint URL */ + authUrl: string; + /** Token exchange endpoint URL */ + tokenUrl: string; + /** Preferred localhost port for the OAuth callback server */ + preferredPort: number; + /** Path for the callback route (e.g., '/callback' or '/auth/callback') */ + callbackPath: string; + /** OAuth scopes to request (empty array if provider doesn't use scopes) */ + scopes: string[]; + /** Whether tokens are permanent API keys or expiring access tokens */ + tokenType: 'permanent-key' | 'expiring-token'; + /** OAuth client ID (empty string if none required, e.g., OpenRouter) */ + clientId: string; + /** If true, UI shows a consent/warning dialog before starting the flow */ + experimental?: boolean; + /** Warning text displayed in the consent dialog for experimental providers */ + experimentalWarning?: string; + /** Override hostname used in the redirect_uri (default: '127.0.0.1'). Server still binds to 127.0.0.1. */ + callbackHostname?: string; +} + +/** + * Transient result from a completed OAuth flow. Contains the API key + * (or access token) and optional refresh/expiry data for providers + * that issue expiring tokens. + */ +export interface OAuthResult { + /** The key or access token to store in LLMProviderConfig.apiKey */ + apiKey: string; + /** Refresh token for expiring-token providers (Codex) */ + refreshToken?: string; + /** Expiration timestamp in Unix milliseconds (for expiring-token providers) */ + expiresAt?: number; + /** Provider-specific metadata (e.g., { accountId, idToken }) */ + metadata?: Record; +} + +/** + * Contract for OAuth provider implementations. Each provider knows how + * to build its authorization URL and exchange an authorization code for + * tokens. Providers with expiring tokens also implement refreshToken(). + */ +export interface IOAuthProvider { + /** Static configuration describing this provider */ + readonly config: OAuthProviderConfig; + + /** + * Build the full authorization URL that opens in the user's browser. + * @param callbackUrl - The localhost callback URL (e.g., http://127.0.0.1:3000/callback) + * @param codeChallenge - Base64url-encoded S256 PKCE challenge + * @param state - Random state string for CSRF protection + * @param preAuthParams - Optional provider-specific params (e.g., key_label for OpenRouter) + * @returns Full authorization URL string + */ + buildAuthUrl( + callbackUrl: string, + codeChallenge: string, + state: string, + preAuthParams?: Record + ): string; + + /** + * Exchange an authorization code for tokens/API key. + * @param code - Authorization code from the callback + * @param codeVerifier - Original PKCE code verifier (never logged or persisted) + * @param callbackUrl - The callback URL used during authorization (must match exactly) + * @returns OAuthResult with the API key and optional token data + */ + exchangeCode( + code: string, + codeVerifier: string, + callbackUrl: string + ): Promise; + + /** + * Refresh an expired access token. Only implemented by providers with + * tokenType === 'expiring-token'. Returns null if refresh fails + * (user must re-authenticate). + */ + refreshToken?(refreshToken: string): Promise; +} + +/** + * Persistent OAuth state stored on LLMProviderConfig.oauth. Tracks + * whether a provider was connected via OAuth and holds token data + * needed for refresh flows. + */ +export interface OAuthState { + /** Whether this provider is currently OAuth-connected */ + connected: boolean; + /** The provider ID that was used for OAuth (e.g., 'openrouter') */ + providerId: string; + /** Timestamp (Unix ms) when the OAuth connection was established */ + connectedAt: number; + /** Refresh token for expiring-token providers */ + refreshToken?: string; + /** Token expiration timestamp (Unix ms) for expiring-token providers */ + expiresAt?: number; + /** Provider-specific metadata (e.g., accountId for Codex) */ + metadata?: Record; +} diff --git a/src/services/oauth/OAuthCallbackServer.ts b/src/services/oauth/OAuthCallbackServer.ts new file mode 100644 index 00000000..67fc59dc --- /dev/null +++ b/src/services/oauth/OAuthCallbackServer.ts @@ -0,0 +1,245 @@ +/** + * OAuthCallbackServer.ts + * Location: src/services/oauth/OAuthCallbackServer.ts + * + * Ephemeral localhost HTTP server that receives a single OAuth callback. + * Binds to 127.0.0.1 ONLY (not 'localhost', not 0.0.0.0) for security. + * Single-use: accepts one valid callback, then shuts down immediately. + * Auto-shuts down after a configurable timeout (default 5 minutes). + * + * Used by: OAuthService.ts (starts server before opening browser, + * waits for callback, then shuts down). + */ + +import { createServer, IncomingMessage, ServerResponse, Server } from 'node:http'; +import { URL } from 'node:url'; +import { timingSafeEqual } from 'node:crypto'; + +/** Common no-cache headers for all callback responses (prevents browser caching auth codes) */ +const NO_CACHE_HEADERS: Record = { + 'Cache-Control': 'no-store, no-cache', + 'Pragma': 'no-cache', +}; + +/** Result from a successful OAuth callback */ +export interface CallbackResult { + /** Authorization code from the OAuth provider */ + code: string; + /** State parameter for CSRF validation */ + state: string; +} + +/** Handle returned by OAuthCallbackServer.start() */ +export interface CallbackServerHandle { + /** The port the server is listening on */ + port: number; + /** Full callback URL (e.g., http://127.0.0.1:3000/callback) */ + callbackUrl: string; + /** Promise that resolves when a valid callback is received */ + waitForCallback(): Promise; + /** Force shutdown the server (idempotent) */ + shutdown(): void; +} + +/** Shared CSS for callback pages — respects system light/dark preference */ +const CALLBACK_STYLE = ` +`; + +/** Static HTML success page -- no dynamic content for security */ +const HTML_SUCCESS = ` + +Authorization Successful${CALLBACK_STYLE} + +
+

Connected!

+

You can close this tab and return to Obsidian.

+
+ + +`; + +/** Static HTML error page */ +const HTML_ERROR = ` + +Authorization Failed${CALLBACK_STYLE} + +
+

Authorization Failed

+

Something went wrong. Please close this tab and try again in Obsidian.

+
+ +`; + +/** + * Options for starting the callback server. + */ +export interface CallbackServerOptions { + /** Port to bind to */ + port: number; + /** URL path to listen for callbacks on (e.g., '/callback') */ + callbackPath: string; + /** Expected state parameter value for CSRF validation */ + expectedState: string; + /** Timeout in milliseconds before auto-shutdown (default: 300_000 = 5 minutes) */ + timeoutMs?: number; + /** Hostname used in the callbackUrl string (default: '127.0.0.1'). Server always binds to 127.0.0.1. */ + callbackUrlHostname?: string; +} + +/** + * Start an ephemeral OAuth callback server. + * + * @param options - Server configuration + * @returns Handle with port, callbackUrl, waitForCallback(), and shutdown() + * @throws Error with descriptive message on EADDRINUSE or other server errors + */ +export function startCallbackServer(options: CallbackServerOptions): Promise { + const { + port, + callbackPath, + expectedState, + timeoutMs = 300_000, + callbackUrlHostname = '127.0.0.1', + } = options; + + return new Promise((resolveStart, rejectStart) => { + let settled = false; + let callbackResolve: ((result: CallbackResult) => void) | null = null; + let callbackReject: ((error: Error) => void) | null = null; + let timeoutHandle: ReturnType | null = null; + let server: Server | null = null; + + const cleanup = () => { + if (timeoutHandle) { + clearTimeout(timeoutHandle); + timeoutHandle = null; + } + if (server) { + try { + server.close(); + } catch { + // Ignore close errors during cleanup + } + server = null; + } + }; + + const shutdown = () => { + if (!settled) { + settled = true; + if (callbackReject) { + callbackReject(new Error('OAuth callback server was shut down')); + } + } + cleanup(); + }; + + // Create the callback promise that callers await + const callbackPromise = new Promise((resolve, reject) => { + callbackResolve = resolve; + callbackReject = reject; + }); + + server = createServer((req: IncomingMessage, res: ServerResponse) => { + // Only handle GET requests to the callback path + const url = new URL(req.url || '/', `http://127.0.0.1:${port}`); + + if (url.pathname !== callbackPath) { + res.writeHead(404, { 'Content-Type': 'text/plain', ...NO_CACHE_HEADERS }); + res.end('Not found'); + return; + } + + // Check for OAuth error from provider + const error = url.searchParams.get('error'); + const errorDescription = url.searchParams.get('error_description'); + if (error) { + const msg = errorDescription || error; + res.writeHead(400, { 'Content-Type': 'text/html', ...NO_CACHE_HEADERS }); + res.end(HTML_ERROR); + if (!settled) { + settled = true; + callbackReject?.(new Error(`OAuth error: ${msg}`)); + cleanup(); + } + return; + } + + // Validate state parameter (CSRF protection via timing-safe comparison) + const state = url.searchParams.get('state') || ''; + const stateValid = + state.length === expectedState.length && + timingSafeEqual(Buffer.from(state), Buffer.from(expectedState)); + if (!stateValid) { + res.writeHead(400, { 'Content-Type': 'text/html', ...NO_CACHE_HEADERS }); + res.end(HTML_ERROR); + if (!settled) { + settled = true; + callbackReject?.(new Error('State mismatch: potential CSRF attack')); + cleanup(); + } + return; + } + + // Extract authorization code + const code = url.searchParams.get('code'); + if (!code) { + res.writeHead(400, { 'Content-Type': 'text/html', ...NO_CACHE_HEADERS }); + res.end(HTML_ERROR); + if (!settled) { + settled = true; + callbackReject?.(new Error('Missing authorization code in callback')); + cleanup(); + } + return; + } + + // Success: return static HTML and resolve the callback promise + res.writeHead(200, { 'Content-Type': 'text/html', ...NO_CACHE_HEADERS }); + res.end(HTML_SUCCESS); + + if (!settled) { + settled = true; + callbackResolve?.({ code, state }); + cleanup(); + } + }); + + // Handle server errors (including EADDRINUSE) + server.on('error', (err: NodeJS.ErrnoException) => { + cleanup(); + if (err.code === 'EADDRINUSE') { + rejectStart(new Error( + `Port ${port} is already in use. If MCP HTTP transport is running on this port, please use manual API key entry.` + )); + } else { + rejectStart(new Error(`OAuth callback server error: ${err.message}`)); + } + }); + + // Bind to 127.0.0.1 ONLY -- never 'localhost' (resolves to IPv6 on some systems), never 0.0.0.0 + server.listen(port, '127.0.0.1', () => { + // Set up timeout for auto-shutdown + timeoutHandle = setTimeout(() => { + if (!settled) { + settled = true; + callbackReject?.(new Error('OAuth callback timeout: authorization took too long')); + cleanup(); + } + }, timeoutMs); + + // Return the handle + resolveStart({ + port, + callbackUrl: `http://${callbackUrlHostname}:${port}${callbackPath}`, + waitForCallback: () => callbackPromise, + shutdown, + }); + }); + }); +} diff --git a/src/services/oauth/OAuthService.ts b/src/services/oauth/OAuthService.ts new file mode 100644 index 00000000..154570c9 --- /dev/null +++ b/src/services/oauth/OAuthService.ts @@ -0,0 +1,206 @@ +/** + * OAuthService.ts + * Location: src/services/oauth/OAuthService.ts + * + * Singleton service that orchestrates OAuth 2.0 PKCE flows. Maintains a + * registry of IOAuthProvider implementations and coordinates the full + * flow: PKCE generation, callback server, browser launch, code exchange. + * + * State machine: 'idle' -> 'authorizing' -> 'exchanging' -> 'idle' + * Only one flow can be active at a time. + * + * Used by: main.ts (registers providers at startup), settings UI + * (starts OAuth flows via startFlow()), LLM adapters (refreshes tokens + * via refreshToken()). + */ + +import { IOAuthProvider, OAuthProviderConfig, OAuthResult } from './IOAuthProvider'; +import { generateCodeVerifier, generateCodeChallenge, generateState } from './PKCEUtils'; +import { startCallbackServer, CallbackServerHandle } from './OAuthCallbackServer'; + +/** Current state of the OAuth service */ +export type OAuthFlowState = 'idle' | 'authorizing' | 'exchanging'; + +export class OAuthService { + private static instance: OAuthService; + private providers: Map = new Map(); + private state: OAuthFlowState = 'idle'; + private activeServerHandle: CallbackServerHandle | null = null; + + private constructor() { + // Private constructor for singleton pattern + } + + /** + * Get the singleton OAuthService instance. + */ + static getInstance(): OAuthService { + if (!OAuthService.instance) { + OAuthService.instance = new OAuthService(); + } + return OAuthService.instance; + } + + /** + * Register an OAuth provider implementation. + * @param provider - Provider implementing IOAuthProvider + */ + registerProvider(provider: IOAuthProvider): void { + this.providers.set(provider.config.providerId, provider); + } + + /** + * Check if a provider with the given ID is registered. + */ + hasProvider(providerId: string): boolean { + return this.providers.has(providerId); + } + + /** + * Get the static configuration for a registered provider. + * @returns Config or null if provider is not registered + */ + getProviderConfig(providerId: string): OAuthProviderConfig | null { + return this.providers.get(providerId)?.config ?? null; + } + + /** + * Get the current flow state. + */ + getState(): OAuthFlowState { + return this.state; + } + + /** + * Start an OAuth PKCE flow for the specified provider. + * + * Opens the user's browser to the provider's authorization page, + * starts a localhost callback server, waits for the callback, and + * exchanges the authorization code for tokens. + * + * @param providerId - ID of the registered provider + * @param preAuthParams - Optional provider-specific params (e.g., key_label for OpenRouter) + * @returns OAuthResult with the API key and optional token data + * @throws Error if provider not found, flow already active, or flow fails + */ + async startFlow( + providerId: string, + preAuthParams?: Record + ): Promise { + // Validate provider exists + const provider = this.providers.get(providerId); + if (!provider) { + throw new Error(`OAuth provider '${providerId}' is not registered`); + } + + // Guard against concurrent flows + if (this.state !== 'idle') { + throw new Error( + `Cannot start OAuth flow: another flow is already ${this.state}` + ); + } + + this.state = 'authorizing'; + + try { + // Generate PKCE parameters + const codeVerifier = generateCodeVerifier(); + const codeChallenge = await generateCodeChallenge(codeVerifier); + const state = generateState(); + + // Start the ephemeral callback server + const serverHandle = await startCallbackServer({ + port: provider.config.preferredPort, + callbackPath: provider.config.callbackPath, + expectedState: state, + callbackUrlHostname: provider.config.callbackHostname, + }); + this.activeServerHandle = serverHandle; + + // Build authorization URL and open browser + const authUrl = provider.buildAuthUrl( + serverHandle.callbackUrl, + codeChallenge, + state, + preAuthParams + ); + + // Open in system browser via Electron shell (preferred) or window.open (fallback) + try { + const { shell } = require('electron'); + shell.openExternal(authUrl); + } catch { + window.open(authUrl, '_blank'); + } + + // Wait for the callback + const callbackResult = await serverHandle.waitForCallback(); + + // Exchange the authorization code for tokens + this.state = 'exchanging'; + const oauthResult = await provider.exchangeCode( + callbackResult.code, + codeVerifier, + serverHandle.callbackUrl + ); + + return oauthResult; + } finally { + // Always clean up: shut down callback server and reset state + if (this.activeServerHandle) { + this.activeServerHandle.shutdown(); + this.activeServerHandle = null; + } + this.state = 'idle'; + } + } + + /** + * Cancel an in-progress OAuth flow. + * Shuts down the callback server if one is active. + */ + cancelFlow(): void { + if (this.activeServerHandle) { + this.activeServerHandle.shutdown(); + this.activeServerHandle = null; + } + this.state = 'idle'; + } + + /** + * Refresh an expired token for a provider. + * + * @param providerId - ID of the registered provider + * @param refreshToken - The current refresh token + * @returns New OAuthResult with fresh tokens, or null if refresh fails + * @throws Error if provider not found or doesn't support token refresh + */ + async refreshToken( + providerId: string, + refreshToken: string + ): Promise { + const provider = this.providers.get(providerId); + if (!provider) { + throw new Error(`OAuth provider '${providerId}' is not registered`); + } + + if (!provider.refreshToken) { + throw new Error( + `OAuth provider '${providerId}' does not support token refresh` + ); + } + + return provider.refreshToken(refreshToken); + } + + /** + * Reset the singleton instance. Useful for testing or plugin unload. + */ + static resetInstance(): void { + if (OAuthService.instance) { + OAuthService.instance.cancelFlow(); + OAuthService.instance.providers.clear(); + } + OAuthService.instance = undefined as unknown as OAuthService; + } +} diff --git a/src/services/oauth/PKCEUtils.ts b/src/services/oauth/PKCEUtils.ts new file mode 100644 index 00000000..ec76dddc --- /dev/null +++ b/src/services/oauth/PKCEUtils.ts @@ -0,0 +1,73 @@ +/** + * PKCEUtils.ts + * Location: src/services/oauth/PKCEUtils.ts + * + * Standalone pure functions for OAuth 2.0 PKCE (RFC 7636) cryptographic + * operations. All randomness uses crypto.getRandomValues() -- never + * Math.random(). Challenge method is always S256 (SHA-256). + * + * Exported as standalone functions (not class methods) for testability. + * + * Used by: OAuthService.ts (generates PKCE pairs before each OAuth flow) + * Tested by: tests/services/oauth/PKCEUtils.test.ts + */ + +/** + * Base64url-encode a buffer. Produces URL-safe output with no padding, + * per RFC 7636 Appendix A. + */ +export function base64url(buffer: ArrayBuffer | Uint8Array): string { + const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); + let binary = ''; + for (let i = 0; i < bytes.length; i++) { + binary += String.fromCharCode(bytes[i]); + } + return btoa(binary) + .replace(/\+/g, '-') + .replace(/\//g, '_') + .replace(/=+$/, ''); +} + +/** + * Generate a cryptographically random PKCE code verifier. + * Produces a 43-character string from the unreserved URI character set + * (A-Z, a-z, 0-9, -, ., _, ~) as specified in RFC 7636 Section 4.1. + * + * Uses crypto.getRandomValues() for secure randomness. + */ +export function generateCodeVerifier(): string { + const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~'; + const bytes = new Uint8Array(43); + crypto.getRandomValues(bytes); + let verifier = ''; + for (let i = 0; i < bytes.length; i++) { + verifier += chars[bytes[i] % chars.length]; + } + return verifier; +} + +/** + * Derive a PKCE code challenge from a code verifier using the S256 method. + * Computes SHA-256 hash of the verifier and base64url-encodes the result. + * + * @param verifier - The code verifier string to hash + * @returns Base64url-encoded SHA-256 hash of the verifier + */ +export async function generateCodeChallenge(verifier: string): Promise { + const encoder = new TextEncoder(); + const data = encoder.encode(verifier); + const hash = await crypto.subtle.digest('SHA-256', data); + return base64url(hash); +} + +/** + * Generate a cryptographically random state parameter for CSRF protection. + * Produces a 32-byte random value, base64url-encoded (~43 characters). + * + * Uses crypto.getRandomValues() for secure randomness. + */ +export function generateState(): string { + const bytes = new Uint8Array(32); + crypto.getRandomValues(bytes); + return base64url(bytes); +} diff --git a/src/services/oauth/index.ts b/src/services/oauth/index.ts new file mode 100644 index 00000000..8d5bc588 --- /dev/null +++ b/src/services/oauth/index.ts @@ -0,0 +1,34 @@ +/** + * OAuth Service Barrel Export + * Location: src/services/oauth/index.ts + * + * Re-exports all OAuth service types, utilities, and providers for + * convenient importing from a single path. + */ + +export type { + OAuthProviderConfig, + OAuthResult, + IOAuthProvider, + OAuthState, +} from './IOAuthProvider'; + +export { + base64url, + generateCodeVerifier, + generateCodeChallenge, + generateState, +} from './PKCEUtils'; + +export type { + CallbackResult, + CallbackServerHandle, + CallbackServerOptions, +} from './OAuthCallbackServer'; +export { startCallbackServer } from './OAuthCallbackServer'; + +export type { OAuthFlowState } from './OAuthService'; +export { OAuthService } from './OAuthService'; + +export { OpenRouterOAuthProvider } from './providers/OpenRouterOAuthProvider'; +export { OpenAICodexOAuthProvider } from './providers/OpenAICodexOAuthProvider'; diff --git a/src/services/oauth/providers/OpenAICodexOAuthProvider.ts b/src/services/oauth/providers/OpenAICodexOAuthProvider.ts new file mode 100644 index 00000000..5de6f484 --- /dev/null +++ b/src/services/oauth/providers/OpenAICodexOAuthProvider.ts @@ -0,0 +1,234 @@ +/** + * OpenAICodexOAuthProvider.ts + * Location: src/services/oauth/providers/OpenAICodexOAuthProvider.ts + * + * OAuth 2.0 PKCE provider for OpenAI Codex (ChatGPT Plus/Pro). + * Uses the same client ID and endpoints as Cline, OpenCode, and Roo Code. + * Tokens are expiring (access_token + refresh_token); the adapter must + * proactively refresh before each API call. + * + * Used by: OAuthService (registered at startup via main.ts) + * Reference: docs/preparation/opencode-oauth-source-analysis.md + * Validated: /tmp/codex-oauth-test/test-codex-oauth.mjs + */ + +import { IOAuthProvider, OAuthProviderConfig, OAuthResult } from '../IOAuthProvider'; + +/** Public OAuth client ID shared across Codex CLI tools (Cline, OpenCode, Roo Code) */ +const CLIENT_ID = 'app_EMoamEEZ73f0CkXaXp7hrann'; + +/** OpenAI OAuth issuer */ +const ISSUER = 'https://auth.openai.com'; + +/** Authorization endpoint */ +const AUTH_ENDPOINT = `${ISSUER}/oauth/authorize`; + +/** Token exchange and refresh endpoint */ +const TOKEN_ENDPOINT = `${ISSUER}/oauth/token`; + +/** + * JWT claims structure from OpenAI id_token / access_token. + * Used to extract the chatgpt_account_id needed for API calls. + */ +interface IdTokenClaims { + chatgpt_account_id?: string; + organizations?: Array<{ id: string }>; + email?: string; + 'https://api.openai.com/auth'?: { + chatgpt_account_id?: string; + }; +} + +/** + * Token response from OpenAI OAuth token endpoint. + */ +interface TokenResponse { + access_token: string; + refresh_token: string; + id_token: string; + expires_in?: number; +} + +/** + * Parse JWT claims from a token without signature verification. + * Only used to extract metadata (account ID) -- not for auth decisions. + */ +function parseJwtClaims(token: string): IdTokenClaims | undefined { + const parts = token.split('.'); + if (parts.length !== 3) return undefined; + try { + const payload = parts[1]; + // Convert base64url to base64, then decode + const base64 = payload.replace(/-/g, '+').replace(/_/g, '/'); + const padded = base64 + '='.repeat((4 - (base64.length % 4)) % 4); + const decoded = atob(padded); + return JSON.parse(decoded); + } catch { + return undefined; + } +} + +/** + * Extract the ChatGPT account ID from JWT claims. + * Checks multiple claim locations (direct, nested, organization fallback). + */ +function extractAccountIdFromClaims(claims: IdTokenClaims): string | undefined { + return ( + claims.chatgpt_account_id || + claims['https://api.openai.com/auth']?.chatgpt_account_id || + claims.organizations?.[0]?.id + ); +} + +/** + * Extract account ID from token response, trying id_token first, + * then falling back to access_token. + */ +function extractAccountId(tokens: TokenResponse): string | undefined { + if (tokens.id_token) { + const claims = parseJwtClaims(tokens.id_token); + if (claims) { + const accountId = extractAccountIdFromClaims(claims); + if (accountId) return accountId; + } + } + if (tokens.access_token) { + const claims = parseJwtClaims(tokens.access_token); + if (claims) { + return extractAccountIdFromClaims(claims); + } + } + return undefined; +} + +/** + * Convert a TokenResponse into an OAuthResult. + */ +function tokenResponseToResult(tokens: TokenResponse): OAuthResult { + const accountId = extractAccountId(tokens); + const expiresIn = tokens.expires_in ?? 3600; + + const result: OAuthResult = { + apiKey: tokens.access_token, + refreshToken: tokens.refresh_token, + expiresAt: Date.now() + expiresIn * 1000, + }; + + // Store account ID in metadata; do NOT persist id_token (contains email PII) + if (accountId) { + result.metadata = { accountId }; + } + + return result; +} + +export class OpenAICodexOAuthProvider implements IOAuthProvider { + readonly config: OAuthProviderConfig = { + providerId: 'openai-codex', + displayName: 'ChatGPT', + authUrl: AUTH_ENDPOINT, + tokenUrl: TOKEN_ENDPOINT, + preferredPort: 1455, + callbackPath: '/auth/callback', + scopes: ['openid', 'profile', 'email', 'offline_access'], + tokenType: 'expiring-token', + clientId: CLIENT_ID, + callbackHostname: 'localhost', + }; + + /** + * Build the OpenAI Codex authorization URL. + * + * Includes all parameters from the validated spike: client_id, + * response_type=code, redirect_uri, scope, state, code_challenge, + * code_challenge_method=S256, prompt=login, and Codex-specific flags. + */ + buildAuthUrl( + callbackUrl: string, + codeChallenge: string, + state: string + ): string { + const params = new URLSearchParams({ + response_type: 'code', + client_id: CLIENT_ID, + redirect_uri: callbackUrl, + scope: 'openid profile email offline_access', + code_challenge: codeChallenge, + code_challenge_method: 'S256', + state, + prompt: 'login', + id_token_add_organizations: 'true', + codex_cli_simplified_flow: 'true', + originator: 'opencode', + }); + + return `${AUTH_ENDPOINT}?${params.toString()}`; + } + + /** + * Exchange the authorization code for tokens. + * + * POST form-urlencoded to the token endpoint with grant_type, + * client_id, code, redirect_uri, and code_verifier. Returns + * access_token, refresh_token, id_token, and expires_in. + */ + async exchangeCode( + code: string, + codeVerifier: string, + callbackUrl: string + ): Promise { + const response = await fetch(TOKEN_ENDPOINT, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: new URLSearchParams({ + grant_type: 'authorization_code', + client_id: CLIENT_ID, + code, + redirect_uri: callbackUrl, + code_verifier: codeVerifier, + }).toString(), + }); + + if (!response.ok) { + const body = await response.text(); + throw new Error( + `Codex token exchange failed: HTTP ${response.status} - ${body}` + ); + } + + const tokens: TokenResponse = await response.json(); + return tokenResponseToResult(tokens); + } + + /** + * Refresh an expired access token. + * + * POST form-urlencoded with grant_type=refresh_token. Returns a new + * set of tokens (including a new refresh_token -- token rotation). + * Returns null if refresh fails (user must re-authenticate). + */ + async refreshToken(refreshToken: string): Promise { + try { + const response = await fetch(TOKEN_ENDPOINT, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: new URLSearchParams({ + grant_type: 'refresh_token', + client_id: CLIENT_ID, + refresh_token: refreshToken, + }).toString(), + }); + + if (!response.ok) { + // Refresh failed -- user must re-authenticate + return null; + } + + const tokens: TokenResponse = await response.json(); + return tokenResponseToResult(tokens); + } catch { + // Network error or other failure -- user must re-authenticate + return null; + } + } +} diff --git a/src/services/oauth/providers/OpenRouterOAuthProvider.ts b/src/services/oauth/providers/OpenRouterOAuthProvider.ts new file mode 100644 index 00000000..9e7f523c --- /dev/null +++ b/src/services/oauth/providers/OpenRouterOAuthProvider.ts @@ -0,0 +1,107 @@ +/** + * OpenRouterOAuthProvider.ts + * Location: src/services/oauth/providers/OpenRouterOAuthProvider.ts + * + * OAuth 2.0 PKCE provider for OpenRouter. The flow produces a permanent + * API key (sk-or-...) identical to manually created keys -- no token + * refresh needed. Supports optional pre-auth params for key_label and + * credit_limit. + * + * Used by: OAuthService (registered at startup via main.ts) + * Reference: docs/preparation/openrouter-oauth-research.md + */ + +import { IOAuthProvider, OAuthProviderConfig, OAuthResult } from '../IOAuthProvider'; + +/** OpenRouter authorization page */ +const AUTH_URL = 'https://openrouter.ai/auth'; + +/** Token exchange endpoint -- returns a permanent API key */ +const TOKEN_URL = 'https://openrouter.ai/api/v1/auth/keys'; + +export class OpenRouterOAuthProvider implements IOAuthProvider { + readonly config: OAuthProviderConfig = { + providerId: 'openrouter', + displayName: 'OpenRouter', + authUrl: AUTH_URL, + tokenUrl: TOKEN_URL, + preferredPort: 3456, + callbackPath: '/callback', + scopes: [], + tokenType: 'permanent-key', + clientId: '', + }; + + /** + * Build the OpenRouter authorization URL. + * + * OpenRouter uses a simplified OAuth flow: the auth page accepts + * callback_url, code_challenge, code_challenge_method, and state. + * Optional preAuthParams support key_label (name for the key in + * OpenRouter dashboard) and credit_limit (spending cap in USD). + */ + buildAuthUrl( + callbackUrl: string, + codeChallenge: string, + state: string, + preAuthParams?: Record + ): string { + const params = new URLSearchParams({ + callback_url: callbackUrl, + code_challenge: codeChallenge, + code_challenge_method: 'S256', + state, + }); + + // Add optional pre-auth parameters + if (preAuthParams?.key_label) { + params.set('key_label', preAuthParams.key_label); + } + if (preAuthParams?.credit_limit) { + params.set('limit', preAuthParams.credit_limit); + } + + return `${AUTH_URL}?${params.toString()}`; + } + + /** + * Exchange the authorization code for a permanent OpenRouter API key. + * + * POST to /api/v1/auth/keys with the code, code_verifier, and + * code_challenge_method. Returns { key: "sk-or-..." }. + */ + async exchangeCode( + code: string, + codeVerifier: string, + _callbackUrl: string + ): Promise { + const response = await fetch(TOKEN_URL, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + code, + code_verifier: codeVerifier, + code_challenge_method: 'S256', + }), + }); + + if (!response.ok) { + const body = await response.text(); + throw new Error( + `OpenRouter token exchange failed: HTTP ${response.status} - ${body}` + ); + } + + const data: { key: string } = await response.json(); + + if (!data.key) { + throw new Error('OpenRouter token exchange returned no key'); + } + + return { + apiKey: data.key, + }; + } + + // No refreshToken needed -- OpenRouter keys are permanent +} diff --git a/src/settings/tabs/DefaultsTab.ts b/src/settings/tabs/DefaultsTab.ts index 392edd9a..0252bd05 100644 --- a/src/settings/tabs/DefaultsTab.ts +++ b/src/settings/tabs/DefaultsTab.ts @@ -78,7 +78,7 @@ export class DefaultsTab { const llmSettings = this.services.llmProviderSettings; const pluginSettings = this.services.settings.settings; - return { + const result = { provider: llmSettings?.defaultModel?.provider || '', model: llmSettings?.defaultModel?.model || '', agentProvider: llmSettings?.agentModel?.provider || undefined, @@ -98,6 +98,7 @@ export class DefaultsTab { promptId: pluginSettings.defaultPromptId || null, contextNotes: pluginSettings.defaultContextNotes || [] }; + return result; } /** @@ -228,6 +229,7 @@ export class DefaultsTab { * Cleanup */ destroy(): void { + this.renderer?.destroy(); this.renderer = null; } } diff --git a/src/settings/tabs/ProvidersTab.ts b/src/settings/tabs/ProvidersTab.ts index ae9cf0a7..460b9eb1 100644 --- a/src/settings/tabs/ProvidersTab.ts +++ b/src/settings/tabs/ProvidersTab.ts @@ -19,6 +19,8 @@ import { Settings } from '../../settings'; import { Card, CardConfig } from '../../components/Card'; import { LLMSettingsNotifier } from '../../services/llm/LLMSettingsNotifier'; import { isDesktop, supportsLocalLLM, MOBILE_COMPATIBLE_PROVIDERS, isProviderComingSoon } from '../../utils/platform'; +import type { OAuthModalConfig, SecondaryOAuthProviderConfig } from '../../components/llm-provider/types'; +import { OAuthService } from '../../services/oauth/OAuthService'; /** * Provider display configuration @@ -28,6 +30,7 @@ interface ProviderDisplayConfig { keyFormat: string; signupUrl: string; category: 'local' | 'cloud'; + oauthConfig?: OAuthModalConfig; } export interface ProvidersTabServices { @@ -115,6 +118,12 @@ export class ProvidersTab { keyFormat: 'pplx-...', signupUrl: 'https://www.perplexity.ai/settings/api', category: 'cloud' + }, + 'openai-codex': { + name: 'ChatGPT (Codex)', + keyFormat: 'OAuth sign-in required', + signupUrl: 'https://chatgpt.com', + category: 'cloud' } }; @@ -137,9 +146,86 @@ export class ProvidersTab { }, this.services.app.vault); } + // Attach OAuth configs to providers that support it (desktop only) + if (isDesktop()) { + this.attachOAuthConfigs(); + } + this.render(); } + /** + * Attach OAuth configurations to providers that support OAuth connect. + * Only called on desktop where the OAuth callback server can run. + */ + private attachOAuthConfigs(): void { + const oauthService = OAuthService.getInstance(); + + // OpenRouter OAuth + if (oauthService.hasProvider('openrouter')) { + this.providerConfigs.openrouter.oauthConfig = { + providerLabel: 'OpenRouter', + preAuthFields: [ + { + key: 'key_name', + label: 'Key label', + defaultValue: 'Claudesidian MCP', + required: false, + }, + { + key: 'limit', + label: 'Credit limit (optional)', + placeholder: 'Leave blank for unlimited', + required: false, + }, + ], + startFlow: (params) => this.startOAuthFlow('openrouter', params), + }; + } + + // OpenAI Codex OAuth (experimental) — attaches to 'openai-codex' provider card, + // NOT 'openai', so tokens are stored under providers['openai-codex'] where + // AdapterRegistry.initializeCodexAdapter() reads them. + if (oauthService.hasProvider('openai-codex')) { + this.providerConfigs['openai-codex'] = { + ...this.providerConfigs['openai-codex'], + oauthConfig: { + providerLabel: 'ChatGPT', + startFlow: (params) => this.startOAuthFlow('openai-codex', params), + }, + }; + } + } + + /** + * Start an OAuth flow for a given provider via OAuthService + */ + private async startOAuthFlow( + providerId: string, + params: Record, + ): Promise<{ success: boolean; apiKey?: string; refreshToken?: string; expiresAt?: number; metadata?: Record; error?: string }> { + try { + const oauthService = OAuthService.getInstance(); + // Cancel any stuck flow before starting a new one (e.g., user dismissed modal while connecting) + if (oauthService.getState() !== 'idle') { + oauthService.cancelFlow(); + } + const result = await oauthService.startFlow(providerId, params); + return { + success: true, + apiKey: result.apiKey, + refreshToken: result.refreshToken, + expiresAt: result.expiresAt, + metadata: result.metadata, + }; + } catch (error) { + return { + success: false, + error: error instanceof Error ? error.message : 'OAuth flow failed', + }; + } + } + /** * Get current LLM settings */ @@ -281,12 +367,37 @@ export class ProvidersTab { ): void { const settings = this.getSettings(); + // Build secondary OAuth provider config for OpenAI (Codex sub-section) + let secondaryOAuthProvider: SecondaryOAuthProviderConfig | undefined; + if (providerId === 'openai') { + const codexDisplay = this.providerConfigs['openai-codex']; + if (codexDisplay?.oauthConfig) { + const codexConfig = settings.providers['openai-codex'] || { + apiKey: '', + enabled: false, + }; + secondaryOAuthProvider = { + providerId: 'openai-codex', + providerLabel: 'ChatGPT (Codex)', + description: 'Connect your ChatGPT Plus/Pro account to use GPT-5 models via OAuth.', + config: { ...codexConfig }, + oauthConfig: codexDisplay.oauthConfig, + onConfigChange: async (updatedCodexConfig: LLMProviderConfig) => { + settings.providers['openai-codex'] = updatedCodexConfig; + await this.saveSettings(); + }, + }; + } + } + const modalConfig: LLMProviderModalConfig = { providerId, providerName: displayConfig.name, keyFormat: displayConfig.keyFormat, signupUrl: displayConfig.signupUrl, config: { ...providerConfig }, + oauthConfig: displayConfig.oauthConfig, + secondaryOAuthProvider, onSave: async (updatedConfig: LLMProviderConfig) => { settings.providers[providerId] = updatedConfig; diff --git a/src/types/llm/ProviderTypes.ts b/src/types/llm/ProviderTypes.ts index 454035f6..997a793e 100644 --- a/src/types/llm/ProviderTypes.ts +++ b/src/types/llm/ProviderTypes.ts @@ -3,6 +3,8 @@ * Extracted from types.ts for better organization and maintainability */ +import type { OAuthState } from '../../services/oauth/IOAuthProvider'; + /** * Thinking effort levels - unified across all providers */ @@ -46,6 +48,8 @@ export interface LLMProviderConfig { // WebLLM-specific settings webllmModel?: string; // Selected WebLLM model (e.g., 'nexus-tools-q4f16') webllmQuantization?: 'q4f16' | 'q5f16' | 'q8f16'; // Quantization level + // OAuth connection state (set when provider connected via OAuth flow) + oauth?: OAuthState; } /** @@ -119,6 +123,10 @@ export const DEFAULT_LLM_PROVIDER_SETTINGS: LLMProviderSettings = { apiKey: '', enabled: false }, + 'openai-codex': { + apiKey: '', + enabled: false + }, ollama: { apiKey: 'http://127.0.0.1:11434', enabled: false, diff --git a/src/types/llm/index.ts b/src/types/llm/index.ts index 92eab56a..a0749185 100644 --- a/src/types/llm/index.ts +++ b/src/types/llm/index.ts @@ -10,6 +10,9 @@ export type { LLMProviderSettings } from './ProviderTypes'; +// Re-export OAuthState so consumers can import from the types barrel +export type { OAuthState } from '../../services/oauth/IOAuthProvider'; + export { DEFAULT_LLM_PROVIDER_SETTINGS } from './ProviderTypes'; diff --git a/src/ui/chat/components/ChatSettingsModal.ts b/src/ui/chat/components/ChatSettingsModal.ts index 84f07ced..2ddc3c58 100644 --- a/src/ui/chat/components/ChatSettingsModal.ts +++ b/src/ui/chat/components/ChatSettingsModal.ts @@ -221,6 +221,7 @@ export class ChatSettingsModal extends Modal { } onClose() { + this.renderer?.destroy(); this.renderer = null; this.pendingSettings = null; this.contentEl.empty(); diff --git a/src/utils/connectorContent.ts b/src/utils/connectorContent.ts index 9065ac43..66b79b58 100644 --- a/src/utils/connectorContent.ts +++ b/src/utils/connectorContent.ts @@ -5,7 +5,7 @@ * DO NOT EDIT MANUALLY - This file is regenerated during the build process. * To update, modify connector.ts and rebuild. * - * Generated: 2026-02-11T19:48:52.933Z + * Generated: 2026-02-22T17:12:41.338Z */ export const CONNECTOR_JS_CONTENT = `"use strict"; diff --git a/src/utils/platform.ts b/src/utils/platform.ts index 75c436c6..271dbbe4 100644 --- a/src/utils/platform.ts +++ b/src/utils/platform.ts @@ -173,6 +173,7 @@ export const MOBILE_COMPATIBLE_PROVIDERS = [ * These use official SDK packages that have Node.js dependencies. */ export const DESKTOP_ONLY_PROVIDERS = [ + 'openai-codex', // Uses OAuth/JWT - desktop only 'openai', // Uses openai SDK 'anthropic', // Uses @anthropic-ai/sdk 'google', // Uses @google/genai diff --git a/styles.css b/styles.css index ea7b3e6d..2f876d99 100644 --- a/styles.css +++ b/styles.css @@ -6800,3 +6800,80 @@ body.is-mobile .chat-loading-overlay { overflow-y: hidden; } +/* ═══════════════════════════════════════════════════════════════════════ + OAuth Connect UI + ═══════════════════════════════════════════════════════════════════════ */ + +/* Connected banner above API key input */ +.oauth-banner-container { + margin-bottom: 4px; +} + +.oauth-connected-banner { + display: flex; + align-items: center; + justify-content: space-between; + padding: 8px 12px; + border-radius: var(--radius-s); + background-color: rgba(34, 197, 94, 0.1); + border: 1px solid rgba(34, 197, 94, 0.3); +} + +.oauth-connected-status { + color: var(--text-normal); + font-size: var(--font-ui-small); + font-weight: 500; +} + +.oauth-connected-status::before { + content: "\2713 "; + color: #22c55e; + font-weight: bold; +} + +.oauth-disconnect-btn { + font-size: var(--font-ui-smaller); + color: var(--text-muted); + background: none; + border: 1px solid var(--background-modifier-border); + border-radius: var(--radius-s); + padding: 2px 8px; + cursor: pointer; +} + +.oauth-disconnect-btn:hover { + color: var(--text-normal); + border-color: var(--text-muted); +} + +/* Connecting state */ +.oauth-connecting { + opacity: 0.6; + cursor: wait; +} + +/* Consent modal */ +.oauth-consent-modal, +.oauth-preauth-modal { + max-width: 480px; +} + +.oauth-consent-warning { + color: var(--text-muted); + line-height: 1.5; + margin-bottom: 16px; +} + +.oauth-consent-fields, +.oauth-preauth-fields { + margin-bottom: 16px; +} + +.oauth-consent-buttons, +.oauth-preauth-buttons { + display: flex; + justify-content: flex-end; + gap: 8px; + margin-top: 16px; +} + diff --git a/tests/unit/OAuthCallbackServer.test.ts b/tests/unit/OAuthCallbackServer.test.ts new file mode 100644 index 00000000..9ffc9ffe --- /dev/null +++ b/tests/unit/OAuthCallbackServer.test.ts @@ -0,0 +1,268 @@ +/** + * OAuthCallbackServer Unit Tests + * + * Tests the ephemeral localhost HTTP server that receives OAuth callbacks. + * Uses randomized ephemeral ports (49152-65535) to avoid conflicts. + * + * NOTE: Ideally these tests would use port 0 (OS-assigned) but the source + * OAuthCallbackServer returns the input port in the handle, not the actual + * bound port. Until the source is updated to use server.address().port, + * we use randomized high ports from the IANA ephemeral range. + */ + +import http from 'node:http'; +import { startCallbackServer } from '../../src/services/oauth/OAuthCallbackServer'; + +// Use randomized ports from the IANA ephemeral range to avoid cross-run conflicts +function nextPort(): number { + return 49152 + Math.floor(Math.random() * 16383); +} + +/** Helper: make a GET request to a URL */ +function makeRequest(url: string): Promise<{ statusCode: number; body: string }> { + return new Promise((resolve, reject) => { + http.get(url, (res) => { + let body = ''; + res.on('data', (chunk) => (body += chunk)); + res.on('end', () => resolve({ statusCode: res.statusCode!, body })); + }).on('error', reject); + }); +} + +describe('OAuthCallbackServer', () => { + describe('start and listen', () => { + it('should start successfully and return a handle with correct callbackUrl', async () => { + const port = nextPort(); + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState: 'test-state', + }); + + expect(handle).toBeDefined(); + expect(handle.port).toBe(port); + expect(handle.callbackUrl).toBe(`http://127.0.0.1:${port}/callback`); + expect(typeof handle.waitForCallback).toBe('function'); + expect(typeof handle.shutdown).toBe('function'); + + // Cleanup + handle.shutdown(); + await handle.waitForCallback().catch(() => {}); + }); + }); + + describe('happy path: valid callback', () => { + it('should resolve with code and state on valid callback', async () => { + const port = nextPort(); + const expectedState = 'valid-state-123'; + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState, + }); + + const callbackPromise = handle.waitForCallback(); + + const url = `http://127.0.0.1:${port}/callback?code=auth-code-xyz&state=${expectedState}`; + const response = await makeRequest(url); + + expect(response.statusCode).toBe(200); + expect(response.body).toContain('Connected!'); + + const result = await callbackPromise; + expect(result.code).toBe('auth-code-xyz'); + expect(result.state).toBe(expectedState); + }); + }); + + describe('error: state mismatch', () => { + it('should reject with CSRF error on state mismatch', async () => { + const port = nextPort(); + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState: 'expected-state', + }); + + let caughtError: Error | null = null; + const callbackPromise = handle.waitForCallback().catch((e: Error) => { caughtError = e; }); + + const url = `http://127.0.0.1:${port}/callback?code=some-code&state=wrong-state`; + const response = await makeRequest(url); + + expect(response.statusCode).toBe(400); + + await callbackPromise; + expect(caughtError).toBeDefined(); + expect(caughtError!.message).toContain('State mismatch'); + }); + }); + + describe('error: OAuth provider error', () => { + it('should reject with error description from provider', async () => { + const port = nextPort(); + const expectedState = 'state-abc'; + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState, + }); + + // Eagerly create a settled-safe promise + let caughtError: Error | null = null; + const callbackPromise = handle.waitForCallback().catch((e: Error) => { caughtError = e; }); + + const url = `http://127.0.0.1:${port}/callback?error=access_denied&error_description=User+denied+access&state=${expectedState}`; + const response = await makeRequest(url); + + expect(response.statusCode).toBe(400); + + await callbackPromise; + expect(caughtError).toBeDefined(); + expect(caughtError!.message).toContain('OAuth error: User denied access'); + }); + + it('should use error code when no description is provided', async () => { + const port = nextPort(); + const expectedState = 'state-def'; + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState, + }); + + let caughtError: Error | null = null; + const callbackPromise = handle.waitForCallback().catch((e: Error) => { caughtError = e; }); + + const url = `http://127.0.0.1:${port}/callback?error=server_error&state=${expectedState}`; + const response = await makeRequest(url); + + expect(response.statusCode).toBe(400); + + await callbackPromise; + expect(caughtError).toBeDefined(); + expect(caughtError!.message).toContain('OAuth error: server_error'); + }); + }); + + describe('error: missing code', () => { + it('should reject when authorization code is missing', async () => { + const port = nextPort(); + const expectedState = 'state-ghi'; + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState, + }); + + let caughtError: Error | null = null; + const callbackPromise = handle.waitForCallback().catch((e: Error) => { caughtError = e; }); + + const url = `http://127.0.0.1:${port}/callback?state=${expectedState}`; + const response = await makeRequest(url); + + expect(response.statusCode).toBe(400); + + await callbackPromise; + expect(caughtError).toBeDefined(); + expect(caughtError!.message).toContain('Missing authorization code'); + }); + }); + + describe('non-callback path', () => { + it('should return 404 for non-callback paths', async () => { + const port = nextPort(); + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState: 'state-jkl', + }); + + const url = `http://127.0.0.1:${port}/other-path`; + const response = await makeRequest(url); + + expect(response.statusCode).toBe(404); + expect(response.body).toBe('Not found'); + + // Cleanup + handle.shutdown(); + await handle.waitForCallback().catch(() => {}); + }); + }); + + describe('timeout', () => { + it('should reject with timeout error after configured timeout', async () => { + const port = nextPort(); + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState: 'state-timeout', + timeoutMs: 100, + }); + + let caughtError: Error | null = null; + await handle.waitForCallback().catch((e: Error) => { caughtError = e; }); + + expect(caughtError).toBeDefined(); + expect(caughtError!.message).toContain('OAuth callback timeout'); + }); + }); + + describe('shutdown', () => { + it('should reject callback promise when shut down before callback', async () => { + const port = nextPort(); + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState: 'state-shutdown', + }); + + let caughtError: Error | null = null; + const callbackPromise = handle.waitForCallback().catch((e: Error) => { caughtError = e; }); + + handle.shutdown(); + await callbackPromise; + + expect(caughtError).toBeDefined(); + expect(caughtError!.message).toContain('shut down'); + }); + + it('should be idempotent (calling shutdown twice is safe)', async () => { + const port = nextPort(); + const handle = await startCallbackServer({ + port, + callbackPath: '/callback', + expectedState: 'state-idempotent', + }); + + const callbackPromise = handle.waitForCallback().catch(() => {}); + + handle.shutdown(); + expect(() => handle.shutdown()).not.toThrow(); + + await callbackPromise; + }); + }); + + describe('EADDRINUSE', () => { + it('should reject with descriptive error when port is in use', async () => { + const port = nextPort(); + + // Occupy the port + const blockingServer = http.createServer(); + await new Promise((resolve) => blockingServer.listen(port, '127.0.0.1', resolve)); + + try { + await expect( + startCallbackServer({ + port, + callbackPath: '/callback', + expectedState: 'state-busy', + }) + ).rejects.toThrow(`Port ${port} is already in use`); + } finally { + blockingServer.close(); + } + }); + }); +}); diff --git a/tests/unit/OAuthModals.test.ts b/tests/unit/OAuthModals.test.ts new file mode 100644 index 00000000..09e56ac8 --- /dev/null +++ b/tests/unit/OAuthModals.test.ts @@ -0,0 +1,223 @@ +/** + * OAuthModals Unit Tests + * + * Tests the OAuth consent and pre-auth modals: + * - OAuthConsentModal: experimental warning display, confirm/cancel callbacks + * - OAuthPreAuthModal: field rendering, confirm/cancel callbacks + */ + +import { App } from 'obsidian'; +import { OAuthConsentModal, OAuthPreAuthModal } from '../../src/components/llm-provider/providers/OAuthModals'; +import type { OAuthModalConfig } from '../../src/components/llm-provider/types'; + +// Add addText method to Setting mock (OAuthModals use Setting.addText, not addTextArea) +jest.mock('obsidian', () => { + const actual = jest.requireActual('obsidian'); + + // Extend the Setting mock with addText + class SettingWithText extends actual.Setting { + addText(callback: (text: any) => void): SettingWithText { + const mockText = { + _value: '', + _onChange: null as ((value: string) => void) | null, + setPlaceholder() { return this; }, + setValue(v: string) { this._value = v; return this; }, + getValue() { return this._value; }, + onChange(cb: (value: string) => void) { this._onChange = cb; return this; }, + }; + callback(mockText); + return this; + } + } + + return { + ...actual, + Setting: SettingWithText, + }; +}); + +function createMockApp(): App { + return new App(); +} + +function createConsentConfig(overrides?: Partial): OAuthModalConfig { + return { + providerLabel: 'ChatGPT (Experimental)', + experimental: true, + experimentalWarning: 'This is an experimental feature.', + preAuthFields: [], + startFlow: jest.fn(async () => ({ success: true, apiKey: 'key-123' })), + ...overrides, + }; +} + +function createPreAuthConfig(overrides?: Partial): OAuthModalConfig { + return { + providerLabel: 'OpenRouter', + preAuthFields: [ + { + key: 'key_label', + label: 'Key Name', + placeholder: 'My Obsidian Key', + required: true, + defaultValue: '', + }, + { + key: 'credit_limit', + label: 'Credit Limit (USD)', + placeholder: '10', + required: false, + defaultValue: '', + }, + ], + startFlow: jest.fn(async () => ({ success: true, apiKey: 'or-key' })), + ...overrides, + }; +} + +describe('OAuthConsentModal', () => { + let app: App; + + beforeEach(() => { + app = createMockApp(); + }); + + it('should construct without errors', () => { + const config = createConsentConfig(); + const modal = new OAuthConsentModal(app, config, jest.fn(), jest.fn()); + expect(modal).toBeDefined(); + }); + + it('should call onOpen without errors', () => { + const config = createConsentConfig(); + const modal = new OAuthConsentModal(app, config, jest.fn(), jest.fn()); + expect(() => modal.onOpen()).not.toThrow(); + }); + + it('should call onClose without errors', () => { + const config = createConsentConfig(); + const modal = new OAuthConsentModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + expect(() => modal.onClose()).not.toThrow(); + }); + + it('should create heading element with "Experimental feature" text', () => { + const config = createConsentConfig(); + const modal = new OAuthConsentModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + + // The mock contentEl should have createEl called with 'h2' + expect(modal.contentEl.createEl).toHaveBeenCalledWith('h2', { + text: 'Experimental feature', + }); + }); + + it('should display experimental warning when provided', () => { + const config = createConsentConfig({ + experimentalWarning: 'This is risky!', + }); + const modal = new OAuthConsentModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + + expect(modal.contentEl.createEl).toHaveBeenCalledWith('p', { + text: 'This is risky!', + cls: 'oauth-consent-warning', + }); + }); + + it('should create button container and buttons', () => { + const config = createConsentConfig({ preAuthFields: [] }); + const modal = new OAuthConsentModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + + // Verify the button container div was created + expect(modal.contentEl.createDiv).toHaveBeenCalledWith('oauth-consent-buttons'); + + // Get the mock button container returned by createDiv + const createDivCalls = (modal.contentEl.createDiv as jest.Mock).mock.calls; + const buttonContainerCallIndex = createDivCalls.findIndex( + (call: any) => call[0] === 'oauth-consent-buttons' + ); + const buttonContainer = (modal.contentEl.createDiv as jest.Mock).mock.results[buttonContainerCallIndex].value; + + // Verify buttons were created on the container + const buttonCreateElCalls = (buttonContainer.createEl as jest.Mock).mock.calls; + const buttonCalls = buttonCreateElCalls.filter( + (call: any) => call[0] === 'button' + ); + expect(buttonCalls.length).toBe(2); // Cancel + Confirm + }); + + it('should add oauth-consent-modal class to contentEl', () => { + const config = createConsentConfig(); + const modal = new OAuthConsentModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + + expect(modal.contentEl.addClass).toHaveBeenCalledWith('oauth-consent-modal'); + }); +}); + +describe('OAuthPreAuthModal', () => { + let app: App; + + beforeEach(() => { + app = createMockApp(); + }); + + it('should construct without errors', () => { + const config = createPreAuthConfig(); + const modal = new OAuthPreAuthModal(app, config, jest.fn(), jest.fn()); + expect(modal).toBeDefined(); + }); + + it('should call onOpen without errors', () => { + const config = createPreAuthConfig(); + const modal = new OAuthPreAuthModal(app, config, jest.fn(), jest.fn()); + expect(() => modal.onOpen()).not.toThrow(); + }); + + it('should call onClose without errors', () => { + const config = createPreAuthConfig(); + const modal = new OAuthPreAuthModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + expect(() => modal.onClose()).not.toThrow(); + }); + + it('should create heading with provider name', () => { + const config = createPreAuthConfig({ providerLabel: 'OpenRouter' }); + const modal = new OAuthPreAuthModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + + expect(modal.contentEl.createEl).toHaveBeenCalledWith('h2', { + text: 'Connect with OpenRouter', + }); + }); + + it('should add oauth-preauth-modal class to contentEl', () => { + const config = createPreAuthConfig(); + const modal = new OAuthPreAuthModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + + expect(modal.contentEl.addClass).toHaveBeenCalledWith('oauth-preauth-modal'); + }); + + it('should create buttons container', () => { + const config = createPreAuthConfig(); + const modal = new OAuthPreAuthModal(app, config, jest.fn(), jest.fn()); + modal.onOpen(); + + expect(modal.contentEl.createDiv).toHaveBeenCalledWith('oauth-preauth-buttons'); + }); + + it('should handle empty preAuthFields gracefully', () => { + const config = createPreAuthConfig({ preAuthFields: [] }); + const modal = new OAuthPreAuthModal(app, config, jest.fn(), jest.fn()); + expect(() => modal.onOpen()).not.toThrow(); + }); + + it('should handle undefined preAuthFields gracefully', () => { + const config = createPreAuthConfig({ preAuthFields: undefined }); + const modal = new OAuthPreAuthModal(app, config, jest.fn(), jest.fn()); + expect(() => modal.onOpen()).not.toThrow(); + }); +}); diff --git a/tests/unit/OAuthService.test.ts b/tests/unit/OAuthService.test.ts new file mode 100644 index 00000000..1ed4101e --- /dev/null +++ b/tests/unit/OAuthService.test.ts @@ -0,0 +1,436 @@ +/** + * OAuthService Unit Tests + * + * Tests the singleton OAuth service orchestrating PKCE flows: + * - Provider registration and lookup + * - State machine transitions (idle -> authorizing -> exchanging -> idle) + * - Concurrent flow prevention + * - Token refresh delegation + * - Flow cancellation + */ + +import { OAuthService } from '../../src/services/oauth/OAuthService'; +import type { IOAuthProvider, OAuthProviderConfig } from '../../src/services/oauth/IOAuthProvider'; + +// Mock the callback server +jest.mock('../../src/services/oauth/OAuthCallbackServer', () => ({ + startCallbackServer: jest.fn(), +})); + +// Mock PKCEUtils +jest.mock('../../src/services/oauth/PKCEUtils', () => ({ + generateCodeVerifier: jest.fn(() => 'mock-verifier-12345678901234567890123'), + generateCodeChallenge: jest.fn(async () => 'mock-challenge-abc'), + generateState: jest.fn(() => 'mock-state-xyz'), +})); + +import { startCallbackServer } from '../../src/services/oauth/OAuthCallbackServer'; + +const mockStartCallbackServer = startCallbackServer as jest.MockedFunction; + +// Mock window.open for browser launch fallback (electron require fails in test env) +const mockWindowOpen = jest.fn(); +const originalWindow = (global as any).window; +(global as any).window = { open: mockWindowOpen }; + +/** Helper to wait for all microtasks and pending callbacks to drain */ +function tick(): Promise { + return new Promise(resolve => setImmediate(resolve)); +} + +function createMockProvider(overrides?: Partial): IOAuthProvider { + const config: OAuthProviderConfig = { + providerId: 'test-provider', + displayName: 'Test Provider', + authUrl: 'https://example.com/auth', + tokenUrl: 'https://example.com/token', + preferredPort: 3000, + callbackPath: '/callback', + scopes: ['read'], + tokenType: 'permanent-key', + clientId: 'test-client-id', + ...overrides, + }; + + return { + config, + buildAuthUrl: jest.fn(() => 'https://example.com/auth?params=test'), + exchangeCode: jest.fn(async () => ({ apiKey: 'test-api-key-123' })), + }; +} + +function createMockProviderWithRefresh(overrides?: Partial): IOAuthProvider { + const provider = createMockProvider({ + tokenType: 'expiring-token', + ...overrides, + }); + (provider as any).refreshToken = jest.fn(async () => ({ + apiKey: 'refreshed-token', + refreshToken: 'new-refresh-token', + expiresAt: Date.now() + 3600000, + })); + return provider; +} + +/** Simple mock server handle for success-path tests */ +function createSimpleServerHandle(code: string = 'auth-code') { + const mockShutdown = jest.fn(); + return { + handle: { + port: 3000, + callbackUrl: 'http://127.0.0.1:3000/callback', + waitForCallback: jest.fn(async () => ({ code, state: 'mock-state-xyz' })), + shutdown: mockShutdown, + }, + mockShutdown, + }; +} + +/** Controllable mock server handle for cancel/concurrent tests */ +function createControllableServerHandle() { + let resolveCallback!: (result: { code: string; state: string }) => void; + let rejectCallback!: (error: Error) => void; + let shutdownCalled = false; + const callbackPromise = new Promise<{ code: string; state: string }>((resolve, reject) => { + resolveCallback = resolve; + rejectCallback = reject; + }); + const mockShutdown = jest.fn(() => { + if (!shutdownCalled) { + shutdownCalled = true; + rejectCallback(new Error('OAuth callback server was shut down')); + } + }); + + return { + handle: { + port: 3000, + callbackUrl: 'http://127.0.0.1:3000/callback', + waitForCallback: () => callbackPromise, + shutdown: mockShutdown, + }, + resolveCallback, + rejectCallback, + mockShutdown, + }; +} + +describe('OAuthService', () => { + let service: OAuthService; + + beforeEach(() => { + OAuthService.resetInstance(); + service = OAuthService.getInstance(); + jest.clearAllMocks(); + }); + + afterEach(() => { + OAuthService.resetInstance(); + }); + + afterAll(() => { + // Restore original window to prevent mock leaking to other test files + if (originalWindow === undefined) { + delete (global as any).window; + } else { + (global as any).window = originalWindow; + } + }); + + describe('singleton', () => { + it('should return the same instance on subsequent calls', () => { + const instance1 = OAuthService.getInstance(); + const instance2 = OAuthService.getInstance(); + expect(instance1).toBe(instance2); + }); + + it('should create a new instance after resetInstance()', () => { + OAuthService.resetInstance(); + const instance2 = OAuthService.getInstance(); + expect(instance2.getState()).toBe('idle'); + }); + }); + + describe('provider registration', () => { + it('should register a provider', () => { + const provider = createMockProvider(); + service.registerProvider(provider); + expect(service.hasProvider('test-provider')).toBe(true); + }); + + it('should return false for unregistered provider', () => { + expect(service.hasProvider('nonexistent')).toBe(false); + }); + + it('should return provider config for registered provider', () => { + const provider = createMockProvider(); + service.registerProvider(provider); + const config = service.getProviderConfig('test-provider'); + expect(config).toBeDefined(); + expect(config!.displayName).toBe('Test Provider'); + }); + + it('should return null config for unregistered provider', () => { + const config = service.getProviderConfig('nonexistent'); + expect(config).toBeNull(); + }); + + it('should allow registering multiple providers', () => { + service.registerProvider(createMockProvider({ providerId: 'provider-a' })); + service.registerProvider(createMockProvider({ providerId: 'provider-b' })); + expect(service.hasProvider('provider-a')).toBe(true); + expect(service.hasProvider('provider-b')).toBe(true); + }); + }); + + describe('state machine', () => { + it('should start in idle state', () => { + expect(service.getState()).toBe('idle'); + }); + + it('should transition through idle -> authorizing -> exchanging -> idle on success', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + const states: string[] = []; + (provider.exchangeCode as jest.Mock).mockImplementation(async () => { + states.push(service.getState()); + return { apiKey: 'key-123' }; + }); + + const { handle } = createSimpleServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + await service.startFlow('test-provider'); + + expect(states).toContain('exchanging'); + expect(service.getState()).toBe('idle'); + }); + + it('should return to idle after flow cancellation', () => { + service.cancelFlow(); + expect(service.getState()).toBe('idle'); + }); + }); + + describe('startFlow', () => { + it('should throw if provider is not registered', async () => { + await expect(service.startFlow('nonexistent')).rejects.toThrow( + "OAuth provider 'nonexistent' is not registered" + ); + }); + + it('should prevent concurrent flows', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + const { handle } = createControllableServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + // Start first flow (won't complete). Eagerly attach catch to prevent unhandled rejection + const firstFlow = service.startFlow('test-provider').catch(() => {}); + await tick(); + + // Second flow should be rejected + await expect(service.startFlow('test-provider')).rejects.toThrow( + 'Cannot start OAuth flow: another flow is already authorizing' + ); + + // Clean up + service.cancelFlow(); + await firstFlow; + }); + + it('should return to idle state even if flow fails', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + mockStartCallbackServer.mockRejectedValue(new Error('Port in use')); + + await expect(service.startFlow('test-provider')).rejects.toThrow('Port in use'); + expect(service.getState()).toBe('idle'); + }); + + it('should call provider.buildAuthUrl with correct parameters', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + const { handle } = createSimpleServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + await service.startFlow('test-provider', { key_label: 'my-key' }); + + expect(provider.buildAuthUrl).toHaveBeenCalledWith( + 'http://127.0.0.1:3000/callback', + 'mock-challenge-abc', + 'mock-state-xyz', + { key_label: 'my-key' } + ); + }); + + it('should call provider.exchangeCode with code, verifier, and callbackUrl', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + const { handle } = createSimpleServerHandle('auth-code-999'); + mockStartCallbackServer.mockResolvedValue(handle); + + await service.startFlow('test-provider'); + + expect(provider.exchangeCode).toHaveBeenCalledWith( + 'auth-code-999', + 'mock-verifier-12345678901234567890123', + 'http://127.0.0.1:3000/callback' + ); + }); + + it('should return the OAuthResult from provider.exchangeCode', async () => { + const provider = createMockProvider(); + (provider.exchangeCode as jest.Mock).mockResolvedValue({ + apiKey: 'sk-or-final-key', + refreshToken: 'rt-123', + expiresAt: 9999999, + }); + service.registerProvider(provider); + + const { handle } = createSimpleServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + const result = await service.startFlow('test-provider'); + expect(result.apiKey).toBe('sk-or-final-key'); + expect(result.refreshToken).toBe('rt-123'); + }); + + it('should shut down callback server after successful flow', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + const { handle, mockShutdown } = createSimpleServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + await service.startFlow('test-provider'); + + expect(mockShutdown).toHaveBeenCalled(); + }); + + it('should shut down callback server on flow failure', async () => { + const provider = createMockProvider(); + (provider.exchangeCode as jest.Mock).mockRejectedValue(new Error('Exchange failed')); + service.registerProvider(provider); + + const { handle, mockShutdown } = createSimpleServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + await expect(service.startFlow('test-provider')).rejects.toThrow('Exchange failed'); + expect(mockShutdown).toHaveBeenCalled(); + }); + + it('should open browser with auth URL', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + const { handle } = createSimpleServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + await service.startFlow('test-provider'); + + expect(mockWindowOpen).toHaveBeenCalledWith( + 'https://example.com/auth?params=test', + '_blank' + ); + }); + }); + + describe('cancelFlow', () => { + it('should reset state to idle', () => { + service.cancelFlow(); + expect(service.getState()).toBe('idle'); + }); + + it('should shut down active callback server', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + const { handle, mockShutdown } = createControllableServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + // Start flow, eagerly handle rejection + const flowPromise = service.startFlow('test-provider').catch(() => {}); + await tick(); + + service.cancelFlow(); + + expect(mockShutdown).toHaveBeenCalled(); + expect(service.getState()).toBe('idle'); + + await flowPromise; + }); + }); + + describe('refreshToken', () => { + it('should throw if provider is not registered', async () => { + await expect(service.refreshToken('nonexistent', 'rt-123')).rejects.toThrow( + "OAuth provider 'nonexistent' is not registered" + ); + }); + + it('should throw if provider does not support refresh', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + await expect(service.refreshToken('test-provider', 'rt-123')).rejects.toThrow( + 'does not support token refresh' + ); + }); + + it('should delegate to provider.refreshToken', async () => { + const provider = createMockProviderWithRefresh(); + service.registerProvider(provider); + + const result = await service.refreshToken('test-provider', 'old-rt'); + + expect(provider.refreshToken).toHaveBeenCalledWith('old-rt'); + expect(result!.apiKey).toBe('refreshed-token'); + expect(result!.refreshToken).toBe('new-refresh-token'); + }); + + it('should return null when provider refresh returns null', async () => { + const provider = createMockProviderWithRefresh(); + (provider.refreshToken as jest.Mock).mockResolvedValue(null); + service.registerProvider(provider); + + const result = await service.refreshToken('test-provider', 'expired-rt'); + expect(result).toBeNull(); + }); + }); + + describe('resetInstance', () => { + it('should cancel any active flow', async () => { + const provider = createMockProvider(); + service.registerProvider(provider); + + const { handle, mockShutdown } = createControllableServerHandle(); + mockStartCallbackServer.mockResolvedValue(handle); + + const flowPromise = service.startFlow('test-provider').catch(() => {}); + await tick(); + + OAuthService.resetInstance(); + + expect(mockShutdown).toHaveBeenCalled(); + + await flowPromise; + }); + + it('should clear all registered providers', () => { + service.registerProvider(createMockProvider({ providerId: 'p1' })); + service.registerProvider(createMockProvider({ providerId: 'p2' })); + + OAuthService.resetInstance(); + + const newService = OAuthService.getInstance(); + expect(newService.hasProvider('p1')).toBe(false); + expect(newService.hasProvider('p2')).toBe(false); + }); + }); +}); diff --git a/tests/unit/OpenAICodexAdapter.test.ts b/tests/unit/OpenAICodexAdapter.test.ts new file mode 100644 index 00000000..7fa1aabb --- /dev/null +++ b/tests/unit/OpenAICodexAdapter.test.ts @@ -0,0 +1,982 @@ +/** + * OpenAICodexAdapter Unit Tests + * + * Tests the LLM adapter for the Codex inference endpoint: + * - Token management (fresh check, proactive refresh) + * - Request construction (headers, body format) + * - SSE stream parsing + * - Error handling + * - Token status diagnostics + */ + +import { OpenAICodexAdapter, CodexOAuthTokens, TokenPersistCallback } from '../../src/services/llm/adapters/openai-codex/OpenAICodexAdapter'; +import { EventEmitter } from 'events'; + +// --- https mock infrastructure --- + +/** + * Create a mock IncomingMessage (EventEmitter + async iterable of Buffers). + * Simulates a Node.js http.IncomingMessage for SSE streaming. + */ +function createMockIncomingMessage( + statusCode: number, + chunks: string[] +): { message: any; emit: () => void } { + const emitter = new EventEmitter(); + (emitter as any).statusCode = statusCode; + + // Make it async-iterable (Node.js IncomingMessage supports this) + (emitter as any)[Symbol.asyncIterator] = async function* () { + for (const chunk of chunks) { + yield Buffer.from(chunk); + } + }; + + const emitFn = () => { + // Defer emission so the consumer has time to register event listeners. + // The adapter sets up on('data')/on('end') after the Promise resolves, + // which happens in the same microtask as callback(message). Using + // setTimeout(0) pushes emission to the next macrotask. + setTimeout(() => { + for (const chunk of chunks) { + emitter.emit('data', Buffer.from(chunk)); + } + emitter.emit('end'); + }, 0); + }; + + return { message: emitter, emit: emitFn }; +} + +// Track all https.request calls for assertions +interface CapturedRequest { + options: any; + body: string; +} + +let capturedRequests: CapturedRequest[] = []; +let requestMockImpl: ((options: any, body: string) => { statusCode: number; chunks: string[] }) | null = null; + +// Mock the https module +jest.mock('https', () => ({ + request: jest.fn((options: any, callback: (res: any) => void) => { + const reqEmitter = new EventEmitter(); + let writtenBody = ''; + + (reqEmitter as any).write = (data: string) => { writtenBody += data; }; + (reqEmitter as any).setTimeout = jest.fn(); // Mock ClientRequest.setTimeout + (reqEmitter as any).destroy = jest.fn(); // Mock ClientRequest.destroy + (reqEmitter as any).end = () => { + capturedRequests.push({ options, body: writtenBody }); + + if (requestMockImpl) { + const { statusCode, chunks } = requestMockImpl(options, writtenBody); + const { message, emit } = createMockIncomingMessage(statusCode, chunks); + callback(message); + // Emit data/end events for all responses — the adapter uses event + // listeners (not async iteration) to read both SSE streams and error bodies + emit(); + } + }; + + return reqEmitter; + }), +})); + +// Mock ModelRegistry +jest.mock('../../src/services/llm/adapters/ModelRegistry', () => ({ + ModelRegistry: { + getProviderModels: jest.fn(() => [ + { + provider: 'openai-codex', + name: 'GPT-5.3 Codex', + apiName: 'gpt-5.3-codex', + contextWindow: 400000, + maxTokens: 128000, + inputCostPerMillion: 0, + outputCostPerMillion: 0, + }, + ]), + toModelInfo: jest.fn((model) => ({ + id: model.apiName, + name: model.name, + provider: model.provider, + contextWindow: model.contextWindow, + maxOutputTokens: model.maxTokens, + })), + }, +})); + +// Mock branding +jest.mock('../../../../src/constants/branding', () => ({ + BRAND_NAME: 'TestBrand', +}), { virtual: true }); + +// Helper: create valid tokens +function createTokens(overrides?: Partial): CodexOAuthTokens { + return { + accessToken: 'test-access-token', + refreshToken: 'test-refresh-token', + expiresAt: Date.now() + 3600_000, // 1 hour from now + accountId: 'acct-test-123', + ...overrides, + }; +} + +// Helper: set up a mock for the Codex API endpoint returning SSE events +function mockCodexSSE(events: string[]) { + const sseText = events.join(''); + requestMockImpl = (options) => { + if (options.path === '/oauth/token') { + // Should not reach here for non-refresh tests + return { statusCode: 200, chunks: ['{}'] }; + } + return { statusCode: 200, chunks: [sseText] }; + }; +} + +// Helper: set up a mock for error responses +function mockCodexError(statusCode: number, errorBody: string) { + requestMockImpl = () => { + return { statusCode, chunks: [errorBody] }; + }; +} + +describe('OpenAICodexAdapter', () => { + let adapter: OpenAICodexAdapter; + let tokens: CodexOAuthTokens; + + beforeEach(() => { + tokens = createTokens(); + adapter = new OpenAICodexAdapter(tokens); + capturedRequests = []; + requestMockImpl = null; + jest.clearAllMocks(); + }); + + describe('constructor', () => { + it('should initialize with provided tokens', () => { + const status = adapter.getTokenStatus(); + expect(status.hasAccessToken).toBe(true); + expect(status.hasRefreshToken).toBe(true); + expect(status.hasAccountId).toBe(true); + expect(status.isExpired).toBe(false); + }); + + it('should set adapter name to "openai-codex"', () => { + expect(adapter.name).toBe('openai-codex'); + }); + }); + + describe('isAvailable', () => { + it('should return true when all required tokens are present', async () => { + expect(await adapter.isAvailable()).toBe(true); + }); + + it('should return false when accessToken is empty', async () => { + const emptyAdapter = new OpenAICodexAdapter(createTokens({ accessToken: '' })); + expect(await emptyAdapter.isAvailable()).toBe(false); + }); + + it('should return false when refreshToken is empty', async () => { + const emptyAdapter = new OpenAICodexAdapter(createTokens({ refreshToken: '' })); + expect(await emptyAdapter.isAvailable()).toBe(false); + }); + + it('should return false when accountId is empty', async () => { + const emptyAdapter = new OpenAICodexAdapter(createTokens({ accountId: '' })); + expect(await emptyAdapter.isAvailable()).toBe(false); + }); + }); + + describe('getTokenStatus', () => { + it('should report correct token status', () => { + const status = adapter.getTokenStatus(); + expect(status.hasAccessToken).toBe(true); + expect(status.hasRefreshToken).toBe(true); + expect(status.hasAccountId).toBe(true); + expect(status.isExpired).toBe(false); + expect(status.needsRefresh).toBe(false); + }); + + it('should detect expired tokens', () => { + const expiredAdapter = new OpenAICodexAdapter( + createTokens({ expiresAt: Date.now() - 1000 }) + ); + const status = expiredAdapter.getTokenStatus(); + expect(status.isExpired).toBe(true); + expect(status.needsRefresh).toBe(true); + }); + + it('should detect tokens needing refresh (within 5-minute threshold)', () => { + const soonAdapter = new OpenAICodexAdapter( + createTokens({ expiresAt: Date.now() + 60_000 }) // 1 minute from now + ); + const status = soonAdapter.getTokenStatus(); + expect(status.isExpired).toBe(false); + expect(status.needsRefresh).toBe(true); + }); + }); + + describe('updateTokens', () => { + it('should update token state', () => { + const newTokens = createTokens({ + accessToken: 'new-at', + refreshToken: 'new-rt', + accountId: 'acct-new', + }); + adapter.updateTokens(newTokens); + + const status = adapter.getTokenStatus(); + expect(status.hasAccessToken).toBe(true); + expect(status.hasRefreshToken).toBe(true); + expect(status.hasAccountId).toBe(true); + }); + }); + + describe('proactive token refresh', () => { + it('should refresh token before API call when close to expiry', async () => { + const nearExpiryTokens = createTokens({ + expiresAt: Date.now() + 60_000, // 1 minute (< 5 min threshold) + }); + const onRefresh = jest.fn(); + const nearExpiryAdapter = new OpenAICodexAdapter(nearExpiryTokens, onRefresh); + + let refreshCalled = false; + requestMockImpl = (options) => { + if (options.path === '/oauth/token') { + refreshCalled = true; + return { + statusCode: 200, + chunks: [JSON.stringify({ + access_token: 'refreshed-at', + refresh_token: 'rotated-rt', + expires_in: 3600, + })], + }; + } + // API call + return { + statusCode: 200, + chunks: [ + 'data: {"type":"response.output_text.delta","delta":{"text":"Hello"}}\n\n', + 'data: [DONE]\n\n', + ], + }; + }; + + const chunks: string[] = []; + for await (const chunk of nearExpiryAdapter.generateStreamAsync('test prompt')) { + if (chunk.content) chunks.push(chunk.content); + } + + // Refresh endpoint should have been called + expect(refreshCalled).toBe(true); + + // Callback should have been invoked with refreshed tokens + expect(onRefresh).toHaveBeenCalled(); + + // Should have captured both refresh + API calls + expect(capturedRequests.length).toBe(2); + expect(capturedRequests[0].options.path).toBe('/oauth/token'); + }); + + it('should not refresh token when far from expiry', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"text":"Hello"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const chunks: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test prompt')) { + if (chunk.content) chunks.push(chunk.content); + } + + // Only the API call, no refresh call + expect(capturedRequests.length).toBe(1); + expect(capturedRequests[0].options.hostname).toBe('chatgpt.com'); + }); + }); + + describe('generateStreamAsync request construction', () => { + it('should send correct headers including Authorization and ChatGPT-Account-Id', async () => { + mockCodexSSE(['data: [DONE]\n\n']); + + for await (const _ of adapter.generateStreamAsync('hello')) { /* no-op */ } + + const req = capturedRequests[0]; + expect(req.options.headers['Authorization']).toBe('Bearer test-access-token'); + expect(req.options.headers['ChatGPT-Account-Id']).toBe('acct-test-123'); + expect(req.options.headers['Content-Type']).toBe('application/json'); + }); + + it('should construct correct request body', async () => { + mockCodexSSE(['data: [DONE]\n\n']); + + for await (const _ of adapter.generateStreamAsync('hello', { + model: 'gpt-5.2-codex', + temperature: 0.7, + maxTokens: 1000, + systemPrompt: 'You are a helper.', + })) { /* no-op */ } + + const body = JSON.parse(capturedRequests[0].body); + expect(body.model).toBe('gpt-5.2-codex'); + expect(body.stream).toBe(true); + expect(body.store).toBe(false); + expect(body.temperature).toBe(0.7); + expect(body.max_output_tokens).toBe(1000); + expect(body.instructions).toBe('You are a helper.'); + expect(body.input).toEqual( + expect.arrayContaining([ + expect.objectContaining({ role: 'user', content: 'hello' }), + ]) + ); + }); + + it('should include system prompt in input array', async () => { + mockCodexSSE(['data: [DONE]\n\n']); + + for await (const _ of adapter.generateStreamAsync('hello', { + systemPrompt: 'System message', + })) { /* no-op */ } + + const body = JSON.parse(capturedRequests[0].body); + expect(body.input[0]).toEqual({ role: 'system', content: 'System message' }); + expect(body.input[1]).toEqual({ role: 'user', content: 'hello' }); + }); + + it('should use conversation history when provided', async () => { + mockCodexSSE(['data: [DONE]\n\n']); + + const history = [ + { role: 'system', content: 'You are helpful.' }, + { role: 'user', content: 'Hi' }, + { role: 'assistant', content: 'Hello!' }, + { role: 'user', content: 'Follow up' }, + ]; + + for await (const _ of adapter.generateStreamAsync('follow up', { + conversationHistory: history, + })) { /* no-op */ } + + const body = JSON.parse(capturedRequests[0].body); + expect(body.input).toEqual(history); + }); + }); + + describe('SSE stream parsing', () => { + it('should extract text from delta.text events', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"text":"Hello "}}\n\n', + 'data: {"type":"response.output_text.delta","delta":{"text":"world!"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + expect(texts).toEqual(['Hello ', 'world!']); + }); + + it('should extract text from delta as plain string (Shape 1a)', async () => { + // The Codex Responses API can send delta as a plain string instead of + // a nested object. This was the fix for the production text rendering bug. + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":"Plain string delta"}\n\n', + 'data: {"type":"response.output_text.delta","delta":"Second chunk"}\n\n', + 'data: [DONE]\n\n', + ]); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + expect(texts).toEqual(['Plain string delta', 'Second chunk']); + }); + + it('should extract text from top-level content field (Shape 3)', async () => { + // Some Codex event variants place content at the top level + mockCodexSSE([ + 'data: {"type":"response.some_event","content":"top level content"}\n\n', + 'data: [DONE]\n\n', + ]); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + expect(texts).toEqual(['top level content']); + }); + + it('should handle delta.content variant', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"content":"content text"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + expect(texts).toEqual(['content text']); + }); + + it('should skip output_text.done recap events', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"text":"Hello"}}\n\n', + 'data: {"type":"response.output_text.done","text":"Hello full text recap"}\n\n', + 'data: [DONE]\n\n', + ]); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + // Should only get delta, not the done recap + expect(texts).toEqual(['Hello']); + }); + + it('should emit complete=true on [DONE]', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"text":"hi"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const completeFlags: boolean[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + completeFlags.push(chunk.complete); + } + + expect(completeFlags[completeFlags.length - 1]).toBe(true); + }); + + it('should emit complete=true on response.completed event', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"text":"hi"}}\n\n', + 'data: {"type":"response.completed","id":"resp-123"}\n\n', + ]); + + const completeFlags: boolean[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + completeFlags.push(chunk.complete); + } + + expect(completeFlags[completeFlags.length - 1]).toBe(true); + }); + + it('should handle malformed JSON lines gracefully (skip them)', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"text":"before"}}\n\n', + 'data: {malformed json\n\n', + 'data: {"type":"response.output_text.delta","delta":{"text":"after"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + expect(texts).toEqual(['before', 'after']); + }); + + it('should handle SSE comments (lines starting with :)', async () => { + mockCodexSSE([ + ': this is a comment\n\n', + 'data: {"type":"response.output_text.delta","delta":{"text":"text"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + expect(texts).toEqual(['text']); + }); + + it('should handle empty lines gracefully', async () => { + mockCodexSSE([ + '\n', + 'data: {"type":"response.output_text.delta","delta":{"text":"ok"}}\n\n', + '\n', + 'data: [DONE]\n\n', + ]); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + expect(texts).toEqual(['ok']); + }); + + it('should correctly buffer and parse SSE data split across chunk boundaries', async () => { + // Simulate a network scenario where a single SSE event arrives in two TCP chunks, + // splitting mid-JSON. The async iterator yields two separate Buffer chunks. + requestMockImpl = () => ({ + statusCode: 200, + chunks: [ + 'data: {"type":"response.output_text.delta","delta":{"tex', // chunk 1: incomplete line + 't":"split across chunks"}}\n\ndata: [DONE]\n\n', // chunk 2: rest of line + DONE + ], + }); + + const texts: string[] = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + if (chunk.content) texts.push(chunk.content); + } + + expect(texts).toEqual(['split across chunks']); + }); + }); + + describe('HTTP error handling', () => { + it('should throw AUTHENTICATION_ERROR on 401', async () => { + mockCodexError(401, 'Unauthorized'); + + await expect(async () => { + for await (const _ of adapter.generateStreamAsync('test')) { /* no-op */ } + }).rejects.toThrow('authentication failed'); + }); + + it('should throw AUTHENTICATION_ERROR on 403', async () => { + mockCodexError(403, 'Forbidden'); + + await expect(async () => { + for await (const _ of adapter.generateStreamAsync('test')) { /* no-op */ } + }).rejects.toThrow('authentication failed'); + }); + + it('should throw RATE_LIMIT_ERROR on 429', async () => { + mockCodexError(429, 'Rate limit exceeded'); + + try { + for await (const _ of adapter.generateStreamAsync('test')) { /* no-op */ } + fail('Expected error to be thrown'); + } catch (error: any) { + expect(error.name).toBe('LLMProviderError'); + expect(error.code).toBe('RATE_LIMIT_ERROR'); + expect(error.provider).toBe('openai-codex'); + expect(error.message).toContain('rate limited'); + expect(error.message).toContain('429'); + } + }); + + it('should throw HTTP_ERROR on other status codes', async () => { + mockCodexError(500, 'Internal Server Error'); + + await expect(async () => { + for await (const _ of adapter.generateStreamAsync('test')) { /* no-op */ } + }).rejects.toThrow('Codex API error'); + }); + }); + + describe('concurrent token refresh deduplication', () => { + it('should make only one refresh call when two requests need refresh simultaneously', async () => { + const nearExpiryTokens = createTokens({ + expiresAt: Date.now() + 60_000, // Within 5-min threshold + }); + const onRefresh = jest.fn(); + const sharedAdapter = new OpenAICodexAdapter(nearExpiryTokens, onRefresh); + + let refreshCallCount = 0; + requestMockImpl = (options) => { + if (options.path === '/oauth/token') { + refreshCallCount++; + return { + statusCode: 200, + chunks: [JSON.stringify({ + access_token: 'refreshed-at', + refresh_token: 'rotated-rt', + expires_in: 3600, + })], + }; + } + return { + statusCode: 200, + chunks: [ + 'data: {"type":"response.output_text.delta","delta":{"text":"ok"}}\n\n', + 'data: [DONE]\n\n', + ], + }; + }; + + // Fire two concurrent streaming requests + const stream1 = (async () => { + const texts: string[] = []; + for await (const chunk of sharedAdapter.generateStreamAsync('prompt1')) { + if (chunk.content) texts.push(chunk.content); + } + return texts; + })(); + + const stream2 = (async () => { + const texts: string[] = []; + for await (const chunk of sharedAdapter.generateStreamAsync('prompt2')) { + if (chunk.content) texts.push(chunk.content); + } + return texts; + })(); + + const [result1, result2] = await Promise.all([stream1, stream2]); + + expect(result1).toEqual(['ok']); + expect(result2).toEqual(['ok']); + + // Only 1 refresh call should be made (not 2) + expect(refreshCallCount).toBe(1); + }); + }); + + describe('token refresh error handling', () => { + it('should throw AUTHENTICATION_ERROR when token refresh fails', async () => { + const nearExpiryAdapter = new OpenAICodexAdapter( + createTokens({ expiresAt: Date.now() + 60_000 }) + ); + + requestMockImpl = (options) => { + if (options.path === '/oauth/token') { + return { statusCode: 400, chunks: ['Invalid grant'] }; + } + return { statusCode: 200, chunks: ['data: [DONE]\n\n'] }; + }; + + await expect(async () => { + for await (const _ of nearExpiryAdapter.generateStreamAsync('test')) { /* no-op */ } + }).rejects.toThrow('Token refresh failed'); + }); + }); + + describe('getCapabilities', () => { + it('should report streaming support', () => { + const caps = adapter.getCapabilities(); + expect(caps.supportsStreaming).toBe(true); + }); + + it('should report function calling support', () => { + const caps = adapter.getCapabilities(); + expect(caps.supportsFunctions).toBe(true); + }); + + it('should include tool_calling in supported features', () => { + const caps = adapter.getCapabilities(); + expect(caps.supportedFeatures).toContain('tool_calling'); + }); + + it('should include oauth_required in supported features', () => { + const caps = adapter.getCapabilities(); + expect(caps.supportedFeatures).toContain('oauth_required'); + }); + }); + + describe('getModelPricing', () => { + it('should return $0 pricing for known models', async () => { + const pricing = await adapter.getModelPricing('gpt-5.3-codex'); + expect(pricing).not.toBeNull(); + expect(pricing!.rateInputPerMillion).toBe(0); + expect(pricing!.rateOutputPerMillion).toBe(0); + }); + + it('should return null for unknown models', async () => { + const pricing = await adapter.getModelPricing('unknown-model'); + expect(pricing).toBeNull(); + }); + }); + + describe('listModels', () => { + it('should return models from ModelRegistry', async () => { + const models = await adapter.listModels(); + expect(models.length).toBeGreaterThan(0); + expect(models[0].id).toBe('gpt-5.3-codex'); + }); + }); + + describe('tool call support', () => { + it('should convert tools from Chat Completions format to Responses API format in request body', async () => { + mockCodexSSE(['data: [DONE]\n\n']); + + const tools = [ + { + type: 'function', + function: { + name: 'get_weather', + description: 'Get current weather', + parameters: { type: 'object', properties: { city: { type: 'string' } } }, + }, + }, + ]; + + for await (const _ of adapter.generateStreamAsync('What is the weather?', { tools })) { /* no-op */ } + + const body = JSON.parse(capturedRequests[0].body); + expect(body.tools).toEqual([ + { + type: 'function', + name: 'get_weather', + description: 'Get current weather', + parameters: { type: 'object', properties: { city: { type: 'string' } } }, + }, + ]); + + // tool_choice must be 'auto' so the model actually selects tools + expect(body.tool_choice).toBe('auto'); + + // instructions should be prepended with tool preamble + expect(body.instructions).toContain('tool access'); + expect(body.instructions).toContain('Call getTools first'); + }); + + it('should pass through tools already in Responses API format', async () => { + mockCodexSSE(['data: [DONE]\n\n']); + + const tools = [ + { + type: 'function', + name: 'search', + description: 'Search the web', + parameters: { type: 'object', properties: { q: { type: 'string' } } }, + }, + ]; + + for await (const _ of adapter.generateStreamAsync('Search for cats', { tools })) { /* no-op */ } + + const body = JSON.parse(capturedRequests[0].body); + expect(body.tools[0].name).toBe('search'); + }); + + it('should accumulate tool calls from response.output_item.done events', async () => { + mockCodexSSE([ + 'data: {"type":"response.function_call_arguments.delta","delta":"{\\"city\\":"}\n\n', + 'data: {"type":"response.function_call_arguments.delta","delta":"\\"NYC\\"}"}\n\n', + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"function_call","call_id":"call_abc","name":"get_weather","arguments":"{\\"city\\":\\"NYC\\"}"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const chunks: Array<{ toolCalls?: any[]; toolCallsReady?: boolean }> = []; + for await (const chunk of adapter.generateStreamAsync('weather in NYC')) { + chunks.push({ toolCalls: chunk.toolCalls, toolCallsReady: chunk.toolCallsReady }); + } + + const finalChunk = chunks[chunks.length - 1]; + expect(finalChunk.toolCallsReady).toBe(true); + expect(finalChunk.toolCalls).toHaveLength(1); + expect(finalChunk.toolCalls![0]).toEqual({ + id: 'call_abc', + type: 'function', + function: { + name: 'get_weather', + arguments: '{"city":"NYC"}', + }, + }); + }); + + it('should accumulate multiple tool calls', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"function_call","call_id":"call_1","name":"get_weather","arguments":"{\\"city\\":\\"NYC\\"}"}}\n\n', + 'data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","call_id":"call_2","name":"get_time","arguments":"{\\"tz\\":\\"EST\\"}"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const chunks: Array<{ toolCalls?: any[] }> = []; + for await (const chunk of adapter.generateStreamAsync('weather and time')) { + chunks.push({ toolCalls: chunk.toolCalls }); + } + + const finalChunk = chunks[chunks.length - 1]; + expect(finalChunk.toolCalls).toHaveLength(2); + expect(finalChunk.toolCalls![0].function.name).toBe('get_weather'); + expect(finalChunk.toolCalls![1].function.name).toBe('get_time'); + }); + + it('should include tool calls in completion event from response.completed', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"function_call","call_id":"call_x","name":"search","arguments":"{\\"q\\":\\"cats\\"}"}}\n\n', + 'data: {"type":"response.completed","id":"resp-456"}\n\n', + ]); + + const chunks: Array<{ toolCalls?: any[]; complete: boolean }> = []; + for await (const chunk of adapter.generateStreamAsync('search cats')) { + chunks.push({ toolCalls: chunk.toolCalls, complete: chunk.complete }); + } + + const finalChunk = chunks[chunks.length - 1]; + expect(finalChunk.complete).toBe(true); + expect(finalChunk.toolCalls).toHaveLength(1); + expect(finalChunk.toolCalls![0].id).toBe('call_x'); + }); + + it('should not include toolCalls in final chunk when no function calls were made', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"text":"Hello"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const chunks: Array<{ toolCalls?: any[]; toolCallsReady?: boolean }> = []; + for await (const chunk of adapter.generateStreamAsync('hello')) { + chunks.push({ toolCalls: chunk.toolCalls, toolCallsReady: chunk.toolCallsReady }); + } + + const finalChunk = chunks[chunks.length - 1]; + expect(finalChunk.toolCalls).toBeUndefined(); + expect(finalChunk.toolCallsReady).toBeUndefined(); + }); + + it('should use item.id as fallback when call_id is not present', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"function_call","id":"item_123","name":"test_fn","arguments":"{}"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const chunks: Array<{ toolCalls?: any[] }> = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + chunks.push({ toolCalls: chunk.toolCalls }); + } + + const finalChunk = chunks[chunks.length - 1]; + expect(finalChunk.toolCalls![0].id).toBe('item_123'); + }); + }); + + describe('network and stream errors', () => { + it('should propagate error when https.request emits error event', async () => { + // Override the mock to emit an error on the request object instead of + // calling the response callback + const https = require('https'); + (https.request as jest.Mock).mockImplementationOnce( + (_options: any, _callback: (res: any) => void) => { + const reqEmitter = new EventEmitter(); + (reqEmitter as any).write = () => {}; + (reqEmitter as any).setTimeout = jest.fn(); + (reqEmitter as any).destroy = jest.fn(); + (reqEmitter as any).end = () => { + // Simulate a network-level failure (DNS, connection refused, etc.) + setTimeout(() => { + reqEmitter.emit('error', new Error('connect ECONNREFUSED 127.0.0.1:443')); + }, 0); + }; + return reqEmitter; + } + ); + + await expect(async () => { + for await (const _ of adapter.generateStreamAsync('test')) { /* no-op */ } + }).rejects.toThrow('ECONNREFUSED'); + }); + + it('should propagate error when SSE stream emits error mid-stream', async () => { + // Override the mock to emit data then an error on the response stream + const https = require('https'); + (https.request as jest.Mock).mockImplementationOnce( + (_options: any, callback: (res: any) => void) => { + const reqEmitter = new EventEmitter(); + let writtenBody = ''; + (reqEmitter as any).write = (data: string) => { writtenBody += data; }; + (reqEmitter as any).setTimeout = jest.fn(); + (reqEmitter as any).destroy = jest.fn(); + (reqEmitter as any).end = () => { + capturedRequests.push({ options: _options, body: writtenBody }); + + const resEmitter = new EventEmitter(); + (resEmitter as any).statusCode = 200; + callback(resEmitter); + + // Emit one good chunk, then an error + setTimeout(() => { + resEmitter.emit('data', Buffer.from( + 'data: {"type":"response.output_text.delta","delta":{"text":"partial"}}\n\n' + )); + resEmitter.emit('error', new Error('socket hang up')); + }, 0); + }; + return reqEmitter; + } + ); + + await expect(async () => { + for await (const _ of adapter.generateStreamAsync('test')) { /* no-op */ } + }).rejects.toThrow('socket hang up'); + }); + + it('should emit fallback completion when stream ends without [DONE]', async () => { + // Stream sends a delta then ends abruptly — no [DONE] or response.completed + const https = require('https'); + (https.request as jest.Mock).mockImplementationOnce( + (_options: any, callback: (res: any) => void) => { + const reqEmitter = new EventEmitter(); + let writtenBody = ''; + (reqEmitter as any).write = (data: string) => { writtenBody += data; }; + (reqEmitter as any).setTimeout = jest.fn(); + (reqEmitter as any).destroy = jest.fn(); + (reqEmitter as any).end = () => { + capturedRequests.push({ options: _options, body: writtenBody }); + + const resEmitter = new EventEmitter(); + (resEmitter as any).statusCode = 200; + callback(resEmitter); + + setTimeout(() => { + resEmitter.emit('data', Buffer.from( + 'data: {"type":"response.output_text.delta","delta":{"text":"truncated"}}\n\n' + )); + resEmitter.emit('end'); + }, 0); + }; + return reqEmitter; + } + ); + + const chunks: Array<{ content: string; complete: boolean }> = []; + for await (const chunk of adapter.generateStreamAsync('test')) { + chunks.push({ content: chunk.content, complete: chunk.complete }); + } + + // Should get the text delta plus a fallback completion + expect(chunks.some(c => c.content === 'truncated')).toBe(true); + expect(chunks[chunks.length - 1].complete).toBe(true); + }); + }); + + describe('generateUncached', () => { + it('should collect all stream chunks and return assembled response', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_text.delta","delta":{"text":"Hello "}}\n\n', + 'data: {"type":"response.output_text.delta","delta":{"text":"World!"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const response = await adapter.generateUncached('test prompt'); + + expect(response.text).toBe('Hello World!'); + expect(response.usage.totalTokens).toBe(0); + expect(response.finishReason).toBe('stop'); + }); + + it('should return tool_calls finishReason and toolCalls when function calls are present', async () => { + mockCodexSSE([ + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"function_call","call_id":"call_uc1","name":"get_weather","arguments":"{\\"city\\":\\"SF\\"}"}}\n\n', + 'data: [DONE]\n\n', + ]); + + const response = await adapter.generateUncached('weather in SF'); + + expect(response.finishReason).toBe('tool_calls'); + expect(response.toolCalls).toHaveLength(1); + expect(response.toolCalls![0]).toEqual({ + id: 'call_uc1', + type: 'function', + function: { + name: 'get_weather', + arguments: '{"city":"SF"}', + }, + }); + }); + }); +}); diff --git a/tests/unit/OpenAICodexOAuthProvider.test.ts b/tests/unit/OpenAICodexOAuthProvider.test.ts new file mode 100644 index 00000000..0f5955f2 --- /dev/null +++ b/tests/unit/OpenAICodexOAuthProvider.test.ts @@ -0,0 +1,430 @@ +/** + * OpenAICodexOAuthProvider Unit Tests + * + * Tests the OpenAI Codex OAuth provider: + * - Static configuration + * - Authorization URL construction + * - Token exchange (form-urlencoded) + * - JWT parsing for account ID extraction + * - Token refresh + */ + +import { OpenAICodexOAuthProvider } from '../../src/services/oauth/providers/OpenAICodexOAuthProvider'; + +// Mock global fetch +const mockFetch = jest.fn(); +global.fetch = mockFetch; + +/** Helper: create a mock JWT with given claims payload */ +function createMockJwt(claims: Record): string { + const header = btoa(JSON.stringify({ alg: 'RS256', typ: 'JWT' })) + .replace(/\+/g, '-') + .replace(/\//g, '_') + .replace(/=+$/g, ''); + const payload = btoa(JSON.stringify(claims)) + .replace(/\+/g, '-') + .replace(/\//g, '_') + .replace(/=+$/g, ''); + const signature = 'mock-signature'; + return `${header}.${payload}.${signature}`; +} + +describe('OpenAICodexOAuthProvider', () => { + let provider: OpenAICodexOAuthProvider; + + beforeEach(() => { + provider = new OpenAICodexOAuthProvider(); + jest.clearAllMocks(); + }); + + describe('config', () => { + it('should have providerId "openai-codex"', () => { + expect(provider.config.providerId).toBe('openai-codex'); + }); + + it('should have correct client_id', () => { + expect(provider.config.clientId).toBe('app_EMoamEEZ73f0CkXaXp7hrann'); + }); + + it('should prefer port 1455', () => { + expect(provider.config.preferredPort).toBe(1455); + }); + + it('should use /auth/callback path', () => { + expect(provider.config.callbackPath).toBe('/auth/callback'); + }); + + it('should request openid, profile, email, offline_access scopes', () => { + expect(provider.config.scopes).toEqual(['openid', 'profile', 'email', 'offline_access']); + }); + + it('should use expiring-token type', () => { + expect(provider.config.tokenType).toBe('expiring-token'); + }); + + it('should display name "ChatGPT"', () => { + expect(provider.config.displayName).toBe('ChatGPT'); + }); + + it('should point to correct auth endpoint', () => { + expect(provider.config.authUrl).toBe('https://auth.openai.com/oauth/authorize'); + }); + + it('should point to correct token endpoint', () => { + expect(provider.config.tokenUrl).toBe('https://auth.openai.com/oauth/token'); + }); + }); + + describe('buildAuthUrl', () => { + const callbackUrl = 'http://127.0.0.1:1455/auth/callback'; + const codeChallenge = 'test-challenge'; + const state = 'test-state'; + + it('should produce a URL starting with the auth endpoint', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + expect(url).toMatch(/^https:\/\/auth\.openai\.com\/oauth\/authorize\?/); + }); + + it('should include response_type=code', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('response_type')).toBe('code'); + }); + + it('should include the correct client_id', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('client_id')).toBe('app_EMoamEEZ73f0CkXaXp7hrann'); + }); + + it('should include redirect_uri', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('redirect_uri')).toBe(callbackUrl); + }); + + it('should include correct scope', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('scope')).toBe('openid profile email offline_access'); + }); + + it('should include code_challenge and code_challenge_method=S256', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('code_challenge')).toBe(codeChallenge); + expect(params.get('code_challenge_method')).toBe('S256'); + }); + + it('should include state parameter', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('state')).toBe(state); + }); + + it('should include prompt=login', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('prompt')).toBe('login'); + }); + + it('should include codex_cli_simplified_flow=true', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('codex_cli_simplified_flow')).toBe('true'); + }); + + it('should include id_token_add_organizations=true', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('id_token_add_organizations')).toBe('true'); + }); + + it('should include originator=opencode', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('originator')).toBe('opencode'); + }); + }); + + describe('exchangeCode', () => { + const callbackUrl = 'http://127.0.0.1:1455/auth/callback'; + + it('should POST form-urlencoded to token endpoint', async () => { + const idToken = createMockJwt({ chatgpt_account_id: 'acct-123' }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'at-xyz', + refresh_token: 'rt-abc', + id_token: idToken, + expires_in: 3600, + }), + }); + + await provider.exchangeCode('auth-code', 'verifier-123', callbackUrl); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://auth.openai.com/oauth/token', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + }) + ); + + const body = new URLSearchParams(mockFetch.mock.calls[0][1].body); + expect(body.get('grant_type')).toBe('authorization_code'); + expect(body.get('client_id')).toBe('app_EMoamEEZ73f0CkXaXp7hrann'); + expect(body.get('code')).toBe('auth-code'); + expect(body.get('redirect_uri')).toBe(callbackUrl); + expect(body.get('code_verifier')).toBe('verifier-123'); + }); + + it('should return OAuthResult with access_token as apiKey', async () => { + const idToken = createMockJwt({ chatgpt_account_id: 'acct-456' }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'at-my-token', + refresh_token: 'rt-my-refresh', + id_token: idToken, + expires_in: 7200, + }), + }); + + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + + expect(result.apiKey).toBe('at-my-token'); + expect(result.refreshToken).toBe('rt-my-refresh'); + expect(result.expiresAt).toBeGreaterThan(Date.now()); + }); + + it('should extract accountId from id_token chatgpt_account_id claim', async () => { + const idToken = createMockJwt({ chatgpt_account_id: 'acct-from-id-token' }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'at-1', + refresh_token: 'rt-1', + id_token: idToken, + expires_in: 3600, + }), + }); + + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + + expect(result.metadata?.accountId).toBe('acct-from-id-token'); + }); + + it('should extract accountId from nested auth claim', async () => { + const idToken = createMockJwt({ + 'https://api.openai.com/auth': { chatgpt_account_id: 'nested-acct-id' }, + }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'at-2', + refresh_token: 'rt-2', + id_token: idToken, + expires_in: 3600, + }), + }); + + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + + expect(result.metadata?.accountId).toBe('nested-acct-id'); + }); + + it('should fall back to organizations[0].id for accountId', async () => { + const idToken = createMockJwt({ + organizations: [{ id: 'org-abc-123' }], + }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'at-3', + refresh_token: 'rt-3', + id_token: idToken, + expires_in: 3600, + }), + }); + + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + + expect(result.metadata?.accountId).toBe('org-abc-123'); + }); + + it('should fall back to access_token for accountId when id_token has no account info', async () => { + const idTokenWithoutAccount = createMockJwt({ email: 'user@example.com' }); + const accessTokenWithAccount = createMockJwt({ chatgpt_account_id: 'acct-from-at' }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: accessTokenWithAccount, + refresh_token: 'rt-4', + id_token: idTokenWithoutAccount, + expires_in: 3600, + }), + }); + + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + + expect(result.metadata?.accountId).toBe('acct-from-at'); + }); + + it('should NOT include idToken in metadata (PII prevention)', async () => { + const idToken = createMockJwt({ chatgpt_account_id: 'acct-x' }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'at-5', + refresh_token: 'rt-5', + id_token: idToken, + expires_in: 3600, + }), + }); + + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + + // id_token contains email PII and must NOT be persisted in metadata + expect(result.metadata?.idToken).toBeUndefined(); + // accountId should still be extracted + expect(result.metadata?.accountId).toBe('acct-x'); + }); + + it('should default expires_in to 3600 when not provided', async () => { + const idToken = createMockJwt({}); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'at-6', + refresh_token: 'rt-6', + id_token: idToken, + // No expires_in + }), + }); + + const beforeTime = Date.now(); + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + const afterTime = Date.now(); + + // Should default to 3600 seconds (1 hour) + expect(result.expiresAt).toBeGreaterThanOrEqual(beforeTime + 3600 * 1000); + expect(result.expiresAt).toBeLessThanOrEqual(afterTime + 3600 * 1000); + }); + + it('should throw on HTTP error response', async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 401, + text: async () => 'Unauthorized', + }); + + await expect( + provider.exchangeCode('bad-code', 'verifier', callbackUrl) + ).rejects.toThrow('Codex token exchange failed: HTTP 401 - Unauthorized'); + }); + + it('should handle invalid JWT in id_token gracefully', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'at-7', + refresh_token: 'rt-7', + id_token: 'not-a-jwt', + expires_in: 3600, + }), + }); + + // Should not throw -- just won't extract accountId + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + expect(result.apiKey).toBe('at-7'); + // accountId won't be extracted from invalid JWT + expect(result.metadata?.accountId).toBeUndefined(); + }); + + it('should handle missing id_token by trying access_token', async () => { + const accessTokenWithAccount = createMockJwt({ chatgpt_account_id: 'acct-at-only' }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: accessTokenWithAccount, + refresh_token: 'rt-8', + id_token: '', + expires_in: 3600, + }), + }); + + const result = await provider.exchangeCode('code', 'verifier', callbackUrl); + expect(result.metadata?.accountId).toBe('acct-at-only'); + }); + }); + + describe('refreshToken', () => { + it('should POST form-urlencoded with refresh_token grant type', async () => { + const idToken = createMockJwt({ chatgpt_account_id: 'acct-r1' }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'new-at', + refresh_token: 'new-rt', + id_token: idToken, + expires_in: 3600, + }), + }); + + await provider.refreshToken!('old-refresh-token'); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://auth.openai.com/oauth/token', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + }) + ); + + const body = new URLSearchParams(mockFetch.mock.calls[0][1].body); + expect(body.get('grant_type')).toBe('refresh_token'); + expect(body.get('client_id')).toBe('app_EMoamEEZ73f0CkXaXp7hrann'); + expect(body.get('refresh_token')).toBe('old-refresh-token'); + }); + + it('should return new OAuthResult with refreshed tokens', async () => { + const idToken = createMockJwt({ chatgpt_account_id: 'acct-r2' }); + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'refreshed-at', + refresh_token: 'rotated-rt', + id_token: idToken, + expires_in: 7200, + }), + }); + + const result = await provider.refreshToken!('old-rt'); + + expect(result).not.toBeNull(); + expect(result!.apiKey).toBe('refreshed-at'); + expect(result!.refreshToken).toBe('rotated-rt'); + expect(result!.expiresAt).toBeGreaterThan(Date.now()); + }); + + it('should return null on HTTP error', async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 400, + text: async () => 'Invalid grant', + }); + + const result = await provider.refreshToken!('expired-rt'); + expect(result).toBeNull(); + }); + + it('should return null on network error', async () => { + mockFetch.mockRejectedValue(new Error('Network error')); + + const result = await provider.refreshToken!('some-rt'); + expect(result).toBeNull(); + }); + }); +}); diff --git a/tests/unit/OpenRouterOAuthProvider.test.ts b/tests/unit/OpenRouterOAuthProvider.test.ts new file mode 100644 index 00000000..c3e81602 --- /dev/null +++ b/tests/unit/OpenRouterOAuthProvider.test.ts @@ -0,0 +1,225 @@ +/** + * OpenRouterOAuthProvider Unit Tests + * + * Tests the OpenRouter OAuth provider: + * - Static configuration (provider ID, port, etc.) + * - Authorization URL construction + * - Token exchange via mocked fetch + * - Pre-auth parameter handling + */ + +import { OpenRouterOAuthProvider } from '../../src/services/oauth/providers/OpenRouterOAuthProvider'; + +// Mock global fetch +const mockFetch = jest.fn(); +global.fetch = mockFetch; + +describe('OpenRouterOAuthProvider', () => { + let provider: OpenRouterOAuthProvider; + + beforeEach(() => { + provider = new OpenRouterOAuthProvider(); + jest.clearAllMocks(); + }); + + describe('config', () => { + it('should have providerId "openrouter"', () => { + expect(provider.config.providerId).toBe('openrouter'); + }); + + it('should have displayName "OpenRouter"', () => { + expect(provider.config.displayName).toBe('OpenRouter'); + }); + + it('should prefer port 3456', () => { + expect(provider.config.preferredPort).toBe(3456); + }); + + it('should use /callback path', () => { + expect(provider.config.callbackPath).toBe('/callback'); + }); + + it('should have empty scopes', () => { + expect(provider.config.scopes).toEqual([]); + }); + + it('should use permanent-key token type', () => { + expect(provider.config.tokenType).toBe('permanent-key'); + }); + + it('should have empty clientId', () => { + expect(provider.config.clientId).toBe(''); + }); + + it('should not be marked experimental', () => { + expect(provider.config.experimental).toBeUndefined(); + }); + + it('should point to correct auth URL', () => { + expect(provider.config.authUrl).toBe('https://openrouter.ai/auth'); + }); + + it('should point to correct token URL', () => { + expect(provider.config.tokenUrl).toBe('https://openrouter.ai/api/v1/auth/keys'); + }); + }); + + describe('buildAuthUrl', () => { + const callbackUrl = 'http://127.0.0.1:3000/callback'; + const codeChallenge = 'test-challenge-abc123'; + const state = 'test-state-xyz'; + + it('should produce a URL starting with the auth endpoint', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + expect(url).toMatch(/^https:\/\/openrouter\.ai\/auth\?/); + }); + + it('should include callback_url parameter', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('callback_url')).toBe(callbackUrl); + }); + + it('should include code_challenge parameter', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('code_challenge')).toBe(codeChallenge); + }); + + it('should include code_challenge_method=S256', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('code_challenge_method')).toBe('S256'); + }); + + it('should include state parameter', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.get('state')).toBe(state); + }); + + it('should include key_label when provided in preAuthParams', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state, { + key_label: 'My Obsidian Key', + }); + const params = new URL(url).searchParams; + expect(params.get('key_label')).toBe('My Obsidian Key'); + }); + + it('should include credit_limit as "limit" parameter', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state, { + credit_limit: '10', + }); + const params = new URL(url).searchParams; + expect(params.get('limit')).toBe('10'); + }); + + it('should not include key_label when not provided', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.has('key_label')).toBe(false); + }); + + it('should not include limit when credit_limit not provided', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state); + const params = new URL(url).searchParams; + expect(params.has('limit')).toBe(false); + }); + + it('should not include key_label when it is empty string', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state, { + key_label: '', + }); + const params = new URL(url).searchParams; + expect(params.has('key_label')).toBe(false); + }); + + it('should include both key_label and limit when both provided', () => { + const url = provider.buildAuthUrl(callbackUrl, codeChallenge, state, { + key_label: 'MyKey', + credit_limit: '5', + }); + const params = new URL(url).searchParams; + expect(params.get('key_label')).toBe('MyKey'); + expect(params.get('limit')).toBe('5'); + }); + }); + + describe('exchangeCode', () => { + it('should POST to token URL with correct JSON body', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ key: 'sk-or-v1-test-key' }), + }); + + await provider.exchangeCode('auth-code-123', 'verifier-abc', 'http://127.0.0.1:3000/callback'); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://openrouter.ai/api/v1/auth/keys', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + }) + ); + + const body = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(body.code).toBe('auth-code-123'); + expect(body.code_verifier).toBe('verifier-abc'); + expect(body.code_challenge_method).toBe('S256'); + }); + + it('should return OAuthResult with the API key', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ key: 'sk-or-v1-my-key' }), + }); + + const result = await provider.exchangeCode('code', 'verifier', 'http://localhost:3000/callback'); + + expect(result.apiKey).toBe('sk-or-v1-my-key'); + // Permanent key -- no refresh token or expiry + expect(result.refreshToken).toBeUndefined(); + expect(result.expiresAt).toBeUndefined(); + }); + + it('should throw on HTTP error response', async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 400, + text: async () => 'Invalid code', + }); + + await expect( + provider.exchangeCode('bad-code', 'verifier', 'http://localhost:3000/callback') + ).rejects.toThrow('OpenRouter token exchange failed: HTTP 400 - Invalid code'); + }); + + it('should throw when response has no key', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({}), + }); + + await expect( + provider.exchangeCode('code', 'verifier', 'http://localhost:3000/callback') + ).rejects.toThrow('OpenRouter token exchange returned no key'); + }); + + it('should throw when response key is empty string', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ key: '' }), + }); + + await expect( + provider.exchangeCode('code', 'verifier', 'http://localhost:3000/callback') + ).rejects.toThrow('OpenRouter token exchange returned no key'); + }); + }); + + describe('refreshToken', () => { + it('should not have a refreshToken method (permanent keys)', () => { + expect((provider as any).refreshToken).toBeUndefined(); + }); + }); +}); diff --git a/tests/unit/PKCEUtils.test.ts b/tests/unit/PKCEUtils.test.ts new file mode 100644 index 00000000..cb59cbd1 --- /dev/null +++ b/tests/unit/PKCEUtils.test.ts @@ -0,0 +1,170 @@ +/** + * PKCEUtils Unit Tests + * + * Tests PKCE (RFC 7636) cryptographic operations: + * - base64url encoding + * - Code verifier generation + * - Code challenge (S256) derivation + * - State parameter generation + */ + +import { + base64url, + generateCodeVerifier, + generateCodeChallenge, + generateState, +} from '../../src/services/oauth/PKCEUtils'; + +describe('PKCEUtils', () => { + describe('base64url', () => { + it('should encode an empty buffer', () => { + const result = base64url(new Uint8Array(0)); + expect(result).toBe(''); + }); + + it('should produce URL-safe output (no +, /, or = padding)', () => { + // Use bytes that would produce +, /, and = in standard base64 + const buffer = new Uint8Array([251, 255, 254, 63, 62]); + const encoded = base64url(buffer); + expect(encoded).not.toContain('+'); + expect(encoded).not.toContain('/'); + expect(encoded).not.toContain('='); + }); + + it('should replace + with -', () => { + // 0xFB 0xEF => standard base64 "u+8" which contains + + const buffer = new Uint8Array([251, 239]); + const encoded = base64url(buffer); + expect(encoded).toContain('-'); + expect(encoded).not.toContain('+'); + }); + + it('should replace / with _', () => { + // 0xFF 0xFF => standard base64 "//8" which contains / + const buffer = new Uint8Array([255, 255]); + const encoded = base64url(buffer); + expect(encoded).toContain('_'); + expect(encoded).not.toContain('/'); + }); + + it('should strip trailing = padding', () => { + // Single byte produces 2 base64 chars + 2 padding chars + const buffer = new Uint8Array([65]); // 'A' in ASCII + const encoded = base64url(buffer); + expect(encoded).not.toMatch(/=$/); + }); + + it('should accept ArrayBuffer as input', () => { + const arrayBuffer = new Uint8Array([72, 101, 108, 108, 111]).buffer; + const encoded = base64url(arrayBuffer); + expect(encoded).toBe('SGVsbG8'); // base64url of "Hello" + }); + + it('should accept Uint8Array as input', () => { + const uint8 = new Uint8Array([72, 101, 108, 108, 111]); + const encoded = base64url(uint8); + expect(encoded).toBe('SGVsbG8'); // base64url of "Hello" + }); + + it('should produce consistent output for same input', () => { + const buffer = new Uint8Array([1, 2, 3, 4, 5]); + const first = base64url(buffer); + const second = base64url(buffer); + expect(first).toBe(second); + }); + }); + + describe('generateCodeVerifier', () => { + it('should produce a 43-character string', () => { + const verifier = generateCodeVerifier(); + expect(verifier).toHaveLength(43); + }); + + it('should only contain unreserved URI characters (A-Z, a-z, 0-9, -, ., _, ~)', () => { + const verifier = generateCodeVerifier(); + expect(verifier).toMatch(/^[A-Za-z0-9\-._~]+$/); + }); + + it('should produce unique values on successive calls', () => { + const verifiers = new Set(); + for (let i = 0; i < 20; i++) { + verifiers.add(generateCodeVerifier()); + } + // All 20 should be unique (collision probability is astronomically low) + expect(verifiers.size).toBe(20); + }); + + it('should use the full character set over many generations', () => { + // Generate many verifiers and check we see a good spread + const allChars = new Set(); + for (let i = 0; i < 100; i++) { + const v = generateCodeVerifier(); + for (const ch of v) { + allChars.add(ch); + } + } + // We should see letters, digits, and at least some special chars + expect(allChars.size).toBeGreaterThan(30); + }); + }); + + describe('generateCodeChallenge', () => { + it('should produce a non-empty string', async () => { + const verifier = generateCodeVerifier(); + const challenge = await generateCodeChallenge(verifier); + expect(challenge.length).toBeGreaterThan(0); + }); + + it('should produce base64url-encoded output', async () => { + const verifier = generateCodeVerifier(); + const challenge = await generateCodeChallenge(verifier); + expect(challenge).toMatch(/^[A-Za-z0-9\-_]+$/); + }); + + it('should be deterministic for the same verifier', async () => { + const verifier = 'fixed-test-verifier-value-that-is-long-en'; + const challenge1 = await generateCodeChallenge(verifier); + const challenge2 = await generateCodeChallenge(verifier); + expect(challenge1).toBe(challenge2); + }); + + it('should produce different challenges for different verifiers', async () => { + const v1 = generateCodeVerifier(); + const v2 = generateCodeVerifier(); + const c1 = await generateCodeChallenge(v1); + const c2 = await generateCodeChallenge(v2); + expect(c1).not.toBe(c2); + }); + + it('should produce a 43-character challenge (SHA-256 = 32 bytes => 43 base64url chars)', async () => { + const verifier = generateCodeVerifier(); + const challenge = await generateCodeChallenge(verifier); + expect(challenge).toHaveLength(43); + }); + }); + + describe('generateState', () => { + it('should produce a non-empty string', () => { + const state = generateState(); + expect(state.length).toBeGreaterThan(0); + }); + + it('should produce base64url-encoded output', () => { + const state = generateState(); + expect(state).toMatch(/^[A-Za-z0-9\-_]+$/); + }); + + it('should produce approximately 43 characters (32 bytes base64url)', () => { + const state = generateState(); + expect(state).toHaveLength(43); + }); + + it('should produce unique values on successive calls', () => { + const states = new Set(); + for (let i = 0; i < 20; i++) { + states.add(generateState()); + } + expect(states.size).toBe(20); + }); + }); +}); diff --git a/tests/unit/ProviderMessageBuilder.codex.test.ts b/tests/unit/ProviderMessageBuilder.codex.test.ts new file mode 100644 index 00000000..7344a1dc --- /dev/null +++ b/tests/unit/ProviderMessageBuilder.codex.test.ts @@ -0,0 +1,253 @@ +/** + * ProviderMessageBuilder — openai-codex continuation branch tests + * + * Tests the stateless Responses API input array construction used by the + * Codex provider in buildContinuationOptions. + */ + +import { + ProviderMessageBuilder, + ConversationMessage, + GenerateOptionsInternal, +} from '../../src/services/llm/core/ProviderMessageBuilder'; + +// Mock ConversationContextBuilder — not needed for the openai-codex branch +// (it builds its own input array without calling ConversationContextBuilder) +jest.mock('../../src/services/chat/ConversationContextBuilder', () => ({ + ConversationContextBuilder: { + buildToolContinuation: jest.fn(), + buildResponsesAPIToolInput: jest.fn(), + }, +})); + +describe('ProviderMessageBuilder — openai-codex continuation', () => { + let builder: ProviderMessageBuilder; + const baseGenerateOptions: GenerateOptionsInternal = { + model: 'gpt-5.3-codex', + systemPrompt: 'You are helpful.', + tools: [{ type: 'function', name: 'getTools', parameters: {} } as any], + }; + + beforeEach(() => { + builder = new ProviderMessageBuilder(new Map()); + }); + + it('should build a full input array with prior messages, user prompt, function_call, and function_call_output', () => { + const previousMessages: ConversationMessage[] = [ + { role: 'user', content: 'Hi there' }, + { role: 'assistant', content: 'Hello!' }, + ]; + + const toolCalls = [ + { + id: 'call_abc', + type: 'function' as const, + function: { name: 'get_weather', arguments: '{"city":"NYC"}' }, + }, + ]; + + const toolResults = [ + { id: 'call_abc', name: 'get_weather', success: true, result: { temp: 72 } }, + ]; + + const result = builder.buildContinuationOptions( + 'openai-codex', + 'What is the weather?', + toolCalls, + toolResults, + previousMessages, + baseGenerateOptions, + ); + + const input = result.conversationHistory as Array>; + + // Prior messages included (non-system) + expect(input[0]).toEqual({ role: 'user', content: 'Hi there' }); + expect(input[1]).toEqual({ role: 'assistant', content: 'Hello!' }); + + // Current user prompt + expect(input[2]).toEqual({ role: 'user', content: 'What is the weather?' }); + + // function_call item + expect(input[3]).toEqual({ + type: 'function_call', + call_id: 'call_abc', + name: 'get_weather', + arguments: '{"city":"NYC"}', + }); + + // function_call_output item + expect(input[4]).toEqual({ + type: 'function_call_output', + call_id: 'call_abc', + output: JSON.stringify({ temp: 72 }), + }); + }); + + it('should skip system messages from previous messages', () => { + const previousMessages: ConversationMessage[] = [ + { role: 'system', content: 'System prompt in messages' }, + { role: 'user', content: 'Hello' }, + ]; + + const result = builder.buildContinuationOptions( + 'openai-codex', + 'Follow up', + [], + [], + previousMessages, + baseGenerateOptions, + ); + + const input = result.conversationHistory as Array>; + + // System message should be skipped + expect(input[0]).toEqual({ role: 'user', content: 'Hello' }); + expect(input[1]).toEqual({ role: 'user', content: 'Follow up' }); + expect(input).toHaveLength(2); + }); + + it('should not include previousResponseId in output', () => { + const result = builder.buildContinuationOptions( + 'openai-codex', + 'test', + [], + [], + [], + baseGenerateOptions, + { conversationId: 'conv-1', responsesApiId: 'resp-old' } as any, + ); + + expect(result.previousResponseId).toBeUndefined(); + }); + + it('should preserve systemPrompt and tools from generateOptions', () => { + const result = builder.buildContinuationOptions( + 'openai-codex', + 'test', + [], + [], + [], + baseGenerateOptions, + ); + + expect(result.systemPrompt).toBe('You are helpful.'); + expect(result.tools).toEqual(baseGenerateOptions.tools); + }); + + it('should handle multiple tool calls and results', () => { + const toolCalls = [ + { + id: 'call_1', + type: 'function' as const, + function: { name: 'get_weather', arguments: '{"city":"NYC"}' }, + }, + { + id: 'call_2', + type: 'function' as const, + function: { name: 'get_time', arguments: '{"tz":"EST"}' }, + }, + ]; + + const toolResults = [ + { id: 'call_1', success: true, result: { temp: 72 } }, + { id: 'call_2', success: true, result: { time: '3:00 PM' } }, + ]; + + const result = builder.buildContinuationOptions( + 'openai-codex', + 'weather and time', + toolCalls, + toolResults, + [], + baseGenerateOptions, + ); + + const input = result.conversationHistory as Array>; + + // user prompt + 2 function_call + 2 function_call_output = 5 items + expect(input).toHaveLength(5); + + expect(input[1]).toMatchObject({ type: 'function_call', call_id: 'call_1', name: 'get_weather' }); + expect(input[2]).toMatchObject({ type: 'function_call', call_id: 'call_2', name: 'get_time' }); + expect(input[3]).toMatchObject({ type: 'function_call_output', call_id: 'call_1' }); + expect(input[4]).toMatchObject({ type: 'function_call_output', call_id: 'call_2' }); + }); + + it('should format failed tool results with error JSON', () => { + const toolCalls = [ + { + id: 'call_fail', + type: 'function' as const, + function: { name: 'broken_tool', arguments: '{}' }, + }, + ]; + + const toolResults = [ + { id: 'call_fail', success: false, error: 'Tool timed out' }, + ]; + + const result = builder.buildContinuationOptions( + 'openai-codex', + 'try this', + toolCalls, + toolResults, + [], + baseGenerateOptions, + ); + + const input = result.conversationHistory as Array>; + const output = input.find(i => i.type === 'function_call_output'); + + expect(output).toBeDefined(); + expect(JSON.parse(output!.output as string)).toEqual({ error: 'Tool timed out' }); + }); + + it('should extract name from ChatToolCall union type (name property)', () => { + // ChatToolCall has a top-level `name` property in addition to function.name + const toolCalls = [ + { + id: 'call_chat', + type: 'function' as const, + name: 'chat_tool_name', + function: { name: 'function_name', arguments: '{}' }, + }, + ]; + + const toolResults = [ + { id: 'call_chat', success: true, result: {} }, + ]; + + const result = builder.buildContinuationOptions( + 'openai-codex', + 'test', + toolCalls, + toolResults, + [], + baseGenerateOptions, + ); + + const input = result.conversationHistory as Array>; + const fnCall = input.find(i => i.type === 'function_call'); + + // Should prefer the top-level name (ChatToolCall path) + expect(fnCall!.name).toBe('chat_tool_name'); + }); + + it('should omit user prompt from input when empty', () => { + const result = builder.buildContinuationOptions( + 'openai-codex', + '', // empty prompt (continuation without new user message) + [], + [], + [{ role: 'user', content: 'earlier' }], + baseGenerateOptions, + ); + + const input = result.conversationHistory as Array>; + + // Only the prior message, no empty user prompt added + expect(input).toHaveLength(1); + expect(input[0]).toEqual({ role: 'user', content: 'earlier' }); + }); +});