diff --git a/src/components/notebook/NotebookView.tsx b/src/components/notebook/NotebookView.tsx index 0336a8cb..f1f53c8b 100644 --- a/src/components/notebook/NotebookView.tsx +++ b/src/components/notebook/NotebookView.tsx @@ -44,6 +44,7 @@ import { createNotebookFromState, } from "../../utils/notebookStore"; import { useDatabase } from "../../hooks/useDatabase"; +import { useSqlAutocompleteRegistration } from "../../hooks/useSqlAutocompleteRegistration"; import { isMultiDatabaseCapable } from "../../utils/database"; import { useSettings } from "../../hooks/useSettings"; import { useAlert } from "../../hooks/useAlert"; @@ -59,12 +60,14 @@ interface NotebookViewProps { tab: Tab; updateTab: (id: string, partial: Partial) => void; connectionId: string; + isActive: boolean; } export function NotebookView({ tab, updateTab, connectionId, + isActive, }: NotebookViewProps) { const { t } = useTranslation(); const { activeSchema, activeCapabilities, selectedDatabases } = useDatabase(); @@ -72,6 +75,10 @@ export function NotebookView({ isMultiDatabaseCapable(activeCapabilities) && selectedDatabases.length > 1; const effectiveSchema = tab.schema || activeSchema || (isMultiDb ? selectedDatabases[0] : null); + useSqlAutocompleteRegistration(connectionId, { + schema: effectiveSchema, + enabled: isActive, + }); const { settings } = useSettings(); const { showAlert } = useAlert(); const { matchesShortcut } = useKeybindings(); diff --git a/src/contexts/DatabaseProvider.tsx b/src/contexts/DatabaseProvider.tsx index a6620839..eb3fcf20 100644 --- a/src/contexts/DatabaseProvider.tsx +++ b/src/contexts/DatabaseProvider.tsx @@ -14,7 +14,7 @@ import { } from './DatabaseContext'; import type { ReactNode } from 'react'; import type { PluginManifest } from '../types/plugins'; -import { clearAutocompleteCache } from '../utils/autocomplete'; +import { clearAutocompleteCache, disposeSqlAutocomplete } from '../utils/autocomplete'; import { toErrorMessage } from '../utils/errors'; import { useSettings } from '../hooks/useSettings'; import { findConnectionsForDrivers } from '../utils/connectionManager'; @@ -691,6 +691,7 @@ export const DatabaseProvider = ({ children }: { children: ReactNode }) => { if (!targetId) return; clearAutocompleteCache(targetId); + disposeSqlAutocomplete(); try { await invoke('disconnect_connection', { connectionId: targetId }); @@ -782,6 +783,7 @@ export const DatabaseProvider = ({ children }: { children: ReactNode }) => { console.warn(`[DatabaseProvider] Connection health check failed for ${connectionId}: ${event.payload.error}`); clearAutocompleteCache(connectionId); + disposeSqlAutocomplete(); setOpenConnectionIds(prev => prev.filter(id => id !== connectionId)); setConnectionDataMap(prev => { diff --git a/src/hooks/useSqlAutocompleteRegistration.ts b/src/hooks/useSqlAutocompleteRegistration.ts new file mode 100644 index 00000000..1f3e476f --- /dev/null +++ b/src/hooks/useSqlAutocompleteRegistration.ts @@ -0,0 +1,89 @@ +import { useEffect } from "react"; +import type { Monaco } from "@monaco-editor/react"; +import { loader } from "@monaco-editor/react"; +import { useDatabase } from "./useDatabase"; +import { isMultiDatabaseCapable } from "../utils/database"; +import { registerSqlAutocomplete } from "../utils/autocomplete"; + +type Options = { + monaco?: Monaco | null; + schema?: string | null; + /** When false, skips registration (e.g. inactive notebook tabs). Defaults to true. */ + enabled?: boolean; +}; + +/** + * Keeps the global SQL completion provider in sync with the active connection. + * Pass `monaco` from the main editor when available; otherwise Monaco is loaded via loader.init (notebook). + */ +export function useSqlAutocompleteRegistration( + connectionId: string | null, + options?: Options, +) { + const { + tables, + activeDriver, + activeSchema, + activeCapabilities, + schemaDataMap, + databaseDataMap, + selectedDatabases, + } = useDatabase(); + + const schema = options?.schema ?? activeSchema; + const isMultiDb = + isMultiDatabaseCapable(activeCapabilities) && selectedDatabases.length > 1; + + const enabled = options?.enabled ?? true; + + useEffect(() => { + if (!connectionId || !enabled) return; + + let cancelled = false; + + const register = (monaco: Monaco) => { + if (cancelled) return; + + let effectiveTables = tables; + if (activeCapabilities?.schemas && schema) { + effectiveTables = schemaDataMap[schema]?.tables ?? tables; + } else if (isMultiDb) { + effectiveTables = selectedDatabases.flatMap( + (db) => databaseDataMap[db]?.tables ?? [], + ); + } + + registerSqlAutocomplete( + monaco, + connectionId, + effectiveTables, + schema, + activeDriver ?? null, + ); + }; + + if (options?.monaco) { + register(options.monaco); + return () => { + cancelled = true; + }; + } + + loader.init().then((monaco) => register(monaco)); + return () => { + cancelled = true; + }; + }, [ + connectionId, + enabled, + options?.monaco, + schema, + tables, + activeDriver, + activeCapabilities, + schemaDataMap, + databaseDataMap, + isMultiDb, + selectedDatabases, + ]); +} diff --git a/src/pages/Editor.tsx b/src/pages/Editor.tsx index 6f226067..8dfda8bc 100644 --- a/src/pages/Editor.tsx +++ b/src/pages/Editor.tsx @@ -82,7 +82,7 @@ import { SqlEditorWrapper } from "../components/ui/SqlEditorWrapper"; import { NotebookView } from "../components/notebook/NotebookView"; import { extractSqlFromCells } from "../utils/notebook"; import { createNotebook } from "../utils/notebookStore"; -import { registerSqlAutocomplete } from "../utils/autocomplete"; +import { useSqlAutocompleteRegistration } from "../hooks/useSqlAutocompleteRegistration"; import { type OnMount, type Monaco } from "@monaco-editor/react"; import { save } from "@tauri-apps/plugin-dialog"; import { useAlert } from "../hooks/useAlert"; @@ -137,7 +137,6 @@ export const Editor = () => { const { t } = useTranslation(); const { activeConnectionId, - tables, views, activeDriver, activeSchema, @@ -145,8 +144,6 @@ export const Editor = () => { selectedDatabases, activeConnectionName, activeDatabaseName, - schemaDataMap, - databaseDataMap, } = useDatabase(); const { explorerConnectionId } = useConnectionLayoutContext(); const { settings } = useSettings(); @@ -2136,23 +2133,11 @@ export const Editor = () => { }); }; - useEffect(() => { - if (monacoInstance && activeConnectionId) { - let effectiveTables = tables; - if (activeCapabilities?.schemas && activeSchema) { - effectiveTables = schemaDataMap[activeSchema]?.tables ?? tables; - } else if (isMultiDb) { - effectiveTables = selectedDatabases.flatMap(db => databaseDataMap[db]?.tables ?? []); - } - const disposable = registerSqlAutocomplete( - monacoInstance, - activeConnectionId, - effectiveTables, - activeSchema, - ); - return () => disposable.dispose(); - } - }, [monacoInstance, activeConnectionId, tables, activeSchema, activeCapabilities, schemaDataMap, databaseDataMap, isMultiDb, selectedDatabases]); + useSqlAutocompleteRegistration(activeConnectionId, { + monaco: monacoInstance, + schema: activeSchema, + enabled: !isNotebookTab, + }); useEffect(() => { const state = location.state as EditorState; @@ -2742,6 +2727,7 @@ export const Editor = () => { tab={tab} updateTab={updateTab} connectionId={activeConnectionId || ""} + isActive={isActive} /> ); diff --git a/src/utils/autocomplete.ts b/src/utils/autocomplete.ts index 40807522..34c5f31b 100644 --- a/src/utils/autocomplete.ts +++ b/src/utils/autocomplete.ts @@ -1,6 +1,7 @@ import type { Monaco } from "@monaco-editor/react"; import { invoke } from "@tauri-apps/api/core"; import type { TableInfo } from "../contexts/DatabaseContext"; +import { formatSqlIdentifier, quoteTableRef } from "./identifiers"; import { getCurrentStatement, parseTablesFromQuery } from "./sqlAnalysis"; // Lightweight column cache with TTL and size limits @@ -98,11 +99,32 @@ export const clearAutocompleteCache = (connectionId?: string) => { } }; +// Find a table by name in the list of tables +const findTableByName = (name: string, tables: TableInfo[]) => + tables.find((t) => t.name.toLowerCase() === name.toLowerCase())?.name; + +const tableInsertText = ( + tableName: string, + driver?: string | null, + schema?: string | null, +) => + schema + ? quoteTableRef(tableName, driver, schema) + : formatSqlIdentifier(tableName, driver); + +let sqlCompletionProvider: { dispose: () => void } | null = null; + +export const disposeSqlAutocomplete = (): void => { + sqlCompletionProvider?.dispose(); + sqlCompletionProvider = null; +}; + export const registerSqlAutocomplete = ( monaco: Monaco, connectionId: string | null, tables: TableInfo[], schema?: string | null, + driver?: string | null, ) => { const provider = monaco.languages.registerCompletionItemProvider("sql", { triggerCharacters: [".", " "], @@ -141,13 +163,13 @@ export const registerSqlAutocomplete = ( // Check if it's an alias or table name let actualTableName = tableAliases?.get(typedName); - + if (!actualTableName) { - // Try direct table name match - const foundTable = tables.find(t => t.name.toLowerCase() === typedName); - actualTableName = foundTable?.name; + actualTableName = findTableByName(typedName, tables); + } else { + actualTableName = findTableByName(actualTableName, tables) ?? actualTableName; } - + if (actualTableName) { const columns = await getTableColumns(connectionId, actualTableName, schema); @@ -163,13 +185,15 @@ export const registerSqlAutocomplete = ( label: c.label, kind: monaco.languages.CompletionItemKind.Field, detail: c.detail, - insertText: c.label, + insertText: formatSqlIdentifier(c.label, driver), range: columnRange, sortText: `0_${c.label}`, })); return { suggestions }; } + + return { suggestions: [] }; } // ============================================ @@ -188,7 +212,7 @@ export const registerSqlAutocomplete = ( // User is inside a query with FROM/JOIN - suggest columns from those tables const tableNames = Array.from(new Set(tableAliases.values())); const matchingTables = tableNames - .map(name => tables.find(t => t.name.toLowerCase() === name.toLowerCase())) + .map((name) => tables.find((t) => t.name.toLowerCase() === name.toLowerCase())) .filter(Boolean) as TableInfo[]; // Limit parallel fetches to prevent memory spikes @@ -223,7 +247,7 @@ export const registerSqlAutocomplete = ( label: col.label, kind: monaco.languages.CompletionItemKind.Field, detail: `${col.detail} — ${table.name}${aliasHint}`, - insertText: col.label, + insertText: formatSqlIdentifier(col.label, driver), range, sortText: `0_${col.label}`, }); @@ -259,7 +283,7 @@ export const registerSqlAutocomplete = ( label: t.name, kind: monaco.languages.CompletionItemKind.Class, detail: "Table", - insertText: t.name, + insertText: tableInsertText(t.name, driver, schema), range, sortText: `1_${t.name}` })); @@ -274,5 +298,7 @@ export const registerSqlAutocomplete = ( }, }); + sqlCompletionProvider?.dispose(); + sqlCompletionProvider = provider; return provider; }; diff --git a/src/utils/sqlAnalysis.ts b/src/utils/sqlAnalysis.ts index d38a5ccb..4da6e684 100644 --- a/src/utils/sqlAnalysis.ts +++ b/src/utils/sqlAnalysis.ts @@ -1,49 +1,69 @@ // SQL Analysis Utilities - Pure logic functions for parsing and analyzing SQL +// Removes wrapping SQL identifier quotes/backticks. +// Unquoted identifiers are normalized to lowercase. +function stripIdentifierQuotes(token: string): string { + const q = token[0]; + if (q === '"' || q === '`') return token.slice(1, -1); + return token.toLowerCase(); +} + // Optimized table parser - early exit and minimal allocations export const parseTablesFromQuery = (sql: string): Map | null => { if (!sql || sql.length === 0) return null; - + const lowerSql = sql.toLowerCase(); - + // Quick check if query contains FROM/JOIN keywords if (!lowerSql.includes('from') && !lowerSql.includes('join')) { return null; } - + + // Only scan FROM clause onward (avoids SELECT-list commas; keeps quoted case) + const fromAt = lowerSql.search(/\bfrom\b/); + const scan = fromAt >= 0 ? sql.slice(fromAt) : sql; + const tableMap = new Map(); - const fromPattern = /(?:from|join)\s+(?:`)?([a-z_][a-z0-9_]*)(?:`)?(?:\s+(?:as\s+)?(?:`)?([a-z_][a-z0-9_]*)(?:`)?)?/gi; - + const fromPattern = + /(?:from|join|,)\s+("(?:[^"]|"")*"|`[^`]+`|[a-zA-Z_][a-zA-Z0-9_]*)(?:\.("(?:[^"]|"")*"|`[^`]+`|[a-zA-Z_][a-zA-Z0-9_]*))?(?:\s+(?:as\s+)?("(?:[^"]|"")*"|`[^`]+`|[a-zA-Z_][a-zA-Z0-9_]*))?/gi; + let match; let matchCount = 0; const MAX_MATCHES = 10; // Prevent regex catastrophic backtracking - - while ((match = fromPattern.exec(lowerSql)) !== null && matchCount++ < MAX_MATCHES) { - const tableName = match[1]; - const alias = match[2] || tableName; - tableMap.set(alias, tableName); + + while ((match = fromPattern.exec(scan)) !== null && matchCount++ < MAX_MATCHES) { + const tableToken = match[2] ?? match[1]; + + if (!tableToken) continue; + + const tableName = stripIdentifierQuotes(tableToken); + const aliasToken = match[3]; + const alias = aliasToken ? stripIdentifierQuotes(aliasToken) : tableName; + + tableMap.set(alias.toLowerCase(), tableName); } - + return tableMap.size > 0 ? tableMap : null; }; // Optimized statement extractor - avoid full text scan when possible export const getCurrentStatement = (model: { getValue: () => string; getOffsetAt: (position: { lineNumber: number; column: number }) => number }, position: { lineNumber: number; column: number }): string => { const fullText = model.getValue(); - + // For small files, just return full text if (fullText.length < 500) { return fullText; } - + const offset = model.getOffsetAt(position); let start = 0; let end = fullText.length; - + + // Search within reasonable bounds (±2000 chars from cursor) const searchStart = Math.max(0, offset - 2000); const searchEnd = Math.min(fullText.length, offset + 2000); - + // Find previous semicolon for (let i = offset - 1; i >= searchStart; i--) { if (fullText[i] === ';') { @@ -51,7 +71,7 @@ export const getCurrentStatement = (model: { getValue: () => string; getOffsetAt break; } } - + // Find next semicolon for (let i = offset; i < searchEnd; i++) { if (fullText[i] === ';') { @@ -59,6 +79,6 @@ export const getCurrentStatement = (model: { getValue: () => string; getOffsetAt break; } } - + return fullText.substring(start, end).trim(); }; diff --git a/tests/contexts/DatabaseProvider.test.tsx b/tests/contexts/DatabaseProvider.test.tsx index adc92e4b..b3e6dc33 100644 --- a/tests/contexts/DatabaseProvider.test.tsx +++ b/tests/contexts/DatabaseProvider.test.tsx @@ -15,6 +15,7 @@ vi.mock('@tauri-apps/api/event', () => ({ vi.mock('../../src/utils/autocomplete', () => ({ clearAutocompleteCache: vi.fn(), + disposeSqlAutocomplete: vi.fn(), })); vi.mock('../../src/hooks/useSettings', () => ({ diff --git a/tests/utils/autocomplete.test.ts b/tests/utils/autocomplete.test.ts index 2d85e7d6..647f917c 100644 --- a/tests/utils/autocomplete.test.ts +++ b/tests/utils/autocomplete.test.ts @@ -151,6 +151,50 @@ describe('autocomplete', () => { expect(tableSuggestions[1].label).toBe('orders'); }); + it('inserts double-quoted table names for postgres', async () => { + const monaco = createMockMonaco(); + registerSqlAutocomplete( + monaco as unknown as Parameters[0], + 'conn1', + [{ name: 'AccountEventLog' }], + null, + 'postgres', + ); + + const provider = monaco.languages.registerCompletionItemProvider.mock.calls[0][1]; + const result = await provider.provideCompletionItems( + createMockModel('SELECT * FROM '), + { lineNumber: 1, column: 15 }, + ); + + const tableSuggestions = result.suggestions.filter((s: { sortText?: string }) => + s.sortText?.startsWith('1_'), + ); + expect(tableSuggestions[0]?.insertText).toBe('"AccountEventLog"'); + }); + + it('inserts schema-qualified table names for postgres when schema is set', async () => { + const monaco = createMockMonaco(); + registerSqlAutocomplete( + monaco as unknown as Parameters[0], + 'conn1', + [{ name: 'AccountEventLog' }], + 'public', + 'postgres', + ); + + const provider = monaco.languages.registerCompletionItemProvider.mock.calls[0][1]; + const result = await provider.provideCompletionItems( + createMockModel('SELECT * FROM '), + { lineNumber: 1, column: 15 }, + ); + + const tableSuggestions = result.suggestions.filter((s: { sortText?: string }) => + s.sortText?.startsWith('1_'), + ); + expect(tableSuggestions[0]?.insertText).toBe('"public"."AccountEventLog"'); + }); + it('should include all table suggestions regardless of count', async () => { const monaco = createMockMonaco(); const tables: TableInfo[] = Array.from({ length: 60 }, (_, i) => ({ @@ -321,6 +365,33 @@ describe('autocomplete', () => { // Should include column suggestions expect(result.suggestions.length).toBeGreaterThan(0); }); + + it('inserts double-quoted column names for postgres', async () => { + const mockInvoke = invoke as unknown as ReturnType; + mockInvoke.mockResolvedValue([{ name: 'CreatedAt', data_type: 'timestamp' }]); + + const { parseTablesFromQuery } = await import('../../src/utils/sqlAnalysis'); + (parseTablesFromQuery as ReturnType).mockReturnValue( + new Map([['ael', 'AccountEventLog']]), + ); + + const monaco = createMockMonaco(); + registerSqlAutocomplete( + monaco as unknown as Parameters[0], + 'conn1', + [{ name: 'AccountEventLog' }], + 'public', + 'postgres', + ); + + const provider = monaco.languages.registerCompletionItemProvider.mock.calls[0][1]; + const model = createMockModel('SELECT ael.'); + model.getValueInRange = vi.fn(() => 'SELECT ael.'); + + const result = await provider.provideCompletionItems(model, { lineNumber: 1, column: 12 }); + + expect(result.suggestions[0]?.insertText).toBe('"CreatedAt"'); + }); }); describe('suggestion limits', () => { diff --git a/tests/utils/sqlAnalysis.test.ts b/tests/utils/sqlAnalysis.test.ts index 150689d9..889c81a9 100644 --- a/tests/utils/sqlAnalysis.test.ts +++ b/tests/utils/sqlAnalysis.test.ts @@ -48,6 +48,22 @@ describe('sqlAnalysis utils', () => { expect(result?.get('u')).toBe('users'); expect(result?.get('p')).toBe('products'); }); + + it('should extract PostgreSQL double-quoted table with alias', () => { + const result = parseTablesFromQuery('SELECT ael. FROM "AccountEventLog" ael'); + expect(result?.get('ael')).toBe('AccountEventLog'); + }); + + it('should extract schema-qualified table with alias', () => { + const result = parseTablesFromQuery('SELECT u. FROM public.users u'); + expect(result?.get('u')).toBe('users'); + }); + + it('should extract comma-separated FROM tables', () => { + const result = parseTablesFromQuery('SELECT * FROM users u, orders o'); + expect(result?.get('u')).toBe('users'); + expect(result?.get('o')).toBe('orders'); + }); }); describe('getCurrentStatement', () => {