diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e765b596..fdf55037e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Moved string literal escaping into plugin drivers via `escapeStringLiteral` on `PluginDatabaseDriver` and `DatabaseDriver` protocols; `SQLEscaping.escapeStringLiteral` now uses ANSI SQL escaping only (doubles single quotes, strips null bytes) +- SQL autocomplete data types and CREATE TABLE options now use plugin-provided dialect data instead of hardcoded per-database switches +- `FilterSQLGenerator` now uses `SQLDialectDescriptor` data (regex syntax, boolean literals, LIKE escape style, pagination style) instead of `DatabaseType` switch statements + ### Added - `SQLDialectDescriptor` in TableProPluginKit: plugins can now self-describe their SQL dialect (keywords, functions, data types, identifier quoting), with `SQLDialectFactory` preferring plugin-provided dialect info over built-in structs @@ -27,10 +33,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Pre-connect script: run a shell command before each connection (e.g., to refresh credentials or update ~/.pgpass) - `ParameterStyle` enum in TableProPluginKit: plugins declare `?` or `$1` placeholder style via `parameterStyle` property on `PluginDatabaseDriver` - DML statement generation in ClickHouse, MSSQL, and Oracle plugins via `generateStatements()` for database-specific UPDATE/DELETE syntax +- `quoteIdentifier` method on `PluginDatabaseDriver` and `DatabaseDriver` protocols: plugins provide database-specific identifier quoting (backticks for MySQL/SQLite/ClickHouse, brackets for MSSQL, double-quotes for PostgreSQL/Oracle/DuckDB, passthrough for MongoDB/Redis) ### Changed - Moved MSSQL and Oracle pagination query building (`OFFSET...FETCH NEXT`) from `TableQueryBuilder` into their respective plugin drivers via `buildBrowseQuery`/`buildFilteredQuery`/`buildQuickSearchQuery`/`buildCombinedQuery` hooks +- Moved identifier quoting from `DatabaseType` into plugin drivers via `quoteIdentifier` method on `PluginDatabaseDriver` protocol, with each plugin providing its own quoting style (backtick, bracket, double-quote, or passthrough) ### Fixed diff --git a/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift b/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift index 00757c4a8..70e7f2b34 100644 --- a/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift +++ b/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift @@ -84,7 +84,14 @@ final class ClickHousePlugin: NSObject, TableProPlugin, DriverPlugin { "ENUM8", "ENUM16", "IPV4", "IPV6", "JSON", "BOOL" - ] + ], + tableOptions: [ + "ENGINE=MergeTree()", "ORDER BY", "PARTITION BY", "SETTINGS" + ], + regexSyntax: .match, + booleanLiteralStyle: .numeric, + likeEscapeStyle: .implicit, + paginationStyle: .limit ) func createDriver(config: DriverConnectionConfig) -> any PluginDatabaseDriver { @@ -134,6 +141,25 @@ final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable { var serverVersion: String? { _serverVersion } var supportsSchemas: Bool { false } var supportsTransactions: Bool { false } + + func quoteIdentifier(_ name: String) -> String { + let escaped = name.replacingOccurrences(of: "`", with: "``") + return "`\(escaped)`" + } + + func escapeStringLiteral(_ value: String) -> String { + var result = value + result = result.replacingOccurrences(of: "\\", with: "\\\\") + result = result.replacingOccurrences(of: "'", with: "''") + result = result.replacingOccurrences(of: "\n", with: "\\n") + result = result.replacingOccurrences(of: "\r", with: "\\r") + result = result.replacingOccurrences(of: "\t", with: "\\t") + result = result.replacingOccurrences(of: "\0", with: "\\0") + result = result.replacingOccurrences(of: "\u{08}", with: "\\b") + result = result.replacingOccurrences(of: "\u{0C}", with: "\\f") + result = result.replacingOccurrences(of: "\u{1A}", with: "\\Z") + return result + } func beginTransaction() async throws {} func commitTransaction() async throws {} func rollbackTransaction() async throws {} diff --git a/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift b/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift index 94de4d07d..55ebb8472 100644 --- a/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift +++ b/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift @@ -88,7 +88,11 @@ final class DuckDBPlugin: NSObject, TableProPlugin, DriverPlugin { "DATE", "TIME", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "INTERVAL", "UUID", "JSON", "LIST", "MAP", "STRUCT", "UNION", "ENUM", "BIT" - ] + ], + regexSyntax: .regexpMatches, + booleanLiteralStyle: .truefalse, + likeEscapeStyle: .explicit, + paginationStyle: .limit ) func createDriver(config: DriverConnectionConfig) -> any PluginDatabaseDriver { diff --git a/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift b/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift index d51530f17..475a1ce22 100644 --- a/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift +++ b/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift @@ -80,7 +80,14 @@ final class MSSQLPlugin: NSObject, TableProPlugin, DriverPlugin { "DATE", "TIME", "DATETIME", "DATETIME2", "SMALLDATETIME", "DATETIMEOFFSET", "BIT", "UNIQUEIDENTIFIER", "XML", "SQL_VARIANT", "ROWVERSION", "TIMESTAMP", "HIERARCHYID" - ] + ], + tableOptions: [ + "ON", "CLUSTERED", "NONCLUSTERED", "WITH", "TEXTIMAGE_ON" + ], + regexSyntax: .unsupported, + booleanLiteralStyle: .numeric, + likeEscapeStyle: .explicit, + paginationStyle: .offsetFetch ) func createDriver(config: DriverConnectionConfig) -> any PluginDatabaseDriver { @@ -416,6 +423,11 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { var supportsSchemas: Bool { true } var supportsTransactions: Bool { true } + func quoteIdentifier(_ name: String) -> String { + let escaped = name.replacingOccurrences(of: "]", with: "]]") + return "[\(escaped)]" + } + init(config: DriverConnectionConfig) { self.config = config self._currentSchema = config.additionalFields["mssqlSchema"]?.isEmpty == false @@ -1250,7 +1262,7 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { // MARK: - Query Building Helpers private func mssqlQuoteIdentifier(_ identifier: String) -> String { - "[\(identifier.replacingOccurrences(of: "]", with: "]]"))]" + quoteIdentifier(identifier) } private func mssqlBuildOrderByClause( diff --git a/Plugins/MongoDBDriverPlugin/MongoDBPluginDriver.swift b/Plugins/MongoDBDriverPlugin/MongoDBPluginDriver.swift index 93ddd970b..89a2ed650 100644 --- a/Plugins/MongoDBDriverPlugin/MongoDBPluginDriver.swift +++ b/Plugins/MongoDBDriverPlugin/MongoDBPluginDriver.swift @@ -20,6 +20,7 @@ final class MongoDBPluginDriver: PluginDatabaseDriver { func beginTransaction() async throws {} func commitTransaction() async throws {} func rollbackTransaction() async throws {} + func quoteIdentifier(_ name: String) -> String { name } init(config: DriverConnectionConfig) { self.config = config diff --git a/Plugins/MySQLDriverPlugin/MySQLPlugin.swift b/Plugins/MySQLDriverPlugin/MySQLPlugin.swift index 4818855da..2de3921ef 100644 --- a/Plugins/MySQLDriverPlugin/MySQLPlugin.swift +++ b/Plugins/MySQLDriverPlugin/MySQLPlugin.swift @@ -73,7 +73,15 @@ final class MySQLPlugin: NSObject, TableProPlugin, DriverPlugin { "BLOB", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "DATE", "TIME", "DATETIME", "TIMESTAMP", "YEAR", "ENUM", "SET", "JSON", "BOOL", "BOOLEAN" - ] + ], + tableOptions: [ + "ENGINE=InnoDB", "DEFAULT CHARSET=utf8mb4", "COLLATE=utf8mb4_unicode_ci", + "AUTO_INCREMENT=", "COMMENT=", "ROW_FORMAT=" + ], + regexSyntax: .regexp, + booleanLiteralStyle: .numeric, + likeEscapeStyle: .implicit, + paginationStyle: .limit ) func createDriver(config: DriverConnectionConfig) -> any PluginDatabaseDriver { diff --git a/Plugins/MySQLDriverPlugin/MySQLPluginDriver.swift b/Plugins/MySQLDriverPlugin/MySQLPluginDriver.swift index fb79e92e2..fd4d0db4d 100644 --- a/Plugins/MySQLDriverPlugin/MySQLPluginDriver.swift +++ b/Plugins/MySQLDriverPlugin/MySQLPluginDriver.swift @@ -24,6 +24,25 @@ final class MySQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { var supportsSchemas: Bool { false } var supportsTransactions: Bool { true } + func quoteIdentifier(_ name: String) -> String { + let escaped = name.replacingOccurrences(of: "`", with: "``") + return "`\(escaped)`" + } + + func escapeStringLiteral(_ value: String) -> String { + var result = value + result = result.replacingOccurrences(of: "\\", with: "\\\\") + result = result.replacingOccurrences(of: "'", with: "''") + result = result.replacingOccurrences(of: "\n", with: "\\n") + result = result.replacingOccurrences(of: "\r", with: "\\r") + result = result.replacingOccurrences(of: "\t", with: "\\t") + result = result.replacingOccurrences(of: "\0", with: "\\0") + result = result.replacingOccurrences(of: "\u{08}", with: "\\b") + result = result.replacingOccurrences(of: "\u{0C}", with: "\\f") + result = result.replacingOccurrences(of: "\u{1A}", with: "\\Z") + return result + } + private static let tableNameRegex = try? NSRegularExpression(pattern: "(?i)\\bFROM\\s+[`\"']?([\\w]+)[`\"']?") private static let limitRegex = try? NSRegularExpression(pattern: "(?i)\\s+LIMIT\\s+\\d+(\\s*,\\s*\\d+)?") private static let offsetRegex = try? NSRegularExpression(pattern: "(?i)\\s+OFFSET\\s+\\d+") diff --git a/Plugins/OracleDriverPlugin/OraclePlugin.swift b/Plugins/OracleDriverPlugin/OraclePlugin.swift index 912c30fae..52519b0ac 100644 --- a/Plugins/OracleDriverPlugin/OraclePlugin.swift +++ b/Plugins/OracleDriverPlugin/OraclePlugin.swift @@ -79,7 +79,14 @@ final class OraclePlugin: NSObject, TableProPlugin, DriverPlugin { "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE", "INTERVAL YEAR TO MONTH", "INTERVAL DAY TO SECOND", "BOOLEAN", "ROWID", "UROWID", "XMLTYPE", "SDO_GEOMETRY" - ] + ], + tableOptions: [ + "TABLESPACE", "PCTFREE", "INITRANS" + ], + regexSyntax: .regexpLike, + booleanLiteralStyle: .numeric, + likeEscapeStyle: .explicit, + paginationStyle: .offsetFetch ) func createDriver(config: DriverConnectionConfig) -> any PluginDatabaseDriver { diff --git a/Plugins/PostgreSQLDriverPlugin/PostgreSQLPlugin.swift b/Plugins/PostgreSQLDriverPlugin/PostgreSQLPlugin.swift index bc6f36649..f37856dbe 100644 --- a/Plugins/PostgreSQLDriverPlugin/PostgreSQLPlugin.swift +++ b/Plugins/PostgreSQLDriverPlugin/PostgreSQLPlugin.swift @@ -81,7 +81,14 @@ final class PostgreSQLPlugin: NSObject, TableProPlugin, DriverPlugin { "CHAR", "CHARACTER", "VARCHAR", "TEXT", "DATE", "TIME", "TIMESTAMP", "TIMESTAMPTZ", "INTERVAL", "BOOLEAN", "BOOL", "JSON", "JSONB", "UUID", "BYTEA", "ARRAY" - ] + ], + tableOptions: [ + "INHERITS", "PARTITION BY", "TABLESPACE", "WITH", "WITHOUT OIDS" + ], + regexSyntax: .tilde, + booleanLiteralStyle: .truefalse, + likeEscapeStyle: .explicit, + paginationStyle: .limit ) static func driverVariant(for databaseTypeId: String) -> String? { diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver.swift index 98fde2210..3578c082e 100644 --- a/Plugins/RedisDriverPlugin/RedisPluginDriver.swift +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver.swift @@ -23,6 +23,8 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { redisConnection?.serverVersion() } + func quoteIdentifier(_ name: String) -> String { name } + init(config: DriverConnectionConfig) { self.config = config } diff --git a/Plugins/SQLiteDriverPlugin/SQLitePlugin.swift b/Plugins/SQLiteDriverPlugin/SQLitePlugin.swift index 4c6195711..a7f592468 100644 --- a/Plugins/SQLiteDriverPlugin/SQLitePlugin.swift +++ b/Plugins/SQLiteDriverPlugin/SQLitePlugin.swift @@ -71,7 +71,14 @@ final class SQLitePlugin: NSObject, TableProPlugin, DriverPlugin { "NVARCHAR", "CLOB", "DOUBLE", "PRECISION", "FLOAT", "DECIMAL", "BOOLEAN", "DATE", "DATETIME" - ] + ], + tableOptions: [ + "WITHOUT ROWID", "STRICT" + ], + regexSyntax: .unsupported, + booleanLiteralStyle: .numeric, + likeEscapeStyle: .explicit, + paginationStyle: .limit ) func createDriver(config: DriverConnectionConfig) -> any PluginDatabaseDriver { @@ -317,6 +324,11 @@ final class SQLitePluginDriver: PluginDatabaseDriver, @unchecked Sendable { var supportsSchemas: Bool { false } var supportsTransactions: Bool { true } + func quoteIdentifier(_ name: String) -> String { + let escaped = name.replacingOccurrences(of: "`", with: "``") + return "`\(escaped)`" + } + init(config: DriverConnectionConfig) { self.config = config } @@ -651,10 +663,6 @@ final class SQLitePluginDriver: PluginDatabaseDriver, @unchecked Sendable { return path } - private func escapeStringLiteral(_ value: String) -> String { - value.replacingOccurrences(of: "'", with: "''") - } - private func stripLimitOffset(from query: String) -> String { var result = query diff --git a/Plugins/TableProPluginKit/PluginDatabaseDriver.swift b/Plugins/TableProPluginKit/PluginDatabaseDriver.swift index d7848038b..82f318631 100644 --- a/Plugins/TableProPluginKit/PluginDatabaseDriver.swift +++ b/Plugins/TableProPluginKit/PluginDatabaseDriver.swift @@ -110,6 +110,12 @@ public protocol PluginDatabaseDriver: AnyObject, Sendable { // EXPLAIN query building (optional) func buildExplainQuery(_ sql: String) -> String? + + // Identifier quoting + func quoteIdentifier(_ name: String) -> String + + // String escaping + func escapeStringLiteral(_ value: String) -> String } public extension PluginDatabaseDriver { @@ -218,6 +224,18 @@ public extension PluginDatabaseDriver { func buildExplainQuery(_ sql: String) -> String? { nil } + func quoteIdentifier(_ name: String) -> String { + let escaped = name.replacingOccurrences(of: "\"", with: "\"\"") + return "\"\(escaped)\"" + } + + func escapeStringLiteral(_ value: String) -> String { + var result = value + result = result.replacingOccurrences(of: "'", with: "''") + result = result.replacingOccurrences(of: "\0", with: "") + return result + } + func executeParameterized(query: String, parameters: [String?]) async throws -> PluginQueryResult { guard !parameters.isEmpty else { return try await execute(query: query) diff --git a/Plugins/TableProPluginKit/SQLDialectDescriptor.swift b/Plugins/TableProPluginKit/SQLDialectDescriptor.swift index f5b817460..5deecdb5c 100644 --- a/Plugins/TableProPluginKit/SQLDialectDescriptor.swift +++ b/Plugins/TableProPluginKit/SQLDialectDescriptor.swift @@ -5,16 +5,57 @@ public struct SQLDialectDescriptor: Sendable { public let keywords: Set public let functions: Set public let dataTypes: Set + public let tableOptions: [String] + + // Filter dialect + public let regexSyntax: RegexSyntax + public let booleanLiteralStyle: BooleanLiteralStyle + public let likeEscapeStyle: LikeEscapeStyle + public let paginationStyle: PaginationStyle + + public enum RegexSyntax: String, Sendable { + case regexp // MySQL: column REGEXP 'pattern' + case tilde // PostgreSQL: column ~ 'pattern' + case regexpMatches // DuckDB: regexp_matches(column, 'pattern') + case match // ClickHouse: match(column, 'pattern') + case regexpLike // Oracle: REGEXP_LIKE(column, 'pattern') + case unsupported // SQLite, MSSQL, MongoDB, Redis + } + + public enum BooleanLiteralStyle: String, Sendable { + case truefalse // PostgreSQL, DuckDB: TRUE/FALSE + case numeric // MySQL, SQLite, etc: 1/0 + } + + public enum LikeEscapeStyle: String, Sendable { + case implicit // MySQL: backslash is default escape, no ESCAPE clause needed + case explicit // PostgreSQL, SQLite, etc: need ESCAPE '\' clause + } + + public enum PaginationStyle: String, Sendable { + case limit // MySQL, PostgreSQL, SQLite, etc: LIMIT n + case offsetFetch // Oracle, MSSQL: OFFSET n ROWS FETCH NEXT m ROWS ONLY + } public init( identifierQuote: String, keywords: Set, functions: Set, - dataTypes: Set + dataTypes: Set, + tableOptions: [String] = [], + regexSyntax: RegexSyntax = .unsupported, + booleanLiteralStyle: BooleanLiteralStyle = .numeric, + likeEscapeStyle: LikeEscapeStyle = .explicit, + paginationStyle: PaginationStyle = .limit ) { self.identifierQuote = identifierQuote self.keywords = keywords self.functions = functions self.dataTypes = dataTypes + self.tableOptions = tableOptions + self.regexSyntax = regexSyntax + self.booleanLiteralStyle = booleanLiteralStyle + self.likeEscapeStyle = likeEscapeStyle + self.paginationStyle = paginationStyle } } diff --git a/TablePro/Core/Autocomplete/CompletionEngine.swift b/TablePro/Core/Autocomplete/CompletionEngine.swift index 84be48241..faabc78a8 100644 --- a/TablePro/Core/Autocomplete/CompletionEngine.swift +++ b/TablePro/Core/Autocomplete/CompletionEngine.swift @@ -6,6 +6,7 @@ // import Foundation +import TableProPluginKit /// Completion context returned by the engine struct CompletionContext { @@ -29,8 +30,8 @@ final class CompletionEngine { // MARK: - Initialization - init(schemaProvider: SQLSchemaProvider, databaseType: DatabaseType? = nil) { - self.provider = SQLCompletionProvider(schemaProvider: schemaProvider, databaseType: databaseType) + init(schemaProvider: SQLSchemaProvider, databaseType: DatabaseType? = nil, dialect: SQLDialectDescriptor? = nil) { + self.provider = SQLCompletionProvider(schemaProvider: schemaProvider, databaseType: databaseType, dialect: dialect) } // MARK: - Public API diff --git a/TablePro/Core/Autocomplete/SQLCompletionProvider.swift b/TablePro/Core/Autocomplete/SQLCompletionProvider.swift index 1d23c4777..0ed797c01 100644 --- a/TablePro/Core/Autocomplete/SQLCompletionProvider.swift +++ b/TablePro/Core/Autocomplete/SQLCompletionProvider.swift @@ -6,6 +6,7 @@ // import Foundation +import TableProPluginKit /// Main provider for SQL autocomplete suggestions final class SQLCompletionProvider { @@ -14,6 +15,7 @@ final class SQLCompletionProvider { private let contextAnalyzer = SQLContextAnalyzer() private let schemaProvider: SQLSchemaProvider private var databaseType: DatabaseType? + private var cachedDialect: SQLDialectDescriptor? /// Minimum prefix length to trigger suggestions private let minPrefixLength = 1 @@ -23,14 +25,16 @@ final class SQLCompletionProvider { // MARK: - Init - init(schemaProvider: SQLSchemaProvider, databaseType: DatabaseType? = nil) { + init(schemaProvider: SQLSchemaProvider, databaseType: DatabaseType? = nil, dialect: SQLDialectDescriptor? = nil) { self.schemaProvider = schemaProvider self.databaseType = databaseType + self.cachedDialect = dialect } /// Update the database type for context-aware completions - func setDatabaseType(_ type: DatabaseType) { + func setDatabaseType(_ type: DatabaseType, dialect: SQLDialectDescriptor? = nil) { self.databaseType = type + self.cachedDialect = dialect } // MARK: - Public API @@ -316,31 +320,12 @@ final class SQLCompletionProvider { ]) items += dataTypeKeywords() } else { - // Pre-paren (CREATE TABLE ...) or post-paren (CREATE TABLE (...) ...) - items = filterKeywords([ - "IF NOT EXISTS", - ]) - // Database-specific table options (for post-paren context) - switch databaseType { - case .mysql, .mariadb: - items += filterKeywords([ - "ENGINE", "CHARSET", "COLLATE", "COMMENT", - "AUTO_INCREMENT", "ROW_FORMAT", "DEFAULT CHARSET", - ]) - case .postgresql, .redshift: - items += filterKeywords([ - "TABLESPACE", "INHERITS", "PARTITION BY", - "WITH", "WITHOUT OIDS", - ]) - case .mssql: - items += filterKeywords([ - "ON", "CLUSTERED", "NONCLUSTERED", - "WITH", "TEXTIMAGE_ON", - ]) - default: + items = filterKeywords(["IF NOT EXISTS"]) + if let options = cachedDialect?.tableOptions { + items += filterKeywords(options) + } else { items += filterKeywords([ - "ENGINE", "CHARSET", "COLLATE", "COMMENT", - "TABLESPACE", + "ENGINE", "CHARSET", "COLLATE", "COMMENT", "TABLESPACE" ]) } } @@ -466,141 +451,48 @@ final class SQLCompletionProvider { } /// SQL data type keywords (database-aware), with a slight priority boost - /// so they sort before generic constraint keywords in CREATE TABLE context + /// so they sort before generic constraint keywords in CREATE TABLE context. + /// Uses plugin-provided dialect data when available; falls back to common SQL types. private func dataTypeKeywords() -> [SQLCompletionItem] { - var types: [String] = [ - // Common numeric types (all databases) - "INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", - "DECIMAL", "NUMERIC", "FLOAT", "DOUBLE", "REAL", - // Common string types - "VARCHAR", "CHAR", "TEXT", - // Common date/time types - "DATE", "TIME", "DATETIME", "TIMESTAMP", - // Boolean - "BOOLEAN", "BOOL", - ] - - // Add database-specific types - switch databaseType { - case .mysql, .mariadb: - types += [ - "MEDIUMINT", "DOUBLE PRECISION", - "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", - "BLOB", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", - "YEAR", "ENUM", "SET", "JSON", - "BINARY", "VARBINARY", - ] - - case .postgresql, .redshift: - types += [ - "BIGSERIAL", "SERIAL", "SMALLSERIAL", - "DOUBLE PRECISION", "MONEY", - "CHARACTER", "CHARACTER VARYING", "CLOB", - "BYTEA", "UUID", "JSON", "JSONB", "XML", "ARRAY", - "TIMESTAMPTZ", "TIMETZ", "INTERVAL", - "POINT", "LINE", "LSEG", "BOX", "PATH", "POLYGON", "CIRCLE", - "INET", "CIDR", "MACADDR", "MACADDR8", - ] - - case .mssql: - types += [ - "NVARCHAR", "NCHAR", "NTEXT", - "MONEY", "SMALLMONEY", - "DATETIMEOFFSET", "DATETIME2", "SMALLDATETIME", - "BINARY", "VARBINARY", "IMAGE", - "UNIQUEIDENTIFIER", "XML", "SQL_VARIANT", - "ROWVERSION", "HIERARCHYID", - ] - - case .oracle: - types += [ - "NUMBER", "BINARY_FLOAT", "BINARY_DOUBLE", - "VARCHAR2", "NVARCHAR2", "NCHAR", "NCLOB", - "CLOB", "LONG", "RAW", "LONG RAW", "BFILE", - "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE", - "INTERVAL YEAR TO MONTH", "INTERVAL DAY TO SECOND", - "ROWID", "UROWID", "XMLTYPE", "SDO_GEOMETRY", - ] - - case .clickhouse: - types += [ - "UInt8", "UInt16", "UInt32", "UInt64", "UInt128", "UInt256", - "Int8", "Int16", "Int32", "Int64", "Int128", "Int256", - "Float32", "Float64", - "Decimal32", "Decimal64", "Decimal128", "Decimal256", - "String", "FixedString", "UUID", - "Date32", "DateTime64", - "Array", "Tuple", "Map", "Nested", - "Nullable", "LowCardinality", - "Enum8", "Enum16", - "IPv4", "IPv6", - "JSON", "Bool", - ] - - - case .sqlite: - types += [ - "BLOB", - ] - - case .duckdb: - types += [ - "HUGEINT", "TINYINT", "SMALLINT", "REAL", "NUMERIC", - "CHAR", "BPCHAR", - "BLOB", "BYTEA", - "TIMESTAMP WITH TIME ZONE", - "LIST", "MAP", "STRUCT", "UNION", "ENUM", "UUID", "JSON", "BIT", "INTERVAL", - ] - - case .mongodb: - // MongoDB types are case-sensitive — return directly without uppercasing - let mongoTypes = [ + // MongoDB and Redis use case-sensitive, non-SQL types + if databaseType == .mongodb { + return [ "ObjectId", "String", "Int32", "Int64", "Double", "Decimal128", "Boolean", "Date", "Timestamp", "BinData", "Array", "Object", - "Null", "Regex", "UUID", - ] - return mongoTypes.map { typeName in - var item = SQLCompletionItem( - label: typeName, - kind: .keyword, - insertText: typeName - ) + "Null", "Regex", "UUID" + ].map { typeName in + var item = SQLCompletionItem(label: typeName, kind: .keyword, insertText: typeName) item.sortPriority = 380 return item } - - case .redis: - let redisTypes = [ - "String", "List", "Set", "Sorted Set", "Hash", "Stream", - ] - return redisTypes.map { typeName in - var item = SQLCompletionItem( - label: typeName, - kind: .keyword, - insertText: typeName - ) + } + if databaseType == .redis { + return [ + "String", "List", "Set", "Sorted Set", "Hash", "Stream" + ].map { typeName in + var item = SQLCompletionItem(label: typeName, kind: .keyword, insertText: typeName) item.sortPriority = 380 return item } + } - case .none: - // Include all types if database type is unknown - types += [ - "MEDIUMINT", "DOUBLE PRECISION", - "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", - "BLOB", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", - "CLOB", "NCHAR", "NVARCHAR", - "YEAR", "INTERVAL", "TIMESTAMPTZ", "TIMETZ", - "BIT", "JSON", "JSONB", "XML", "ARRAY", - "UUID", "BINARY", "VARBINARY", "BYTEA", - "ENUM", "SET", - "SERIAL", "BIGSERIAL", "SMALLSERIAL", "MONEY", - "POINT", "LINE", "LSEG", "BOX", "PATH", "POLYGON", "CIRCLE", - "INET", "CIDR", "MACADDR", "MACADDR8", - ] + if let descriptor = cachedDialect, !descriptor.dataTypes.isEmpty { + return descriptor.dataTypes.sorted().map { typeName in + var item = SQLCompletionItem(label: typeName, kind: .keyword, insertText: typeName) + item.sortPriority = 380 + return item + } } - return types.map { typeName in + let commonTypes: [String] = [ + "INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", + "DECIMAL", "NUMERIC", "FLOAT", "DOUBLE", "REAL", + "VARCHAR", "CHAR", "TEXT", + "DATE", "TIME", "DATETIME", "TIMESTAMP", + "BOOLEAN", "BOOL", + "BLOB", "JSON", "UUID" + ] + return commonTypes.map { typeName in var item = SQLCompletionItem.keyword(typeName) item.sortPriority = 380 return item diff --git a/TablePro/Core/ChangeTracking/DataChangeManager.swift b/TablePro/Core/ChangeTracking/DataChangeManager.swift index ad5766347..09f6f0a56 100644 --- a/TablePro/Core/ChangeTracking/DataChangeManager.swift +++ b/TablePro/Core/ChangeTracking/DataChangeManager.swift @@ -674,7 +674,8 @@ final class DataChangeManager { tableName: tableName, columns: columns, primaryKeyColumn: primaryKeyColumn, - databaseType: databaseType + databaseType: databaseType, + quoteIdentifier: pluginDriver?.quoteIdentifier ) let statements = generator.generateStatements( from: changes, diff --git a/TablePro/Core/ChangeTracking/SQLStatementGenerator.swift b/TablePro/Core/ChangeTracking/SQLStatementGenerator.swift index 57bec156b..6c89eccf4 100644 --- a/TablePro/Core/ChangeTracking/SQLStatementGenerator.swift +++ b/TablePro/Core/ChangeTracking/SQLStatementGenerator.swift @@ -25,19 +25,22 @@ struct SQLStatementGenerator { let primaryKeyColumn: String? let databaseType: DatabaseType let parameterStyle: ParameterStyle + private let quoteIdentifierFn: (String) -> String init( tableName: String, columns: [String], primaryKeyColumn: String?, databaseType: DatabaseType, - parameterStyle: ParameterStyle? = nil + parameterStyle: ParameterStyle? = nil, + quoteIdentifier: ((String) -> String)? = nil ) { self.tableName = tableName self.columns = columns self.primaryKeyColumn = primaryKeyColumn self.databaseType = databaseType self.parameterStyle = parameterStyle ?? Self.defaultParameterStyle(for: databaseType) + self.quoteIdentifierFn = quoteIdentifier ?? databaseType.quoteIdentifier } private static func defaultParameterStyle(for databaseType: DatabaseType) -> ParameterStyle { @@ -156,7 +159,7 @@ struct SQLStatementGenerator { guard index < columns.count else { continue } let columnName = columns[index] - nonDefaultColumns.append(databaseType.quoteIdentifier(columnName)) + nonDefaultColumns.append(quoteIdentifierFn(columnName)) if let val = value { if isSQLFunctionExpression(val) { @@ -177,7 +180,7 @@ struct SQLStatementGenerator { let placeholders = placeholderParts.joined(separator: ", ") let sql = - "INSERT INTO \(databaseType.quoteIdentifier(tableName)) (\(columnList)) VALUES (\(placeholders))" + "INSERT INTO \(quoteIdentifierFn(tableName)) (\(columnList)) VALUES (\(placeholders))" return ParameterizedStatement(sql: sql, parameters: bindParameters) } @@ -196,7 +199,7 @@ struct SQLStatementGenerator { guard !nonDefaultChanges.isEmpty else { return nil } let columnNames = nonDefaultChanges.map { - databaseType.quoteIdentifier($0.columnName) + quoteIdentifierFn($0.columnName) }.joined(separator: ", ") var parameters: [Any?] = [] @@ -214,7 +217,7 @@ struct SQLStatementGenerator { }.joined(separator: ", ") let sql = - "INSERT INTO \(databaseType.quoteIdentifier(tableName)) (\(columnNames)) VALUES (\(placeholders))" + "INSERT INTO \(quoteIdentifierFn(tableName)) (\(columnNames)) VALUES (\(placeholders))" return ParameterizedStatement(sql: sql, parameters: parameters) } @@ -234,20 +237,20 @@ struct SQLStatementGenerator { var parameters: [Any?] = [] let setClauses = change.cellChanges.map { cellChange -> String in if cellChange.newValue == "__DEFAULT__" { - return "\(databaseType.quoteIdentifier(cellChange.columnName)) = DEFAULT" + return "\(quoteIdentifierFn(cellChange.columnName)) = DEFAULT" } else if let newValue = cellChange.newValue { if isSQLFunctionExpression(newValue) { return - "\(databaseType.quoteIdentifier(cellChange.columnName)) = \(newValue.trimmingCharacters(in: .whitespaces).uppercased())" + "\(quoteIdentifierFn(cellChange.columnName)) = \(newValue.trimmingCharacters(in: .whitespaces).uppercased())" } else { parameters.append(newValue) return - "\(databaseType.quoteIdentifier(cellChange.columnName)) = \(placeholder(at: parameters.count - 1))" + "\(quoteIdentifierFn(cellChange.columnName)) = \(placeholder(at: parameters.count - 1))" } } else { parameters.append(nil) return - "\(databaseType.quoteIdentifier(cellChange.columnName)) = \(placeholder(at: parameters.count - 1))" + "\(quoteIdentifierFn(cellChange.columnName)) = \(placeholder(at: parameters.count - 1))" } }.joined(separator: ", ") @@ -271,9 +274,9 @@ struct SQLStatementGenerator { parameters.append(pkValue) let whereClause = - "\(databaseType.quoteIdentifier(pkColumn)) = \(placeholder(at: parameters.count - 1))" + "\(quoteIdentifierFn(pkColumn)) = \(placeholder(at: parameters.count - 1))" let sql = - "UPDATE \(databaseType.quoteIdentifier(tableName)) SET \(setClauses) WHERE \(whereClause)" + "UPDATE \(quoteIdentifierFn(tableName)) SET \(setClauses) WHERE \(whereClause)" return ParameterizedStatement(sql: sql, parameters: parameters) } else { guard let originalRow = change.originalRow else { @@ -287,7 +290,7 @@ struct SQLStatementGenerator { for (index, columnName) in columns.enumerated() { guard index < originalRow.count else { continue } let value = originalRow[index] - let quotedColumn = databaseType.quoteIdentifier(columnName) + let quotedColumn = quoteIdentifierFn(columnName) if let value = value { parameters.append(value) conditions.append("\(quotedColumn) = \(placeholder(at: parameters.count - 1))") @@ -300,7 +303,7 @@ struct SQLStatementGenerator { let whereClause = conditions.joined(separator: " AND ") let sql = - "UPDATE \(databaseType.quoteIdentifier(tableName)) SET \(setClauses) WHERE \(whereClause)" + "UPDATE \(quoteIdentifierFn(tableName)) SET \(setClauses) WHERE \(whereClause)" return ParameterizedStatement(sql: sql, parameters: parameters) } @@ -327,13 +330,13 @@ struct SQLStatementGenerator { parameters.append(originalRow[pkIndex]) return - "\(databaseType.quoteIdentifier(pkColumn)) = \(placeholder(at: parameters.count - 1))" + "\(quoteIdentifierFn(pkColumn)) = \(placeholder(at: parameters.count - 1))" } guard !conditions.isEmpty else { return nil } let whereClause = conditions.joined(separator: " OR ") - let sql = "DELETE FROM \(databaseType.quoteIdentifier(tableName)) WHERE \(whereClause)" + let sql = "DELETE FROM \(quoteIdentifierFn(tableName)) WHERE \(whereClause)" return ParameterizedStatement(sql: sql, parameters: parameters) } @@ -354,7 +357,7 @@ struct SQLStatementGenerator { guard index < originalRow.count else { continue } let value = originalRow[index] - let quotedColumn = databaseType.quoteIdentifier(columnName) + let quotedColumn = quoteIdentifierFn(columnName) if let value = value { parameters.append(value) @@ -367,7 +370,7 @@ struct SQLStatementGenerator { guard !conditions.isEmpty else { return nil } let whereClause = conditions.joined(separator: " AND ") - let sql = "DELETE FROM \(databaseType.quoteIdentifier(tableName)) WHERE \(whereClause)" + let sql = "DELETE FROM \(quoteIdentifierFn(tableName)) WHERE \(whereClause)" return ParameterizedStatement(sql: sql, parameters: parameters) } diff --git a/TablePro/Core/Database/DatabaseDriver.swift b/TablePro/Core/Database/DatabaseDriver.swift index 3cd7ed59b..e5be8aa6a 100644 --- a/TablePro/Core/Database/DatabaseDriver.swift +++ b/TablePro/Core/Database/DatabaseDriver.swift @@ -140,6 +140,12 @@ protocol DatabaseDriver: AnyObject { /// Access to the underlying plugin driver for query building dispatch var queryBuildingPluginDriver: (any PluginDatabaseDriver)? { get } + + /// Quote an identifier (table or column name) using the driver's quoting style + func quoteIdentifier(_ name: String) -> String + + /// Escape a string value for safe use in SQL string literals + func escapeStringLiteral(_ value: String) -> String } // MARK: - Schema Switching @@ -160,6 +166,19 @@ extension DatabaseDriver { var queryBuildingPluginDriver: (any PluginDatabaseDriver)? { nil } + func quoteIdentifier(_ name: String) -> String { + let q = "\"" + let escaped = name.replacingOccurrences(of: q, with: q + q) + return "\(q)\(escaped)\(q)" + } + + func escapeStringLiteral(_ value: String) -> String { + var result = value + result = result.replacingOccurrences(of: "'", with: "''") + result = result.replacingOccurrences(of: "\0", with: "") + return result + } + func testConnection() async throws -> Bool { try await connect() disconnect() diff --git a/TablePro/Core/Database/FilterSQLGenerator.swift b/TablePro/Core/Database/FilterSQLGenerator.swift index 4fdb4eeaf..c740d8443 100644 --- a/TablePro/Core/Database/FilterSQLGenerator.swift +++ b/TablePro/Core/Database/FilterSQLGenerator.swift @@ -11,6 +11,73 @@ import TableProPluginKit /// Generates SQL WHERE clauses from filter definitions struct FilterSQLGenerator { let databaseType: DatabaseType + private let dialect: SQLDialectDescriptor + private let quoteIdentifierFn: (String) -> String + + init( + databaseType: DatabaseType, + dialect: SQLDialectDescriptor? = nil, + quoteIdentifier: ((String) -> String)? = nil + ) { + self.databaseType = databaseType + self.dialect = dialect ?? Self.fallbackDialect(for: databaseType) + self.quoteIdentifierFn = quoteIdentifier ?? databaseType.quoteIdentifier + } + + /// Fallback dialect properties when no plugin-provided descriptor is available. + /// Preserves pre-existing behavior for each database type. + private static func fallbackDialect(for databaseType: DatabaseType) -> SQLDialectDescriptor { + switch databaseType { + case .mysql, .mariadb: + return SQLDialectDescriptor( + identifierQuote: "`", keywords: [], functions: [], dataTypes: [], + regexSyntax: .regexp, booleanLiteralStyle: .numeric, + likeEscapeStyle: .implicit, paginationStyle: .limit + ) + case .postgresql, .redshift: + return SQLDialectDescriptor( + identifierQuote: "\"", keywords: [], functions: [], dataTypes: [], + regexSyntax: .tilde, booleanLiteralStyle: .truefalse, + likeEscapeStyle: .explicit, paginationStyle: .limit + ) + case .sqlite: + return SQLDialectDescriptor( + identifierQuote: "`", keywords: [], functions: [], dataTypes: [], + regexSyntax: .unsupported, booleanLiteralStyle: .numeric, + likeEscapeStyle: .explicit, paginationStyle: .limit + ) + case .clickhouse: + return SQLDialectDescriptor( + identifierQuote: "`", keywords: [], functions: [], dataTypes: [], + regexSyntax: .match, booleanLiteralStyle: .numeric, + likeEscapeStyle: .implicit, paginationStyle: .limit + ) + case .mssql: + return SQLDialectDescriptor( + identifierQuote: "[", keywords: [], functions: [], dataTypes: [], + regexSyntax: .unsupported, booleanLiteralStyle: .numeric, + likeEscapeStyle: .explicit, paginationStyle: .offsetFetch + ) + case .oracle: + return SQLDialectDescriptor( + identifierQuote: "\"", keywords: [], functions: [], dataTypes: [], + regexSyntax: .regexpLike, booleanLiteralStyle: .numeric, + likeEscapeStyle: .explicit, paginationStyle: .offsetFetch + ) + case .duckdb: + return SQLDialectDescriptor( + identifierQuote: "\"", keywords: [], functions: [], dataTypes: [], + regexSyntax: .regexpMatches, booleanLiteralStyle: .truefalse, + likeEscapeStyle: .explicit, paginationStyle: .limit + ) + case .mongodb, .redis: + return SQLDialectDescriptor( + identifierQuote: "`", keywords: [], functions: [], dataTypes: [], + regexSyntax: .unsupported, booleanLiteralStyle: .numeric, + likeEscapeStyle: .explicit, paginationStyle: .limit + ) + } + } // MARK: - Public API @@ -38,7 +105,7 @@ struct FilterSQLGenerator { return "(\(rawSQL))" } - let quotedColumn = databaseType.quoteIdentifier(filter.columnName) + let quotedColumn = quoteIdentifierFn(filter.columnName) switch filter.filterOperator { case .equal: @@ -102,14 +169,14 @@ struct FilterSQLGenerator { return "\(quotedColumn) BETWEEN \(escapeValue(filter.value)) AND \(escapeValue(secondValue))" case .regex: - // MongoDB filters are handled natively by MongoDBQueryBuilder + // MongoDB/Redis filters are handled natively by their query builders if databaseType == .mongodb || databaseType == .redis { return nil } - // SQLite doesn't support REGEXP without a custom function; fall back to LIKE - if databaseType == .sqlite { + let syntax = dialect.regexSyntax + if syntax == .unsupported { let escaped = escapeSQLQuote(filter.value) return "\(quotedColumn) LIKE '%\(escaped)%'" } - if databaseType == .clickhouse { + if syntax == .match { let escapedPattern = escapeStringValue(filter.value) return "match(\(quotedColumn), '\(escapedPattern)')" } @@ -120,15 +187,11 @@ struct FilterSQLGenerator { // MARK: - LIKE Conditions /// Database-specific ESCAPE clause for LIKE patterns. - /// MySQL/MariaDB default to `\` as the LIKE escape character, so no clause needed. - /// PostgreSQL and SQLite require an explicit ESCAPE declaration. + /// Implicit style (MySQL/MariaDB): backslash is the default LIKE escape, no clause needed. + /// Explicit style: requires an ESCAPE declaration. private var likeEscapeClause: String { - switch databaseType { - case .mysql, .mariadb: - return "" - case .postgresql, .redshift, .sqlite, .mongodb, .redis, .mssql, .oracle, .clickhouse, .duckdb: - return " ESCAPE '\\'" - } + if dialect.likeEscapeStyle == .implicit { return "" } + return " ESCAPE '\\'" } private func generateLikeCondition(column: String, pattern: String) -> String { @@ -141,19 +204,23 @@ struct FilterSQLGenerator { return "\(column) NOT LIKE '\(quotedPattern)'\(likeEscapeClause)" } - // MARK: - REGEX Conditions (Database-Specific) + // MARK: - REGEX Conditions private func generateRegexCondition(column: String, pattern: String) -> String { let escapedPattern = escapeStringValue(pattern) - switch databaseType { - case .mysql, .mariadb: + switch dialect.regexSyntax { + case .regexp: return "\(column) REGEXP '\(escapedPattern)'" - case .postgresql, .redshift: + case .tilde: return "\(column) ~ '\(escapedPattern)'" - case .duckdb: + case .regexpMatches: return "regexp_matches(\(column), '\(escapedPattern)')" - case .sqlite, .mongodb, .redis, .mssql, .oracle, .clickhouse: + case .regexpLike: + return "REGEXP_LIKE(\(column), '\(escapedPattern)')" + case .match: + return "match(\(column), '\(escapedPattern)')" + case .unsupported: return "\(column) LIKE '%\(escapedPattern)%'" } } @@ -171,10 +238,10 @@ struct FilterSQLGenerator { // Check for boolean literals if trimmed.caseInsensitiveCompare("TRUE") == .orderedSame { - return databaseType == .postgresql || databaseType == .redshift || databaseType == .duckdb ? "TRUE" : "1" + return dialect.booleanLiteralStyle == .truefalse ? "TRUE" : "1" } if trimmed.caseInsensitiveCompare("FALSE") == .orderedSame { - return databaseType == .postgresql || databaseType == .redshift || databaseType == .duckdb ? "FALSE" : "0" + return dialect.booleanLiteralStyle == .truefalse ? "FALSE" : "0" } // Try to detect numeric values @@ -197,17 +264,23 @@ struct FilterSQLGenerator { /// Escape special characters in string values private func escapeStringValue(_ value: String) -> String { // Fast path: most values have no special chars - guard value.contains("\\") || value.contains("'") else { return value } - return value - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "'", with: "''") + if dialect.likeEscapeStyle == .implicit { + // MySQL/MariaDB/ClickHouse: backslash is significant in string literals + guard value.contains("\\") || value.contains("'") else { return value } + return value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "'", with: "''") + } else { + // ANSI SQL: only single-quote needs escaping + guard value.contains("'") else { return value } + return value.replacingOccurrences(of: "'", with: "''") + } } private func escapeLikeWildcards(_ value: String) -> String { guard value.contains("\\") || value.contains("%") || value.contains("_") else { return value } - switch databaseType { - case .mysql, .mariadb: + if dialect.likeEscapeStyle == .implicit { // MySQL uses \ as both string escape and default LIKE escape. // Need double backslash in SQL string so string layer yields single \ // which LIKE then uses as escape char. @@ -215,12 +288,11 @@ struct FilterSQLGenerator { .replacingOccurrences(of: "\\", with: "\\\\\\\\") .replacingOccurrences(of: "%", with: "\\\\%") .replacingOccurrences(of: "_", with: "\\\\_") - default: - return value - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "%", with: "\\%") - .replacingOccurrences(of: "_", with: "\\_") } + return value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "%", with: "\\%") + .replacingOccurrences(of: "_", with: "\\_") } // MARK: - List Parsing @@ -259,7 +331,7 @@ extension FilterSQLGenerator { } } - let quotedTable = databaseType.quoteIdentifier(tableName) + let quotedTable = quoteIdentifierFn(tableName) var sql = "SELECT * FROM \(quotedTable)" let whereClause = generateWhereClause(from: filters) @@ -267,12 +339,10 @@ extension FilterSQLGenerator { sql += "\n\(whereClause)" } - switch databaseType { - case .oracle: - sql += "\nORDER BY 1 OFFSET 0 ROWS FETCH NEXT \(limit) ROWS ONLY" - case .mssql: - sql += "\nORDER BY (SELECT NULL) OFFSET 0 ROWS FETCH NEXT \(limit) ROWS ONLY" - default: + if dialect.paginationStyle == .offsetFetch { + let orderBy = databaseType == .oracle ? "ORDER BY 1" : "ORDER BY (SELECT NULL)" + sql += "\n\(orderBy) OFFSET 0 ROWS FETCH NEXT \(limit) ROWS ONLY" + } else { sql += "\nLIMIT \(limit)" } return sql diff --git a/TablePro/Core/Database/SQLEscaping.swift b/TablePro/Core/Database/SQLEscaping.swift index 46ffaba20..48c9a6b9d 100644 --- a/TablePro/Core/Database/SQLEscaping.swift +++ b/TablePro/Core/Database/SQLEscaping.swift @@ -10,53 +10,19 @@ import Foundation /// Centralized SQL escaping utilities to prevent SQL injection vulnerabilities enum SQLEscaping { - /// Escape a string value for use in SQL string literals (VALUES, WHERE clauses, etc.) + /// Escape a string value for use in SQL string literals using ANSI SQL rules. + /// Only doubles single quotes and strips null bytes. /// - /// MySQL/MariaDB: Uses backslash escape sequences for control characters (`\n`, `\t`, etc.) - /// PostgreSQL/SQLite: Uses standard SQL escaping (only single quotes doubled). - /// Newlines, tabs, and backslashes are valid as-is in standard SQL string literals. + /// For database-specific escaping (e.g., MySQL backslash sequences), use the + /// driver's `escapeStringLiteral` method instead. /// - /// Example: - /// ```swift - /// let safe = SQLEscaping.escapeStringLiteral("O'Brien\\test", databaseType: .mysql) - /// // Result: "O''Brien\\\\test" - /// let safe2 = SQLEscaping.escapeStringLiteral("O'Brien\\test", databaseType: .postgresql) - /// // Result: "O''Brien\\test" - /// ``` - /// - /// - Parameters: - /// - str: The raw string to escape - /// - databaseType: The target database type (defaults to `.mysql` for backward compatibility) - /// - Returns: The escaped string safe for use in SQL string literals - static func escapeStringLiteral(_ str: String, databaseType: DatabaseType = .mysql) -> String { - switch databaseType { - case .mysql, .mariadb, .clickhouse: - // MySQL/MariaDB/ClickHouse: backslash escaping is active by default - var result = str - // IMPORTANT: Escape backslashes FIRST to avoid double-escaping - result = result.replacingOccurrences(of: "\\", with: "\\\\") - // Single quote: SQL standard escaping (double the quote) - result = result.replacingOccurrences(of: "'", with: "''") - // Common control characters - result = result.replacingOccurrences(of: "\n", with: "\\n") - result = result.replacingOccurrences(of: "\r", with: "\\r") - result = result.replacingOccurrences(of: "\t", with: "\\t") - result = result.replacingOccurrences(of: "\0", with: "\\0") - // Additional control characters that can cause issues - result = result.replacingOccurrences(of: "\u{08}", with: "\\b") // Backspace - result = result.replacingOccurrences(of: "\u{0C}", with: "\\f") // Form feed - result = result.replacingOccurrences(of: "\u{1A}", with: "\\Z") // MySQL EOF marker (Ctrl+Z) - return result - - case .postgresql, .redshift, .sqlite, .mongodb, .redis, .mssql, .oracle, .duckdb: - // Standard SQL: only single quotes need doubling - // Newlines, tabs, backslashes are valid as-is in string literals - var result = str - result = result.replacingOccurrences(of: "'", with: "''") - // Strip null bytes (PostgreSQL rejects them in text) - result = result.replacingOccurrences(of: "\0", with: "") - return result - } + /// - Parameter str: The raw string to escape + /// - Returns: The escaped string safe for use in ANSI SQL string literals + static func escapeStringLiteral(_ str: String) -> String { + var result = str + result = result.replacingOccurrences(of: "'", with: "''") + result = result.replacingOccurrences(of: "\0", with: "") + return result } /// Known SQL temporal function expressions that should not be quoted/parameterized. diff --git a/TablePro/Core/Plugins/ExportDataSourceAdapter.swift b/TablePro/Core/Plugins/ExportDataSourceAdapter.swift index bf574494e..82222d0a3 100644 --- a/TablePro/Core/Plugins/ExportDataSourceAdapter.swift +++ b/TablePro/Core/Plugins/ExportDataSourceAdapter.swift @@ -50,11 +50,11 @@ final class ExportDataSourceAdapter: PluginExportDataSource, @unchecked Sendable } func quoteIdentifier(_ identifier: String) -> String { - dbType.quoteIdentifier(identifier) + driver.quoteIdentifier(identifier) } func escapeStringLiteral(_ value: String) -> String { - SQLEscaping.escapeStringLiteral(value, databaseType: dbType) + driver.escapeStringLiteral(value) } func fetchApproximateRowCount(table: String, databaseName: String) async throws -> Int? { @@ -75,10 +75,10 @@ final class ExportDataSourceAdapter: PluginExportDataSource, @unchecked Sendable private func qualifiedTableRef(table: String, databaseName: String) -> String { if databaseName.isEmpty { - return dbType.quoteIdentifier(table) + return driver.quoteIdentifier(table) } else { - let quotedDb = dbType.quoteIdentifier(databaseName) - let quotedTable = dbType.quoteIdentifier(table) + let quotedDb = driver.quoteIdentifier(databaseName) + let quotedTable = driver.quoteIdentifier(table) return "\(quotedDb).\(quotedTable)" } } diff --git a/TablePro/Core/Plugins/PluginDriverAdapter.swift b/TablePro/Core/Plugins/PluginDriverAdapter.swift index e110db798..be5102d34 100644 --- a/TablePro/Core/Plugins/PluginDriverAdapter.swift +++ b/TablePro/Core/Plugins/PluginDriverAdapter.swift @@ -46,7 +46,7 @@ final class PluginDriverAdapter: DatabaseDriver, SchemaSwitchable { return pluginDriver } var currentSchema: String { pluginDriver.currentSchema ?? connection.username } - var escapedSchema: String { SQLEscaping.escapeStringLiteral(currentSchema, databaseType: connection.type) } + var escapedSchema: String { pluginDriver.escapeStringLiteral(currentSchema) } private static let logger = Logger(subsystem: "com.TablePro", category: "PluginDriverAdapter") @@ -356,6 +356,16 @@ final class PluginDriverAdapter: DatabaseDriver, SchemaSwitchable { pluginDriver.buildExplainQuery(sql) } + // MARK: - Identifier Quoting + + func quoteIdentifier(_ name: String) -> String { + pluginDriver.quoteIdentifier(name) + } + + func escapeStringLiteral(_ value: String) -> String { + pluginDriver.escapeStringLiteral(value) + } + // MARK: - Result Mapping private func mapQueryResult(_ pluginResult: PluginQueryResult) -> QueryResult { diff --git a/TablePro/Core/Services/Export/ExportService.swift b/TablePro/Core/Services/Export/ExportService.swift index 1543e10b6..f6c46eae8 100644 --- a/TablePro/Core/Services/Export/ExportService.swift +++ b/TablePro/Core/Services/Export/ExportService.swift @@ -218,10 +218,10 @@ final class ExportService { let unionParts = batch.map { table -> String in let tableRef: String if table.databaseName.isEmpty { - tableRef = databaseType.quoteIdentifier(table.name) + tableRef = driver.quoteIdentifier(table.name) } else { - let quotedDb = databaseType.quoteIdentifier(table.databaseName) - let quotedTable = databaseType.quoteIdentifier(table.name) + let quotedDb = driver.quoteIdentifier(table.databaseName) + let quotedTable = driver.quoteIdentifier(table.name) tableRef = "\(quotedDb).\(quotedTable)" } return "SELECT COUNT(*) AS c FROM \(tableRef)" @@ -240,10 +240,10 @@ final class ExportService { do { let tableRef: String if table.databaseName.isEmpty { - tableRef = databaseType.quoteIdentifier(table.name) + tableRef = driver.quoteIdentifier(table.name) } else { - let quotedDb = databaseType.quoteIdentifier(table.databaseName) - let quotedTable = databaseType.quoteIdentifier(table.name) + let quotedDb = driver.quoteIdentifier(table.databaseName) + let quotedTable = driver.quoteIdentifier(table.name) tableRef = "\(quotedDb).\(quotedTable)" } let result = try await driver.execute(query: "SELECT COUNT(*) FROM \(tableRef)") diff --git a/TablePro/Core/Services/Query/TableQueryBuilder.swift b/TablePro/Core/Services/Query/TableQueryBuilder.swift index e46e3c5fd..8c7bc5ce0 100644 --- a/TablePro/Core/Services/Query/TableQueryBuilder.swift +++ b/TablePro/Core/Services/Query/TableQueryBuilder.swift @@ -27,6 +27,14 @@ struct TableQueryBuilder { pluginDriver = driver } + // MARK: - Identifier Quoting + + private func quote(_ name: String) -> String { + if let pluginDriver { return pluginDriver.quoteIdentifier(name) } + let escaped = name.replacingOccurrences(of: "\"", with: "\"\"") + return "\"\(escaped)\"" + } + // MARK: - Query Building func buildBaseQuery( @@ -46,7 +54,7 @@ struct TableQueryBuilder { } } - let quotedTable = databaseType.quoteIdentifier(tableName) + let quotedTable = quote(tableName) var query = "SELECT * FROM \(quotedTable)" if let orderBy = buildOrderByClause(sortState: sortState, columns: columns) { @@ -80,7 +88,7 @@ struct TableQueryBuilder { } } - let quotedTable = databaseType.quoteIdentifier(tableName) + let quotedTable = quote(tableName) return "SELECT * FROM \(quotedTable) LIMIT \(limit) OFFSET \(offset)" } @@ -102,7 +110,7 @@ struct TableQueryBuilder { } } - let quotedTable = databaseType.quoteIdentifier(tableName) + let quotedTable = quote(tableName) return "SELECT * FROM \(quotedTable) LIMIT \(limit) OFFSET \(offset)" } @@ -132,7 +140,7 @@ struct TableQueryBuilder { } } - let quotedTable = databaseType.quoteIdentifier(tableName) + let quotedTable = quote(tableName) return "SELECT * FROM \(quotedTable) LIMIT \(limit) OFFSET \(offset)" } @@ -143,7 +151,7 @@ struct TableQueryBuilder { ) -> String { var query = removeOrderBy(from: baseQuery) let direction = ascending ? "ASC" : "DESC" - let quotedColumn = databaseType.quoteIdentifier(columnName) + let quotedColumn = quote(columnName) let orderByClause = "ORDER BY \(quotedColumn) \(direction)" if let limitRange = query.range(of: "LIMIT", options: .caseInsensitive) { @@ -211,7 +219,7 @@ struct TableQueryBuilder { guard sortCol.columnIndex >= 0, sortCol.columnIndex < columns.count else { return nil } let columnName = columns[sortCol.columnIndex] let direction = sortCol.direction == .ascending ? "ASC" : "DESC" - let quotedColumn = databaseType.quoteIdentifier(columnName) + let quotedColumn = quote(columnName) return "\(quotedColumn) \(direction)" } diff --git a/TablePro/Core/Utilities/SQL/SQLRowToStatementConverter.swift b/TablePro/Core/Utilities/SQL/SQLRowToStatementConverter.swift index 6b42abc2f..8bfb009e7 100644 --- a/TablePro/Core/Utilities/SQL/SQLRowToStatementConverter.swift +++ b/TablePro/Core/Utilities/SQL/SQLRowToStatementConverter.swift @@ -9,9 +9,44 @@ internal struct SQLRowToStatementConverter { internal let columns: [String] internal let primaryKeyColumn: String? internal let databaseType: DatabaseType + private let quoteIdentifierFn: (String) -> String + private let escapeStringFn: (String) -> String + + init( + tableName: String, + columns: [String], + primaryKeyColumn: String?, + databaseType: DatabaseType, + quoteIdentifier: ((String) -> String)? = nil, + escapeStringLiteral: ((String) -> String)? = nil + ) { + self.tableName = tableName + self.columns = columns + self.primaryKeyColumn = primaryKeyColumn + self.databaseType = databaseType + self.quoteIdentifierFn = quoteIdentifier ?? databaseType.quoteIdentifier + self.escapeStringFn = escapeStringLiteral ?? Self.defaultEscapeFunction(for: databaseType) + } private static let maxRows = 50_000 + /// Fallback escape function when no plugin driver is available. + /// MySQL/MariaDB/ClickHouse need backslash escaping; others use ANSI SQL. + private static func defaultEscapeFunction(for databaseType: DatabaseType) -> (String) -> String { + switch databaseType { + case .mysql, .mariadb, .clickhouse: + return { value in + var result = value + result = result.replacingOccurrences(of: "\\", with: "\\\\") + result = result.replacingOccurrences(of: "'", with: "''") + result = result.replacingOccurrences(of: "\0", with: "\\0") + return result + } + default: + return SQLEscaping.escapeStringLiteral + } + } + internal func generateInserts(rows: [[String?]]) -> String { let capped = rows.prefix(Self.maxRows) let quotedTable = quoteColumn(tableName) @@ -84,14 +119,11 @@ internal struct SQLRowToStatementConverter { guard let value else { return "NULL" } - var escaped = value.replacingOccurrences(of: "'", with: "''") - if databaseType == .mysql || databaseType == .mariadb { - escaped = escaped.replacingOccurrences(of: "\\", with: "\\\\") - } + let escaped = escapeStringFn(value) return "'\(escaped)'" } private func quoteColumn(_ name: String) -> String { - databaseType.quoteIdentifier(name) + quoteIdentifierFn(name) } } diff --git a/TablePro/Models/Query/QueryTab.swift b/TablePro/Models/Query/QueryTab.swift index 68f52458a..220f61403 100644 --- a/TablePro/Models/Query/QueryTab.swift +++ b/TablePro/Models/Query/QueryTab.swift @@ -431,7 +431,12 @@ struct QueryTab: Identifiable, Equatable { /// Build a clean base query for a table tab (no filters/sort). /// Used when restoring table tabs from persistence to avoid stale WHERE clauses. - @MainActor static func buildBaseTableQuery(tableName: String, databaseType: DatabaseType) -> String { + @MainActor static func buildBaseTableQuery( + tableName: String, + databaseType: DatabaseType, + quoteIdentifier: ((String) -> String)? = nil + ) -> String { + let quote = quoteIdentifier ?? databaseType.quoteIdentifier let pageSize = AppSettingsManager.shared.dataGrid.defaultPageSize if databaseType == .mongodb { let escaped = tableName.replacingOccurrences(of: "\\", with: "\\\\").replacingOccurrences(of: "\"", with: "\\\"") @@ -439,13 +444,13 @@ struct QueryTab: Identifiable, Equatable { } else if databaseType == .redis { return "SCAN 0 MATCH * COUNT \(pageSize)" } else if databaseType == .mssql { - let quotedName = databaseType.quoteIdentifier(tableName) + let quotedName = quote(tableName) return "SELECT * FROM \(quotedName) ORDER BY (SELECT NULL) OFFSET 0 ROWS FETCH NEXT \(pageSize) ROWS ONLY;" } else if databaseType == .oracle { - let quotedName = databaseType.quoteIdentifier(tableName) + let quotedName = quote(tableName) return "SELECT * FROM \(quotedName) ORDER BY 1 OFFSET 0 ROWS FETCH NEXT \(pageSize) ROWS ONLY;" } else { - let quotedName = databaseType.quoteIdentifier(tableName) + let quotedName = quote(tableName) return "SELECT * FROM \(quotedName) LIMIT \(pageSize);" } } @@ -534,7 +539,12 @@ final class QueryTabManager { selectedTabId = newTab.id } - func addTableTab(tableName: String, databaseType: DatabaseType = .mysql, databaseName: String = "") { + func addTableTab( + tableName: String, + databaseType: DatabaseType = .mysql, + databaseName: String = "", + quoteIdentifier: ((String) -> String)? = nil + ) { // Check if table tab already exists (match on databaseName) if let existingTab = tabs.first(where: { $0.tabType == .table && $0.tableName == tableName && $0.databaseName == databaseName @@ -544,7 +554,9 @@ final class QueryTabManager { } let pageSize = AppSettingsManager.shared.dataGrid.defaultPageSize - let query = QueryTab.buildBaseTableQuery(tableName: tableName, databaseType: databaseType) + let query = QueryTab.buildBaseTableQuery( + tableName: tableName, databaseType: databaseType, quoteIdentifier: quoteIdentifier + ) var newTab = QueryTab( title: tableName, query: query, @@ -557,9 +569,16 @@ final class QueryTabManager { selectedTabId = newTab.id } - func addPreviewTableTab(tableName: String, databaseType: DatabaseType = .mysql, databaseName: String = "") { + func addPreviewTableTab( + tableName: String, + databaseType: DatabaseType = .mysql, + databaseName: String = "", + quoteIdentifier: ((String) -> String)? = nil + ) { let pageSize = AppSettingsManager.shared.dataGrid.defaultPageSize - let query = QueryTab.buildBaseTableQuery(tableName: tableName, databaseType: databaseType) + let query = QueryTab.buildBaseTableQuery( + tableName: tableName, databaseType: databaseType, quoteIdentifier: quoteIdentifier + ) var newTab = QueryTab( title: tableName, query: query, @@ -580,7 +599,8 @@ final class QueryTabManager { func replaceTabContent( tableName: String, databaseType: DatabaseType = .mysql, isView: Bool = false, databaseName: String = "", - isPreview: Bool = false + isPreview: Bool = false, + quoteIdentifier: ((String) -> String)? = nil ) -> Bool { guard let selectedId = selectedTabId, let selectedIndex = tabs.firstIndex(where: { $0.id == selectedId }) @@ -588,23 +608,12 @@ final class QueryTabManager { return false } + let query = QueryTab.buildBaseTableQuery( + tableName: tableName, + databaseType: databaseType, + quoteIdentifier: quoteIdentifier + ) let pageSize = AppSettingsManager.shared.dataGrid.defaultPageSize - let query: String - if databaseType == .mongodb { - let escaped = tableName.replacingOccurrences(of: "\\", with: "\\\\").replacingOccurrences(of: "\"", with: "\\\"") - query = "db[\"\(escaped)\"].find({}).limit(\(pageSize))" - } else if databaseType == .redis { - query = "SCAN 0 MATCH * COUNT \(pageSize)" - } else if databaseType == .mssql { - let quotedName = databaseType.quoteIdentifier(tableName) - query = "SELECT * FROM \(quotedName) ORDER BY (SELECT NULL) OFFSET 0 ROWS FETCH NEXT \(pageSize) ROWS ONLY;" - } else if databaseType == .oracle { - let quotedName = databaseType.quoteIdentifier(tableName) - query = "SELECT * FROM \(quotedName) ORDER BY 1 OFFSET 0 ROWS FETCH NEXT \(pageSize) ROWS ONLY;" - } else { - let quotedName = databaseType.quoteIdentifier(tableName) - query = "SELECT * FROM \(quotedName) LIMIT \(pageSize);" - } // Build locally and write back once to avoid 14 CoW copies (UI-11). var tab = tabs[selectedIndex] diff --git a/TablePro/Models/UI/FilterState.swift b/TablePro/Models/UI/FilterState.swift index 15d0013d0..155bf2f0a 100644 --- a/TablePro/Models/UI/FilterState.swift +++ b/TablePro/Models/UI/FilterState.swift @@ -356,7 +356,8 @@ final class FilterStateManager { /// Generate preview SQL for the "SQL" button /// Uses selected filters if any are selected, otherwise uses all valid filters func generatePreviewSQL(databaseType: DatabaseType) -> String { - let generator = FilterSQLGenerator(databaseType: databaseType) + let dialect = PluginManager.shared.sqlDialect(for: databaseType) + let generator = FilterSQLGenerator(databaseType: databaseType, dialect: dialect) let filtersToPreview = getFiltersForPreview() // If no valid filters but filters exist, show helpful message diff --git a/TablePro/Resources/Localizable.xcstrings b/TablePro/Resources/Localizable.xcstrings index aae717bef..e7a7656a5 100644 --- a/TablePro/Resources/Localizable.xcstrings +++ b/TablePro/Resources/Localizable.xcstrings @@ -489,6 +489,16 @@ } } }, + "%@: %lld" : { + "localizations" : { + "en" : { + "stringUnit" : { + "state" : "new", + "value" : "%1$@: %2$lld" + } + } + } + }, "%@." : { "extractionState" : "stale", "localizations" : { @@ -4886,6 +4896,7 @@ } }, "Database Index: %lld" : { + "extractionState" : "stale", "localizations" : { "vi" : { "stringUnit" : { @@ -12951,6 +12962,7 @@ } }, "Redis" : { + "extractionState" : "stale", "localizations" : { "vi" : { "stringUnit" : { @@ -16781,6 +16793,9 @@ } } } + }, + "Unsupported schema operation: %@" : { + }, "Untitled" : { "localizations" : { diff --git a/TablePro/Views/Editor/SQLCompletionAdapter.swift b/TablePro/Views/Editor/SQLCompletionAdapter.swift index 9d56c2741..2519e699e 100644 --- a/TablePro/Views/Editor/SQLCompletionAdapter.swift +++ b/TablePro/Views/Editor/SQLCompletionAdapter.swift @@ -25,13 +25,15 @@ final class SQLCompletionAdapter: CodeSuggestionDelegate { init(schemaProvider: SQLSchemaProvider?, databaseType: DatabaseType? = nil) { if let provider = schemaProvider { - self.completionEngine = CompletionEngine(schemaProvider: provider, databaseType: databaseType) + let dialect = databaseType.flatMap { PluginManager.shared.sqlDialect(for: $0) } + self.completionEngine = CompletionEngine(schemaProvider: provider, databaseType: databaseType, dialect: dialect) } } /// Update the schema provider (e.g. when connection changes) func updateSchemaProvider(_ provider: SQLSchemaProvider, databaseType: DatabaseType? = nil) { - self.completionEngine = CompletionEngine(schemaProvider: provider, databaseType: databaseType) + let dialect = databaseType.flatMap { PluginManager.shared.sqlDialect(for: $0) } + self.completionEngine = CompletionEngine(schemaProvider: provider, databaseType: databaseType, dialect: dialect) } // MARK: - CodeSuggestionDelegate diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+SidebarSave.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+SidebarSave.swift index 6426ea378..978ddab87 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+SidebarSave.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+SidebarSave.swift @@ -6,6 +6,7 @@ // import Foundation +import TableProPluginKit extension MainContentCoordinator { // MARK: - Sidebar Save @@ -43,7 +44,8 @@ extension MainContentCoordinator { tableName: tableName, columns: tab.resultColumns, primaryKeyColumn: changeManager.primaryKeyColumn, - databaseType: connection.type + databaseType: connection.type, + quoteIdentifier: changeManager.pluginDriver?.quoteIdentifier ) var statements: [ParameterizedStatement] = [] diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+TableOperations.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+TableOperations.swift index aae7310fc..293cc9842 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+TableOperations.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+TableOperations.swift @@ -32,6 +32,8 @@ extension MainContentCoordinator { ) -> [String] { var statements: [String] = [] let dbType = connection.type + let driver = DatabaseManager.shared.driver(for: connectionId) + let quote: (String) -> String = driver?.quoteIdentifier ?? dbType.quoteIdentifier // Sort tables for consistent execution order let sortedTruncates = truncates.sorted() @@ -47,7 +49,7 @@ extension MainContentCoordinator { } for tableName in sortedTruncates { - let quotedName = dbType.quoteIdentifier(tableName) + let quotedName = quote(tableName) let tableOptions = options[tableName] ?? TableOperationOptions() statements.append(contentsOf: truncateStatements( tableName: tableName, quotedName: quotedName, options: tableOptions, dbType: dbType @@ -60,7 +62,7 @@ extension MainContentCoordinator { }() for tableName in sortedDeletes { - let quotedName = dbType.quoteIdentifier(tableName) + let quotedName = quote(tableName) let tableOptions = options[tableName] ?? TableOperationOptions() let stmt = dropTableStatement( tableName: tableName, quotedName: quotedName, diff --git a/TablePro/Views/Main/MainContentCoordinator.swift b/TablePro/Views/Main/MainContentCoordinator.swift index d18431a89..6b61d1f7d 100644 --- a/TablePro/Views/Main/MainContentCoordinator.swift +++ b/TablePro/Views/Main/MainContentCoordinator.swift @@ -1363,12 +1363,12 @@ private extension MainContentCoordinator { connectionType: DatabaseType, schemaResult: SchemaResult? ) { - let quotedTable = connectionType.quoteIdentifier(tableName) Task { [weak self] in guard let self else { return } try? await Task.sleep(nanoseconds: 200_000_000) guard !self.isTearingDown else { return } guard let mainDriver = DatabaseManager.shared.driver(for: connectionId) else { return } + let quotedTable = mainDriver.quoteIdentifier(tableName) let countResult = try? await mainDriver.execute( query: "SELECT COUNT(*) FROM \(quotedTable)" ) @@ -1439,10 +1439,10 @@ private extension MainContentCoordinator { capturedGeneration: Int, connectionType: DatabaseType ) { - let quotedTable = connectionType.quoteIdentifier(tableName) Task { [weak self] in guard let self else { return } guard let mainDriver = DatabaseManager.shared.driver(for: connectionId) else { return } + let quotedTable = mainDriver.quoteIdentifier(tableName) let countResult = try? await mainDriver.execute( query: "SELECT COUNT(*) FROM \(quotedTable)" ) diff --git a/TablePro/Views/Results/DataGridView+RowActions.swift b/TablePro/Views/Results/DataGridView+RowActions.swift index 83c1b5b57..c2918c8a6 100644 --- a/TablePro/Views/Results/DataGridView+RowActions.swift +++ b/TablePro/Views/Results/DataGridView+RowActions.swift @@ -105,11 +105,14 @@ extension TableViewCoordinator { func copyRowsAsInsert(at indices: Set) { guard let tableName, let databaseType else { return } + let driver = resolveDriver() let converter = SQLRowToStatementConverter( tableName: tableName, columns: rowProvider.columns, primaryKeyColumn: primaryKeyColumn, - databaseType: databaseType + databaseType: databaseType, + quoteIdentifier: driver?.quoteIdentifier, + escapeStringLiteral: driver?.escapeStringLiteral ) let rows = indices.sorted().compactMap { rowProvider.rowValues(at: $0) } guard !rows.isEmpty else { return } @@ -118,14 +121,22 @@ extension TableViewCoordinator { func copyRowsAsUpdate(at indices: Set) { guard let tableName, let databaseType else { return } + let driver = resolveDriver() let converter = SQLRowToStatementConverter( tableName: tableName, columns: rowProvider.columns, primaryKeyColumn: primaryKeyColumn, - databaseType: databaseType + databaseType: databaseType, + quoteIdentifier: driver?.quoteIdentifier, + escapeStringLiteral: driver?.escapeStringLiteral ) let rows = indices.sorted().compactMap { rowProvider.rowValues(at: $0) } guard !rows.isEmpty else { return } ClipboardService.shared.writeText(converter.generateUpdates(rows: rows)) } + + private func resolveDriver() -> (any DatabaseDriver)? { + guard let connectionId else { return nil } + return DatabaseManager.shared.driver(for: connectionId) + } } diff --git a/TablePro/Views/Results/ForeignKeyPopoverContentView.swift b/TablePro/Views/Results/ForeignKeyPopoverContentView.swift index ada4e0171..763ed882d 100644 --- a/TablePro/Views/Results/ForeignKeyPopoverContentView.swift +++ b/TablePro/Views/Results/ForeignKeyPopoverContentView.swift @@ -115,8 +115,8 @@ struct ForeignKeyPopoverContentView: View { return } - let quotedTable = databaseType.quoteIdentifier(fkInfo.referencedTable) - let quotedColumn = databaseType.quoteIdentifier(fkInfo.referencedColumn) + let quotedTable = driver.quoteIdentifier(fkInfo.referencedTable) + let quotedColumn = driver.quoteIdentifier(fkInfo.referencedColumn) // Try to find a display column (first text-like column that isn't the FK column) var displayColumn: String? @@ -140,7 +140,7 @@ struct ForeignKeyPopoverContentView: View { limitSuffix = "LIMIT \(Self.maxFetchRows)" } if let displayCol = displayColumn { - let quotedDisplay = databaseType.quoteIdentifier(displayCol) + let quotedDisplay = driver.quoteIdentifier(displayCol) query = "SELECT \(quotedColumn), \(quotedDisplay) FROM \(quotedTable) ORDER BY \(quotedColumn) \(limitSuffix)" } else { query = "SELECT DISTINCT \(quotedColumn) FROM \(quotedTable) ORDER BY \(quotedColumn) \(limitSuffix)" diff --git a/TablePro/Views/Structure/TableStructureView.swift b/TablePro/Views/Structure/TableStructureView.swift index 191d7edd9..96d842d84 100644 --- a/TablePro/Views/Structure/TableStructureView.swift +++ b/TablePro/Views/Structure/TableStructureView.swift @@ -746,7 +746,7 @@ struct TableStructureView: View { } for enumType in enumTypes { let quotedName = "\"\(enumType.name.replacingOccurrences(of: "\"", with: "\"\""))\"" - let quotedLabels = enumType.labels.map { "'\(SQLEscaping.escapeStringLiteral($0, databaseType: .postgresql))'" } + let quotedLabels = enumType.labels.map { "'\(SQLEscaping.escapeStringLiteral($0))'" } preamble += "CREATE TYPE \(quotedName) AS ENUM (\(quotedLabels.joined(separator: ", ")));\n" } ddlStatement = preamble + "\n" + baseDDL diff --git a/TableProTests/Core/Database/PostgreSQLDriverTests.swift b/TableProTests/Core/Database/PostgreSQLDriverTests.swift index 9f7977852..3db181920 100644 --- a/TableProTests/Core/Database/PostgreSQLDriverTests.swift +++ b/TableProTests/Core/Database/PostgreSQLDriverTests.swift @@ -16,54 +16,35 @@ import Testing @Suite("PostgreSQL SQL Escaping Correctness") struct PostgreSQLSQLEscapingCorrectness { - @Test("Backslash in table name — MySQL doubles backslashes, PostgreSQL preserves them") - func backslashInTableName() { + @Test("ANSI escaping preserves backslashes") + func backslashPreserved() { let input = "test\\table" - let mysql = SQLEscaping.escapeStringLiteral(input, databaseType: .mysql) - let postgresql = SQLEscaping.escapeStringLiteral(input, databaseType: .postgresql) - - #expect(mysql == "test\\\\table") - #expect(postgresql == "test\\table") - #expect(mysql != postgresql, "MySQL and PostgreSQL escaping must differ for backslashes") + let result = SQLEscaping.escapeStringLiteral(input) + #expect(result == "test\\table") } - @Test("Newline in value — MySQL escapes to \\n, PostgreSQL preserves literal newline") - func newlineInValue() { + @Test("ANSI escaping preserves literal newlines") + func newlinePreserved() { let input = "line1\nline2" - let mysql = SQLEscaping.escapeStringLiteral(input, databaseType: .mysql) - let postgresql = SQLEscaping.escapeStringLiteral(input, databaseType: .postgresql) - - #expect(mysql == "line1\\nline2") - #expect(postgresql == "line1\nline2") - #expect(mysql != postgresql, "MySQL and PostgreSQL escaping must differ for newlines") + let result = SQLEscaping.escapeStringLiteral(input) + #expect(result == "line1\nline2") } - @Test("Tab in value — MySQL escapes to \\t, PostgreSQL preserves literal tab") - func tabInValue() { + @Test("ANSI escaping preserves literal tabs") + func tabPreserved() { let input = "col1\tcol2" - let mysql = SQLEscaping.escapeStringLiteral(input, databaseType: .mysql) - let postgresql = SQLEscaping.escapeStringLiteral(input, databaseType: .postgresql) - - #expect(mysql == "col1\\tcol2") - #expect(postgresql == "col1\tcol2") - #expect(mysql != postgresql, "MySQL and PostgreSQL escaping must differ for tabs") + let result = SQLEscaping.escapeStringLiteral(input) + #expect(result == "col1\tcol2") } - @Test("Combined special chars — backslash and quote produce different results per DB type") + @Test("ANSI escaping doubles single quotes and preserves control chars") func combinedSpecialChars() { let input = "it's a \\path\n" - let mysql = SQLEscaping.escapeStringLiteral(input, databaseType: .mysql) - let postgresql = SQLEscaping.escapeStringLiteral(input, databaseType: .postgresql) - - #expect(mysql.contains("\\\\"), "MySQL should double backslashes") - #expect(mysql.contains("\\n"), "MySQL should escape newlines") - #expect(!postgresql.contains("\\\\"), "PostgreSQL should not double backslashes") - #expect(postgresql.contains("\n"), "PostgreSQL should preserve literal newlines") - - #expect(mysql.contains("''"), "MySQL should double single quotes") - #expect(postgresql.contains("''"), "PostgreSQL should double single quotes") + let result = SQLEscaping.escapeStringLiteral(input) - #expect(mysql != postgresql, "MySQL and PostgreSQL escaping must differ for combined special chars") + #expect(!result.contains("\\\\"), "ANSI escaping should not double backslashes") + #expect(result.contains("\n"), "ANSI escaping should preserve literal newlines") + #expect(result.contains("''"), "ANSI escaping should double single quotes") } } @@ -274,7 +255,7 @@ struct DDLLoadingFlowTests { } for enumType in enumTypes { let quotedName = "\"\(enumType.name.replacingOccurrences(of: "\"", with: "\"\""))\"" - let quotedLabels = enumType.labels.map { "'\(SQLEscaping.escapeStringLiteral($0, databaseType: .postgresql))'" } + let quotedLabels = enumType.labels.map { "'\(SQLEscaping.escapeStringLiteral($0))'" } preamble += "CREATE TYPE \(quotedName) AS ENUM (\(quotedLabels.joined(separator: ", ")));\n" } diff --git a/TableProTests/Core/Database/SQLEscapingTests.swift b/TableProTests/Core/Database/SQLEscapingTests.swift index 81481884b..b3ea8d3aa 100644 --- a/TableProTests/Core/Database/SQLEscapingTests.swift +++ b/TableProTests/Core/Database/SQLEscapingTests.swift @@ -12,7 +12,7 @@ import Testing @Suite("SQL Escaping") struct SQLEscapingTests { - // MARK: - escapeStringLiteral Tests + // MARK: - escapeStringLiteral Tests (ANSI SQL) @Test("Plain string unchanged") func testPlainStringUnchanged() { @@ -28,67 +28,67 @@ struct SQLEscapingTests { #expect(result == "O''Brien") } - @Test("Backslashes doubled") - func testBackslashesDoubled() { + @Test("Backslashes preserved") + func testBackslashesPreserved() { let input = "C:\\Users\\Test" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "C:\\\\Users\\\\Test") + #expect(result == "C:\\Users\\Test") } - @Test("Newline escaped") - func testNewlineEscaped() { + @Test("Newlines preserved") + func testNewlinesPreserved() { let input = "Line1\nLine2" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "Line1\\nLine2") + #expect(result == "Line1\nLine2") } - @Test("Carriage return escaped") - func testCarriageReturnEscaped() { + @Test("Carriage returns preserved") + func testCarriageReturnsPreserved() { let input = "Text\rMore" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "Text\\rMore") + #expect(result == "Text\rMore") } - @Test("Tab escaped") - func testTabEscaped() { + @Test("Tabs preserved") + func testTabsPreserved() { let input = "Col1\tCol2" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "Col1\\tCol2") + #expect(result == "Col1\tCol2") } - @Test("Null character escaped") - func testNullCharacterEscaped() { + @Test("Null bytes stripped") + func testNullBytesStripped() { let input = "Text\0End" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "Text\\0End") + #expect(result == "TextEnd") } - @Test("Backspace escaped") - func testBackspaceEscaped() { + @Test("Backspace preserved") + func testBackspacePreserved() { let input = "Text\u{08}End" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "Text\\bEnd") + #expect(result == "Text\u{08}End") } - @Test("Form feed escaped") - func testFormFeedEscaped() { + @Test("Form feed preserved") + func testFormFeedPreserved() { let input = "Text\u{0C}End" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "Text\\fEnd") + #expect(result == "Text\u{0C}End") } - @Test("EOF marker escaped") - func testEOFMarkerEscaped() { + @Test("EOF marker preserved") + func testEOFMarkerPreserved() { let input = "Text\u{1A}End" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "Text\\ZEnd") + #expect(result == "Text\u{1A}End") } - @Test("Combined special characters") + @Test("Combined special characters — ANSI escaping") func testCombinedSpecialCharacters() { let input = "O'Brien\\test\nline2\t\0end" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "O''Brien\\\\test\\nline2\\t\\0end") + #expect(result == "O''Brien\\test\nline2\tend") } @Test("Empty string unchanged") @@ -98,82 +98,11 @@ struct SQLEscapingTests { #expect(result == "") } - @Test("Backslash and quote order prevents double-escaping") + @Test("Backslash and quote — ANSI escaping") func testBackslashQuoteEscapingOrder() { - // Verify that backslash+quote produces \\'' and not \\\\' let input = "\\'" let result = SQLEscaping.escapeStringLiteral(input) - #expect(result == "\\\\''") - } - - // MARK: - escapeLikeWildcards Tests - - // MARK: - PostgreSQL/SQLite escapeStringLiteral Tests - - @Test("PostgreSQL: plain string unchanged") - func testPostgreSQLPlainStringUnchanged() { - let result = SQLEscaping.escapeStringLiteral("Hello World", databaseType: .postgresql) - #expect(result == "Hello World") - } - - @Test("PostgreSQL: single quotes doubled") - func testPostgreSQLSingleQuotesDoubled() { - let result = SQLEscaping.escapeStringLiteral("O'Brien", databaseType: .postgresql) - #expect(result == "O''Brien") - } - - @Test("PostgreSQL: newlines preserved") - func testPostgreSQLNewlinesPreserved() { - let result = SQLEscaping.escapeStringLiteral("Line1\nLine2", databaseType: .postgresql) - #expect(result == "Line1\nLine2") - } - - @Test("PostgreSQL: carriage returns preserved") - func testPostgreSQLCarriageReturnsPreserved() { - let result = SQLEscaping.escapeStringLiteral("Text\rMore", databaseType: .postgresql) - #expect(result == "Text\rMore") - } - - @Test("PostgreSQL: tabs preserved") - func testPostgreSQLTabsPreserved() { - let result = SQLEscaping.escapeStringLiteral("Col1\tCol2", databaseType: .postgresql) - #expect(result == "Col1\tCol2") - } - - @Test("PostgreSQL: backslashes preserved") - func testPostgreSQLBackslashesPreserved() { - let result = SQLEscaping.escapeStringLiteral("C:\\Users\\Test", databaseType: .postgresql) - #expect(result == "C:\\Users\\Test") - } - - @Test("PostgreSQL: null bytes stripped") - func testPostgreSQLNullBytesStripped() { - let result = SQLEscaping.escapeStringLiteral("Text\0End", databaseType: .postgresql) - #expect(result == "TextEnd") - } - - @Test("PostgreSQL: combined special characters") - func testPostgreSQLCombinedSpecialCharacters() { - let result = SQLEscaping.escapeStringLiteral("O'Brien\\test\nline2\t\0end", databaseType: .postgresql) - #expect(result == "O''Brien\\test\nline2\tend") - } - - @Test("SQLite: newlines preserved") - func testSQLiteNewlinesPreserved() { - let result = SQLEscaping.escapeStringLiteral("Line1\nLine2", databaseType: .sqlite) - #expect(result == "Line1\nLine2") - } - - @Test("SQLite: backslashes preserved") - func testSQLiteBackslashesPreserved() { - let result = SQLEscaping.escapeStringLiteral("path\\to\\file", databaseType: .sqlite) - #expect(result == "path\\to\\file") - } - - @Test("SQLite: single quotes doubled") - func testSQLiteSingleQuotesDoubled() { - let result = SQLEscaping.escapeStringLiteral("it's", databaseType: .sqlite) - #expect(result == "it''s") + #expect(result == "\\''") } // MARK: - escapeLikeWildcards Tests @@ -222,7 +151,6 @@ struct SQLEscapingTests { @Test("LIKE backslash and percent order prevents double-escaping") func testLikeBackslashPercentEscapingOrder() { - // Verify that backslash+percent produces \\% and not \\\\% let input = "\\%" let result = SQLEscaping.escapeLikeWildcards(input) #expect(result == "\\\\\\%")